Implement SSO/OIDC support. (closes #7)
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -110,6 +110,7 @@ ruma.workspace = true
|
||||
rustls.workspace = true
|
||||
rustyline-async.workspace = true
|
||||
rustyline-async.optional = true
|
||||
serde_html_form.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde.workspace = true
|
||||
serde_yaml.workspace = true
|
||||
|
||||
@@ -21,6 +21,7 @@ pub struct Service {
|
||||
pub sender: ClientLazylock,
|
||||
pub appservice: ClientLazylock,
|
||||
pub pusher: ClientLazylock,
|
||||
pub oauth: ClientLazylock,
|
||||
|
||||
pub cidr_range_denylist: Vec<IPAddress>,
|
||||
}
|
||||
@@ -112,6 +113,11 @@ impl crate::Service for Service {
|
||||
.pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout))
|
||||
.redirect(redirect::Policy::limited(2))),
|
||||
|
||||
oauth: create_client!(config, services; base(config)?
|
||||
.dns_resolver2(Arc::clone(&services.resolver.resolver))
|
||||
.redirect(redirect::Policy::limited(0))
|
||||
.pool_max_idle_per_host(1)),
|
||||
|
||||
cidr_range_denylist: config
|
||||
.ip_range_denylist
|
||||
.iter()
|
||||
|
||||
@@ -19,6 +19,7 @@ pub mod globals;
|
||||
pub mod key_backups;
|
||||
pub mod media;
|
||||
pub mod membership;
|
||||
pub mod oauth;
|
||||
pub mod presence;
|
||||
pub mod pusher;
|
||||
pub mod resolver;
|
||||
|
||||
265
src/service/oauth/mod.rs
Normal file
265
src/service/oauth/mod.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
pub mod providers;
|
||||
pub mod sessions;
|
||||
pub mod user_info;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
|
||||
use reqwest::{Method, header::ACCEPT};
|
||||
use ruma::UserId;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tuwunel_core::{
|
||||
Err, Result, err, implement,
|
||||
utils::{hash::sha256, result::LogErr},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
pub use self::{
|
||||
providers::Provider,
|
||||
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session},
|
||||
user_info::UserInfo,
|
||||
};
|
||||
use self::{providers::Providers, sessions::Sessions};
|
||||
use crate::SelfServices;
|
||||
|
||||
pub struct Service {
|
||||
services: SelfServices,
|
||||
pub providers: Arc<Providers>,
|
||||
pub sessions: Arc<Sessions>,
|
||||
}
|
||||
|
||||
impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
let providers = Arc::new(Providers::build(args));
|
||||
let sessions = Arc::new(Sessions::build(args, providers.clone()));
|
||||
Ok(Arc::new(Self {
|
||||
services: args.services.clone(),
|
||||
sessions,
|
||||
providers,
|
||||
}))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||
pub async fn request_userinfo(
|
||||
&self,
|
||||
(provider, session): (&Provider, &Session),
|
||||
) -> Result<UserInfo> {
|
||||
let url = provider
|
||||
.userinfo_url
|
||||
.clone()
|
||||
.ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
|
||||
|
||||
self.request((Some(provider), Some(session)), Method::GET, url)
|
||||
.await
|
||||
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||
.log_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||
pub async fn request_tokeninfo(
|
||||
&self,
|
||||
(provider, session): (&Provider, &Session),
|
||||
) -> Result<UserInfo> {
|
||||
let url = provider
|
||||
.introspection_url
|
||||
.clone()
|
||||
.ok_or_else(|| {
|
||||
err!(Config("introspection_url", "Missing introspection URL in config"))
|
||||
})?;
|
||||
|
||||
self.request((Some(provider), Some(session)), Method::GET, url)
|
||||
.await
|
||||
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||
.log_err()
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||
pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) -> Result {
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RevokeQuery<'a> {
|
||||
client_id: &'a str,
|
||||
client_secret: &'a str,
|
||||
}
|
||||
|
||||
let query = RevokeQuery {
|
||||
client_id: &provider.client_id,
|
||||
client_secret: &provider.client_secret,
|
||||
};
|
||||
|
||||
let query = serde_html_form::to_string(&query)?;
|
||||
let url = provider
|
||||
.revocation_url
|
||||
.clone()
|
||||
.map(|mut url| {
|
||||
url.set_query(Some(&query));
|
||||
url
|
||||
})
|
||||
.ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
|
||||
|
||||
self.request((Some(provider), Some(session)), Method::POST, url)
|
||||
.await
|
||||
.log_err()
|
||||
.map(|_| ())
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip_all, ret)]
|
||||
pub async fn request_token(
|
||||
&self,
|
||||
(provider, session): (&Provider, &Session),
|
||||
code: &str,
|
||||
) -> Result<Session> {
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenQuery<'a> {
|
||||
client_id: &'a str,
|
||||
client_secret: &'a str,
|
||||
grant_type: &'a str,
|
||||
code: &'a str,
|
||||
code_verifier: Option<&'a str>,
|
||||
redirect_uri: Option<&'a str>,
|
||||
}
|
||||
|
||||
let query = TokenQuery {
|
||||
client_id: &provider.client_id,
|
||||
client_secret: &provider.client_secret,
|
||||
grant_type: "authorization_code",
|
||||
code,
|
||||
code_verifier: session.code_verifier.as_deref(),
|
||||
redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
|
||||
};
|
||||
|
||||
let query = serde_html_form::to_string(&query)?;
|
||||
let url = provider
|
||||
.token_url
|
||||
.clone()
|
||||
.map(|mut url| {
|
||||
url.set_query(Some(&query));
|
||||
url
|
||||
})
|
||||
.ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
|
||||
|
||||
self.request((Some(provider), Some(session)), Method::POST, url)
|
||||
.await
|
||||
.and_then(|value| serde_json::from_value(value).map_err(Into::into))
|
||||
.log_err()
|
||||
}
|
||||
|
||||
/// Send a request to a provider; this is somewhat abstract since URL's are
|
||||
/// formed prior to this call and could point at anything, however this function
|
||||
/// uses the oauth-specific http client and is configured for JSON with special
|
||||
/// casing for an `error` property in the response.
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(
|
||||
name = "request",
|
||||
level = "debug",
|
||||
ret(level = "trace"),
|
||||
skip(self)
|
||||
)]
|
||||
pub async fn request(
|
||||
&self,
|
||||
(provider, session): (Option<&Provider>, Option<&Session>),
|
||||
method: Method,
|
||||
url: Url,
|
||||
) -> Result<JsonValue> {
|
||||
let mut request = self
|
||||
.services
|
||||
.client
|
||||
.oauth
|
||||
.request(method, url)
|
||||
.header(ACCEPT, "application/json");
|
||||
|
||||
if let Some(session) = session {
|
||||
if let Some(access_token) = session.access_token.clone() {
|
||||
request = request.bearer_auth(access_token);
|
||||
}
|
||||
}
|
||||
|
||||
let response: JsonValue = request
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let Some(response) = response.as_object().as_ref()
|
||||
&& let Some(error) = response.get("error").and_then(JsonValue::as_str)
|
||||
{
|
||||
let description = response
|
||||
.get("error_description")
|
||||
.and_then(JsonValue::as_str)
|
||||
.unwrap_or("(no description)");
|
||||
|
||||
return Err!(Request(Forbidden("Error from provider: {error}: {description}",)));
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn get_user(&self, user_id: &UserId) -> Result<(Provider, Session)> {
|
||||
let session = self.sessions.get_by_user(user_id).await?;
|
||||
let provider = self.sessions.provider(&session).await?;
|
||||
|
||||
Ok((provider, session))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unique_id((provider, session): (&Provider, &Session)) -> Result<String> {
|
||||
unique_id_parts((provider, session)).and_then(unique_id_iss_sub)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unique_id_sub((provider, sub): (&Provider, &str)) -> Result<String> {
|
||||
unique_id_sub_parts((provider, sub)).and_then(unique_id_iss_sub)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unique_id_iss((iss, session): (&str, &Session)) -> Result<String> {
|
||||
unique_id_iss_parts((iss, session)).and_then(unique_id_iss_sub)
|
||||
}
|
||||
|
||||
pub fn unique_id_iss_sub((iss, sub): (&str, &str)) -> Result<String> {
|
||||
let hash = sha256::delimited([iss, sub].iter());
|
||||
let b64 = b64encode.encode(hash);
|
||||
|
||||
Ok(b64)
|
||||
}
|
||||
|
||||
fn unique_id_parts<'a>(
|
||||
(provider, session): (&'a Provider, &'a Session),
|
||||
) -> Result<(&'a str, &'a str)> {
|
||||
provider
|
||||
.issuer_url
|
||||
.as_ref()
|
||||
.map(Url::as_str)
|
||||
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
|
||||
.and_then(|iss| unique_id_iss_parts((iss, session)))
|
||||
}
|
||||
|
||||
fn unique_id_sub_parts<'a>(
|
||||
(provider, sub): (&'a Provider, &'a str),
|
||||
) -> Result<(&'a str, &'a str)> {
|
||||
provider
|
||||
.issuer_url
|
||||
.as_ref()
|
||||
.map(Url::as_str)
|
||||
.ok_or_else(|| err!(Config("issuer_url", "issuer_url not found for this provider.")))
|
||||
.map(|iss| (iss, sub))
|
||||
}
|
||||
|
||||
fn unique_id_iss_parts<'a>((iss, session): (&'a str, &'a Session)) -> Result<(&'a str, &'a str)> {
|
||||
session
|
||||
.user_info
|
||||
.as_ref()
|
||||
.map(|user_info| user_info.sub.as_str())
|
||||
.ok_or_else(|| err!(Request(NotFound("user_info not found for this session."))))
|
||||
.map(|sub| (iss, sub))
|
||||
}
|
||||
246
src/service/oauth/providers.rs
Normal file
246
src/service/oauth/providers.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use serde_json::{Map as JsonObject, Value as JsonValue};
|
||||
use tokio::sync::RwLock;
|
||||
pub use tuwunel_core::config::IdentityProvider as Provider;
|
||||
use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
|
||||
use url::Url;
|
||||
|
||||
use crate::SelfServices;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Providers {
|
||||
services: SelfServices,
|
||||
providers: RwLock<BTreeMap<String, Provider>>,
|
||||
}
|
||||
|
||||
#[implement(Providers)]
|
||||
pub(super) fn build(args: &crate::Args<'_>) -> Self {
|
||||
Self {
|
||||
services: args.services.clone(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the Provider configuration after any discovery and adjustments
|
||||
/// made on top of the admin's configuration. This incurs network-based
|
||||
/// discovery on the first call but responds from cache on subsequent calls.
|
||||
#[implement(Providers)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn get(&self, id: &str) -> Result<Provider> {
|
||||
if let Some(provider) = self.providers.read().await.get(id).cloned() {
|
||||
return Ok(provider);
|
||||
}
|
||||
|
||||
let config = self.get_config(id)?;
|
||||
let mut map = self.providers.write().await;
|
||||
let config = self.configure(config).await?;
|
||||
|
||||
debug!(?id, ?config);
|
||||
_ = map.insert(id.into(), config.clone());
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Get the admin-configured Provider which exists prior to any
|
||||
/// reconciliation with the well-known discovery (the server's config is
|
||||
/// immutable); though it is important to note the server config can be
|
||||
/// reloaded. This will Err NotFound for a non-existent idp.
|
||||
#[implement(Providers)]
|
||||
pub fn get_config(&self, id: &str) -> Result<Provider> {
|
||||
self.services
|
||||
.config
|
||||
.identity_provider
|
||||
.iter()
|
||||
.find(|config| config.id() == id)
|
||||
.cloned()
|
||||
.ok_or_else(|| err!(Request(NotFound("Unrecognized Identity Provider"))))
|
||||
}
|
||||
|
||||
/// Configure an identity provider; takes the admin-configured instance from the
|
||||
/// server's config, queries the provider for discovery, and then returns an
|
||||
/// updated config based on the proper reconciliation. This final config is then
|
||||
/// cached in memory to avoid repeating this process.
|
||||
#[implement(Providers)]
|
||||
#[tracing::instrument(
|
||||
level = INFO_SPAN_LEVEL,
|
||||
ret(level = "debug"),
|
||||
skip(self),
|
||||
)]
|
||||
async fn configure(&self, mut provider: Provider) -> Result<Provider> {
|
||||
_ = provider
|
||||
.name
|
||||
.get_or_insert_with(|| provider.brand.clone());
|
||||
|
||||
if provider.issuer_url.is_none() {
|
||||
_ = provider
|
||||
.issuer_url
|
||||
.replace(match provider.brand.as_str() {
|
||||
| "github" => "https://github.com".try_into()?,
|
||||
| "gitlab" => "https://gitlab.com".try_into()?,
|
||||
| "google" => "https://accounts.google.com".try_into()?,
|
||||
| _ => return Err!(Config("issuer_url", "Required for this provider.")),
|
||||
});
|
||||
}
|
||||
|
||||
if provider.base_path.is_none() {
|
||||
provider.base_path = match provider.brand.as_str() {
|
||||
| "github" => Some("/login/oauth".to_owned()),
|
||||
| _ => None,
|
||||
};
|
||||
}
|
||||
|
||||
let response = self
|
||||
.discover(&provider)
|
||||
.await
|
||||
.and_then(|response| {
|
||||
response.as_object().cloned().ok_or_else(|| {
|
||||
err!(Request(NotJson("Expecting JSON object for discovery response")))
|
||||
})
|
||||
})
|
||||
.and_then(|response| check_issuer(response, &provider))?;
|
||||
|
||||
if provider.authorization_url.is_none() {
|
||||
response
|
||||
.get("authorization_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| make_url(&provider, "/authorize").ok())
|
||||
.map(|url| provider.authorization_url.replace(url));
|
||||
}
|
||||
|
||||
if provider.revocation_url.is_none() {
|
||||
response
|
||||
.get("revocation_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| make_url(&provider, "/revocation").ok())
|
||||
.map(|url| provider.revocation_url.replace(url));
|
||||
}
|
||||
|
||||
if provider.introspection_url.is_none() {
|
||||
response
|
||||
.get("introspection_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| make_url(&provider, "/introspection").ok())
|
||||
.map(|url| provider.introspection_url.replace(url));
|
||||
}
|
||||
|
||||
if provider.userinfo_url.is_none() {
|
||||
response
|
||||
.get("userinfo_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| match provider.brand.as_str() {
|
||||
| "github" => "https://api.github.com/user".try_into().ok(),
|
||||
| _ => make_url(&provider, "/userinfo").ok(),
|
||||
})
|
||||
.map(|url| provider.userinfo_url.replace(url));
|
||||
}
|
||||
|
||||
if provider.token_url.is_none() {
|
||||
response
|
||||
.get("token_endpoint")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(Url::parse)
|
||||
.transpose()?
|
||||
.or_else(|| {
|
||||
let path = if provider.brand == "github" {
|
||||
"/access_token"
|
||||
} else {
|
||||
"/token"
|
||||
};
|
||||
|
||||
make_url(&provider, path).ok()
|
||||
})
|
||||
.map(|url| provider.token_url.replace(url));
|
||||
}
|
||||
|
||||
Ok(provider)
|
||||
}
|
||||
|
||||
#[implement(Providers)]
|
||||
#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
|
||||
pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
|
||||
self.services
|
||||
.client
|
||||
.oauth
|
||||
.get(discovery_url(provider)?)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn discovery_url(provider: &Provider) -> Result<Url> {
|
||||
let default_url = provider
|
||||
.discovery
|
||||
.then(|| make_url(provider, "/.well-known/openid-configuration"))
|
||||
.transpose()?;
|
||||
|
||||
let Some(url) = provider
|
||||
.discovery_url
|
||||
.clone()
|
||||
.filter(|_| provider.discovery)
|
||||
.or(default_url)
|
||||
else {
|
||||
return Err!(Config(
|
||||
"discovery_url",
|
||||
"Failed to determine URL for discovery of provider {}",
|
||||
provider.id()
|
||||
));
|
||||
};
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
/// Validate that the locally configured `issuer_url` matches the issuer claimed
|
||||
/// in any response. todo: cryptographic validation is not yet implemented here.
|
||||
fn check_issuer(
|
||||
response: JsonObject<String, JsonValue>,
|
||||
provider: &Provider,
|
||||
) -> Result<JsonObject<String, JsonValue>> {
|
||||
let expected = provider
|
||||
.issuer_url
|
||||
.as_ref()
|
||||
.map(Url::as_str)
|
||||
.map(|url| url.trim_end_matches('/'));
|
||||
|
||||
let responded = response
|
||||
.get("issuer")
|
||||
.and_then(JsonValue::as_str)
|
||||
.map(|url| url.trim_end_matches('/'));
|
||||
|
||||
if expected != responded {
|
||||
return Err!(Request(Unauthorized(
|
||||
"Configured issuer_url {expected:?} does not match discovered {responded:?}",
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Generate a full URL for a request to the idp based on the idp's derived
|
||||
/// configuration.
|
||||
fn make_url(provider: &Provider, path: &str) -> Result<Url> {
|
||||
let mut suffix = provider.base_path.clone().unwrap_or_default();
|
||||
|
||||
suffix.push_str(path);
|
||||
let url = provider
|
||||
.issuer_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
let id = &provider.client_id;
|
||||
err!(Config("issuer_url", "Provider {id:?} required field"))
|
||||
})?
|
||||
.join(&suffix)?;
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
234
src/service/oauth/sessions.rs
Normal file
234
src/service/oauth/sessions.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
use std::{sync::Arc, time::SystemTime};
|
||||
|
||||
use futures::{Stream, StreamExt, TryFutureExt};
|
||||
use ruma::{OwnedUserId, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{Err, Result, at, implement, utils::stream::TryExpect};
|
||||
use tuwunel_database::{Cbor, Deserialized, Map};
|
||||
use url::Url;
|
||||
|
||||
use super::{Provider, Providers, UserInfo, unique_id};
|
||||
use crate::SelfServices;
|
||||
|
||||
pub struct Sessions {
|
||||
_services: SelfServices,
|
||||
providers: Arc<Providers>,
|
||||
db: Data,
|
||||
}
|
||||
|
||||
struct Data {
|
||||
oauthid_session: Arc<Map>,
|
||||
oauthuniqid_oauthid: Arc<Map>,
|
||||
userid_oauthid: Arc<Map>,
|
||||
}
|
||||
|
||||
/// Session ultimately represents an OAuth authorization session yielding an
|
||||
/// associated matrix user registration.
|
||||
///
|
||||
/// Mixed-use structure capable of deserializing response values, maintaining
|
||||
/// the state between authorization steps, and maintaining the association to
|
||||
/// the matrix user until deactivation or revocation.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Session {
|
||||
/// Identity Provider ID (the `client_id` in the configuration) associated
|
||||
/// with this session.
|
||||
pub idp_id: Option<String>,
|
||||
|
||||
/// Session ID used as the index key for this session itself.
|
||||
pub sess_id: Option<String>,
|
||||
|
||||
/// Token type (bearer, mac, etc).
|
||||
pub token_type: Option<String>,
|
||||
|
||||
/// Access token to the provider.
|
||||
pub access_token: Option<String>,
|
||||
|
||||
/// Duration in seconds the access_token is valid for.
|
||||
pub expires_in: Option<u64>,
|
||||
|
||||
/// Point in time that the access_token expires.
|
||||
pub expires_at: Option<SystemTime>,
|
||||
|
||||
/// Token used to refresh the access_token.
|
||||
pub refresh_token: Option<String>,
|
||||
|
||||
/// Duration in seconds the refresh_token is valid for
|
||||
pub refresh_token_expires_in: Option<u64>,
|
||||
|
||||
/// Point in time that the refresh_token expires.
|
||||
pub refresh_token_expires_at: Option<SystemTime>,
|
||||
|
||||
/// Access scope actually granted (if supported).
|
||||
pub scope: Option<String>,
|
||||
|
||||
/// Redirect URL
|
||||
pub redirect_url: Option<Url>,
|
||||
|
||||
/// Challenge preimage
|
||||
pub code_verifier: Option<String>,
|
||||
|
||||
/// Random string passed exclusively in the grant session cookie.
|
||||
pub cookie_nonce: Option<String>,
|
||||
|
||||
/// Random single-use string passed in the provider redirect.
|
||||
pub query_nonce: Option<String>,
|
||||
|
||||
/// Point in time the authorization grant session expires.
|
||||
pub authorize_expires_at: Option<SystemTime>,
|
||||
|
||||
/// Associated User Id registration.
|
||||
pub user_id: Option<OwnedUserId>,
|
||||
|
||||
/// Last userinfo response persisted here.
|
||||
pub user_info: Option<UserInfo>,
|
||||
}
|
||||
|
||||
/// Number of characters generated for our code_verifier. The code_verifier is a
|
||||
/// random string which must be between 43 and 128 characters.
|
||||
pub const CODE_VERIFIER_LENGTH: usize = 64;
|
||||
|
||||
/// Number of characters we will generate for the Session ID.
|
||||
pub const SESSION_ID_LENGTH: usize = 32;
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub(super) fn build(args: &crate::Args<'_>, providers: Arc<Providers>) -> Self {
|
||||
Self {
|
||||
_services: args.services.clone(),
|
||||
providers,
|
||||
db: Data {
|
||||
oauthid_session: args.db["oauthid_session"].clone(),
|
||||
oauthuniqid_oauthid: args.db["oauthuniqid_oauthid"].clone(),
|
||||
userid_oauthid: args.db["userid_oauthid"].clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub fn users(&self) -> impl Stream<Item = OwnedUserId> + Send {
|
||||
self.db
|
||||
.userid_oauthid
|
||||
.keys()
|
||||
.expect_ok()
|
||||
.map(UserId::to_owned)
|
||||
}
|
||||
|
||||
/// Delete database state for the session.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn delete(&self, sess_id: &str) {
|
||||
let Ok(session) = self.get(sess_id).await else {
|
||||
return;
|
||||
};
|
||||
|
||||
// Check the user_id still points to this sess_id before deleting. If not, the
|
||||
// association was updated to a newer session.
|
||||
if let Some(user_id) = session.user_id.as_deref() {
|
||||
if let Ok(assoc_id) = self.get_sess_id_by_user(user_id).await {
|
||||
if assoc_id == sess_id {
|
||||
self.db.userid_oauthid.remove(user_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check the unique identity still points to this sess_id before deleting. If
|
||||
// not, the association was updated to a newer session.
|
||||
if let Some(idp_id) = session.idp_id.as_ref() {
|
||||
if let Ok(provider) = self.providers.get(idp_id).await {
|
||||
if let Ok(unique_id) = unique_id((&provider, &session)) {
|
||||
if let Ok(assoc_id) = self.get_sess_id_by_unique_id(&unique_id).await {
|
||||
if assoc_id == sess_id {
|
||||
self.db.oauthuniqid_oauthid.remove(&unique_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.db.oauthid_session.remove(sess_id);
|
||||
}
|
||||
|
||||
/// Create or overwrite database state for the session.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "info", skip(self))]
|
||||
pub async fn put(&self, sess_id: &str, session: &Session) {
|
||||
self.db
|
||||
.oauthid_session
|
||||
.raw_put(sess_id, Cbor(session));
|
||||
|
||||
if let Some(user_id) = session.user_id.as_deref() {
|
||||
self.db.userid_oauthid.insert(user_id, sess_id);
|
||||
}
|
||||
|
||||
if let Some(idp_id) = session.idp_id.as_ref() {
|
||||
if let Ok(provider) = self.providers.get(idp_id).await {
|
||||
if let Ok(unique_id) = unique_id((&provider, session)) {
|
||||
self.db
|
||||
.oauthuniqid_oauthid
|
||||
.insert(&unique_id, sess_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its associated `sub`, in case
|
||||
/// `sess_id` is not known.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_by_unique_id(&self, unique_id: &str) -> Result<Session> {
|
||||
self.get_sess_id_by_unique_id(unique_id)
|
||||
.and_then(async |sess_id| self.get(&sess_id).await)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its associated `user_id`, in case
|
||||
/// `sess_id` is not known.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_by_user(&self, user_id: &UserId) -> Result<Session> {
|
||||
self.get_sess_id_by_user(user_id)
|
||||
.and_then(async |sess_id| self.get(&sess_id).await)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Fetch database state for a session from its `sess_id`.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get(&self, sess_id: &str) -> Result<Session> {
|
||||
self.db
|
||||
.oauthid_session
|
||||
.get(sess_id)
|
||||
.await
|
||||
.deserialized::<Cbor<_>>()
|
||||
.map(at!(0))
|
||||
}
|
||||
|
||||
/// Resolve the `sess_id` from an associated `user_id`.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_sess_id_by_user(&self, user_id: &UserId) -> Result<String> {
|
||||
self.db
|
||||
.userid_oauthid
|
||||
.get(user_id)
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
/// Resolve the `sess_id` from an associated provider subject id.
|
||||
#[implement(Sessions)]
|
||||
#[tracing::instrument(level = "debug", skip(self), ret(level = "debug"))]
|
||||
pub async fn get_sess_id_by_unique_id(&self, unique_id: &str) -> Result<String> {
|
||||
self.db
|
||||
.oauthuniqid_oauthid
|
||||
.get(unique_id)
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
#[implement(Sessions)]
|
||||
pub async fn provider(&self, session: &Session) -> Result<Provider> {
|
||||
let Some(idp_id) = session.idp_id.as_deref() else {
|
||||
return Err!(Request(NotFound("No provider for this session")));
|
||||
};
|
||||
|
||||
self.providers.get(idp_id).await
|
||||
}
|
||||
39
src/service/oauth/user_info.rs
Normal file
39
src/service/oauth/user_info.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Selection of userinfo response claims.
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct UserInfo {
|
||||
/// Unique identifier number or login username. Usually a number on most
|
||||
/// services. We consider a concatenation of the `iss` and `sub` to be a
|
||||
/// universally unique identifier for some user/identity; we index that in
|
||||
/// `oauthidpsub_oauthid`.
|
||||
///
|
||||
/// Considered for user mxid only if none of the better fields are defined.
|
||||
/// `login` alias intended for github.
|
||||
#[serde(alias = "login")]
|
||||
pub sub: String,
|
||||
|
||||
/// The login username we first consider when defined.
|
||||
pub preferred_username: Option<String>,
|
||||
|
||||
/// The login username considered if none preferred.
|
||||
pub nickname: Option<String>,
|
||||
|
||||
/// Full name.
|
||||
pub name: Option<String>,
|
||||
|
||||
/// First name.
|
||||
pub given_name: Option<String>,
|
||||
|
||||
/// Last name.
|
||||
pub family_name: Option<String>,
|
||||
|
||||
/// Email address (`email` scope).
|
||||
pub email: Option<String>,
|
||||
|
||||
/// URL to pfp (github/gitlab)
|
||||
pub avatar_url: Option<String>,
|
||||
|
||||
/// URL to pfp (google)
|
||||
pub picture: Option<String>,
|
||||
}
|
||||
@@ -12,7 +12,7 @@ use crate::{
|
||||
account_data, admin, appservice, client, config, deactivate, emergency, federation, globals,
|
||||
key_backups,
|
||||
manager::Manager,
|
||||
media, membership, presence, pusher, resolver, rooms, sending, server_keys,
|
||||
media, membership, oauth, presence, pusher, resolver, rooms, sending, server_keys,
|
||||
service::{Args, Service},
|
||||
sync, transaction_ids, uiaa, users,
|
||||
};
|
||||
@@ -58,6 +58,7 @@ pub struct Services {
|
||||
pub users: Arc<users::Service>,
|
||||
pub membership: Arc<membership::Service>,
|
||||
pub deactivate: Arc<deactivate::Service>,
|
||||
pub oauth: Arc<oauth::Service>,
|
||||
|
||||
manager: Mutex<Option<Arc<Manager>>>,
|
||||
pub server: Arc<Server>,
|
||||
@@ -115,6 +116,7 @@ pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
|
||||
users: users::Service::build(&args)?,
|
||||
membership: membership::Service::build(&args)?,
|
||||
deactivate: deactivate::Service::build(&args)?,
|
||||
oauth: oauth::Service::build(&args)?,
|
||||
|
||||
manager: Mutex::new(None),
|
||||
server,
|
||||
@@ -173,6 +175,7 @@ pub(crate) fn services(&self) -> impl Iterator<Item = Arc<dyn Service>> + Send {
|
||||
cast!(self.users),
|
||||
cast!(self.membership),
|
||||
cast!(self.deactivate),
|
||||
cast!(self.oauth),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ use ruma::{
|
||||
events::{GlobalAccountDataEventType, ignored_user_list::IgnoredUserListEvent},
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Err, Result, debug_warn, err, is_equal_to,
|
||||
Err, Result, debug_warn, err,
|
||||
pdu::PduBuilder,
|
||||
trace,
|
||||
utils::{self, ReadyExt, stream::TryIgnore},
|
||||
utils::{self, ReadyExt, result::LogErr, stream::TryIgnore},
|
||||
warn,
|
||||
};
|
||||
use tuwunel_database::{Deserialized, Json, Map};
|
||||
@@ -132,6 +132,16 @@ impl Service {
|
||||
|
||||
/// Deactivate account
|
||||
pub async fn deactivate_account(&self, user_id: &UserId) -> Result {
|
||||
// Revoke any SSO authorizations
|
||||
if let Ok((provider, session)) = self.services.oauth.get_user(user_id).await {
|
||||
self.services
|
||||
.oauth
|
||||
.revoke_token((&provider, &session))
|
||||
.await
|
||||
.log_err()
|
||||
.ok();
|
||||
}
|
||||
|
||||
// Remove all associated devices
|
||||
self.all_device_ids(user_id)
|
||||
.for_each(|device_id| self.remove_device(user_id, device_id))
|
||||
@@ -227,18 +237,15 @@ impl Service {
|
||||
// Cannot change the password of a LDAP user. There are two special cases :
|
||||
// - a `None` password can be used to deactivate a LDAP user
|
||||
// - a "*" password is used as the default password of an active LDAP user
|
||||
if cfg!(feature = "ldap")
|
||||
&& password.is_some()
|
||||
//
|
||||
// The above now applies to all non-password origin users including SSO
|
||||
// users. Note that users with no origin are also password-origin users.
|
||||
if password.is_some()
|
||||
&& password != Some("*")
|
||||
&& self
|
||||
.db
|
||||
.userid_origin
|
||||
.get(user_id)
|
||||
.await
|
||||
.deserialized::<String>()
|
||||
.is_ok_and(is_equal_to!("ldap"))
|
||||
&& let Ok(origin) = self.origin(user_id).await
|
||||
&& origin != "password"
|
||||
{
|
||||
return Err!(Request(InvalidParam("Cannot change password of a LDAP user")));
|
||||
return Err!(Request(InvalidParam("Cannot change password of an {origin:?} user.")));
|
||||
}
|
||||
|
||||
password
|
||||
|
||||
Reference in New Issue
Block a user