Implement SSO/OIDC support. (closes #7)
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -101,6 +101,7 @@ tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
tuwunel-core.workspace = true
|
||||
tuwunel-service.workspace = true
|
||||
url.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -4,6 +4,7 @@ mod ldap;
|
||||
mod logout;
|
||||
mod password;
|
||||
mod refresh;
|
||||
mod sso;
|
||||
mod token;
|
||||
|
||||
use axum::extract::State;
|
||||
@@ -12,8 +13,8 @@ use ruma::api::client::session::{
|
||||
get_login_types::{
|
||||
self,
|
||||
v3::{
|
||||
ApplicationServiceLoginType, JwtLoginType, LoginType, PasswordLoginType,
|
||||
TokenLoginType,
|
||||
ApplicationServiceLoginType, IdentityProvider, JwtLoginType, LoginType,
|
||||
PasswordLoginType, SsoLoginType, TokenLoginType,
|
||||
},
|
||||
},
|
||||
login::{
|
||||
@@ -28,6 +29,7 @@ use self::{ldap::ldap_login, password::password_login};
|
||||
pub(crate) use self::{
|
||||
logout::{logout_all_route, logout_route},
|
||||
refresh::refresh_token_route,
|
||||
sso::{sso_callback_route, sso_login_route, sso_login_with_provider_route},
|
||||
token::login_token_route,
|
||||
};
|
||||
use super::TOKEN_LENGTH;
|
||||
@@ -50,6 +52,20 @@ pub(crate) async fn get_login_types_route(
|
||||
LoginType::Token(TokenLoginType {
|
||||
get_login_token: services.config.login_via_existing_session,
|
||||
}),
|
||||
LoginType::Sso(SsoLoginType {
|
||||
identity_providers: services
|
||||
.config
|
||||
.identity_provider
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|config| IdentityProvider {
|
||||
id: config.id().to_owned(),
|
||||
brand: Some(config.brand.clone().into()),
|
||||
icon: config.icon,
|
||||
name: config.name.unwrap_or(config.brand),
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
]))
|
||||
}
|
||||
|
||||
@@ -89,6 +105,10 @@ pub(crate) async fn login_route(
|
||||
},
|
||||
};
|
||||
|
||||
if !services.users.is_active_local(&user_id).await {
|
||||
return Err!(Request(UserDeactivated("This user has been deactivated.")));
|
||||
}
|
||||
|
||||
// Generate a new token for the device
|
||||
let (access_token, expires_in) = services
|
||||
.users
|
||||
|
||||
601
src/api/client/session/sso.rs
Normal file
601
src/api/client/session/sso.rs
Normal file
@@ -0,0 +1,601 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
|
||||
use futures::{StreamExt, TryFutureExt, future::try_join};
|
||||
use itertools::Itertools;
|
||||
use reqwest::header::{CONTENT_TYPE, HeaderValue};
|
||||
use ruma::{
|
||||
Mxc, OwnedRoomId, OwnedUserId, ServerName, UserId,
|
||||
api::client::session::{sso_callback, sso_login, sso_login_with_provider},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{
|
||||
Err, Result, at,
|
||||
debug::INFO_SPAN_LEVEL,
|
||||
debug_info, debug_warn, err, info, utils,
|
||||
utils::{
|
||||
content_disposition::make_content_disposition,
|
||||
hash::sha256,
|
||||
result::{FlatOk, LogErr},
|
||||
string::{EMPTY, truncate_deterministic},
|
||||
timepoint_from_now, timepoint_has_passed,
|
||||
},
|
||||
warn,
|
||||
};
|
||||
use tuwunel_service::{
|
||||
Services,
|
||||
media::MXC_LENGTH,
|
||||
oauth::{
|
||||
CODE_VERIFIER_LENGTH, Provider, SESSION_ID_LENGTH, Session, UserInfo, unique_id,
|
||||
unique_id_sub,
|
||||
},
|
||||
users::Register,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use super::TOKEN_LENGTH;
|
||||
use crate::Ruma;
|
||||
|
||||
/// Grant phase query string.
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GrantQuery<'a> {
|
||||
client_id: &'a str,
|
||||
state: &'a str,
|
||||
nonce: &'a str,
|
||||
scope: &'a str,
|
||||
response_type: &'a str,
|
||||
access_type: &'a str,
|
||||
code_challenge_method: &'a str,
|
||||
code_challenge: &'a str,
|
||||
redirect_uri: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct GrantCookie<'a> {
|
||||
client_id: &'a str,
|
||||
state: &'a str,
|
||||
nonce: &'a str,
|
||||
redirect_uri: &'a str,
|
||||
}
|
||||
|
||||
static GRANT_SESSION_COOKIE: &str = "tuwunel_grant_session";
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "sso_login",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(%client),
|
||||
)]
|
||||
pub(crate) async fn sso_login_route(
|
||||
State(_services): State<crate::State>,
|
||||
InsecureClientIp(client): InsecureClientIp,
|
||||
_body: Ruma<sso_login::v3::Request>,
|
||||
) -> Result<sso_login::v3::Response> {
|
||||
Err!(Request(NotImplemented(
|
||||
"SSO login without specific provider has not been implemented."
|
||||
)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "sso_login_with_provider",
|
||||
level = "info",
|
||||
skip_all,
|
||||
ret(level = "debug")
|
||||
fields(
|
||||
%client,
|
||||
idp_id = body.body.idp_id,
|
||||
),
|
||||
)]
|
||||
pub(crate) async fn sso_login_with_provider_route(
|
||||
State(services): State<crate::State>,
|
||||
InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<sso_login_with_provider::v3::Request>,
|
||||
) -> Result<sso_login_with_provider::v3::Response> {
|
||||
let sso_login_with_provider::v3::Request { idp_id, redirect_url } = body.body;
|
||||
let Ok(redirect_url) = redirect_url.parse::<Url>() else {
|
||||
return Err!(Request(InvalidParam("Invalid redirect_url")));
|
||||
};
|
||||
|
||||
let provider = services.oauth.providers.get(&idp_id).await?;
|
||||
let sess_id = utils::random_string(SESSION_ID_LENGTH);
|
||||
let query_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||
let cookie_nonce = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||
let code_verifier = utils::random_string(CODE_VERIFIER_LENGTH);
|
||||
let code_challenge = b64.encode(sha256::hash(code_verifier.as_bytes()));
|
||||
let callback_uri = provider.callback_url.as_ref().map(Url::as_str);
|
||||
let scope = provider.scope.iter().join(" ");
|
||||
|
||||
let query = GrantQuery {
|
||||
client_id: &provider.client_id,
|
||||
state: &sess_id,
|
||||
nonce: &query_nonce,
|
||||
access_type: "online",
|
||||
response_type: "code",
|
||||
code_challenge_method: "S256",
|
||||
code_challenge: &code_challenge,
|
||||
redirect_uri: callback_uri,
|
||||
scope: scope
|
||||
.is_empty()
|
||||
.then_some("openid email profile")
|
||||
.unwrap_or(scope.as_str()),
|
||||
};
|
||||
|
||||
let location = provider
|
||||
.authorization_url
|
||||
.clone()
|
||||
.map(|mut location| {
|
||||
let query = serde_html_form::to_string(&query).ok();
|
||||
location.set_query(query.as_deref());
|
||||
location
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
err!(Config("authorization_url", "Missing required IdentityProvider config"))
|
||||
})?;
|
||||
|
||||
let cookie_val = GrantCookie {
|
||||
client_id: query.client_id,
|
||||
state: query.state,
|
||||
nonce: &cookie_nonce,
|
||||
redirect_uri: redirect_url.as_str(),
|
||||
};
|
||||
|
||||
let cookie_path = provider
|
||||
.callback_url
|
||||
.as_ref()
|
||||
.map(Url::path)
|
||||
.unwrap_or("/");
|
||||
|
||||
let cookie_max_age = provider
|
||||
.grant_session_duration
|
||||
.map(Duration::from_secs)
|
||||
.expect("Defaulted to Some value during configure_idp()")
|
||||
.try_into()
|
||||
.expect("std::time::Duration to time::Duration conversion failure");
|
||||
|
||||
let cookie = Cookie::build((GRANT_SESSION_COOKIE, serde_html_form::to_string(&cookie_val)?))
|
||||
.path(cookie_path)
|
||||
.max_age(cookie_max_age)
|
||||
.same_site(SameSite::None)
|
||||
.secure(true)
|
||||
.http_only(true)
|
||||
.build()
|
||||
.to_string()
|
||||
.into();
|
||||
|
||||
let session = Session {
|
||||
idp_id: Some(idp_id),
|
||||
sess_id: Some(sess_id.clone()),
|
||||
redirect_url: Some(redirect_url),
|
||||
code_verifier: Some(code_verifier),
|
||||
query_nonce: Some(query_nonce),
|
||||
cookie_nonce: Some(cookie_nonce),
|
||||
authorize_expires_at: provider
|
||||
.grant_session_duration
|
||||
.map(Duration::from_secs)
|
||||
.map(timepoint_from_now)
|
||||
.transpose()?,
|
||||
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
services
|
||||
.oauth
|
||||
.sessions
|
||||
.put(&sess_id, &session)
|
||||
.await;
|
||||
|
||||
Ok(sso_login_with_provider::v3::Response {
|
||||
location: location.into(),
|
||||
cookie: Some(cookie),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "sso_callback"
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
%client,
|
||||
cookie = ?body.cookie,
|
||||
body = ?body.body,
|
||||
),
|
||||
)]
|
||||
pub(crate) async fn sso_callback_route(
|
||||
State(services): State<crate::State>,
|
||||
InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<sso_callback::unstable::Request>,
|
||||
) -> Result<sso_callback::unstable::Response> {
|
||||
let sess_id = body
|
||||
.body
|
||||
.state
|
||||
.as_deref()
|
||||
.ok_or_else(|| err!(Request(Forbidden("Missing sess_id in callback."))))?;
|
||||
|
||||
let code = body
|
||||
.body
|
||||
.code
|
||||
.as_deref()
|
||||
.ok_or_else(|| err!(Request(Forbidden("Missing code in callback."))))?;
|
||||
|
||||
let session = services
|
||||
.oauth
|
||||
.sessions
|
||||
.get(sess_id)
|
||||
.map_err(|_| err!(Request(Forbidden("Invalid state in callback"))));
|
||||
|
||||
let idp_id = body.body.idp_id.as_str();
|
||||
let provider = services.oauth.providers.get(idp_id);
|
||||
let (provider, session) = try_join(provider, session).await.log_err()?;
|
||||
let client_id = &provider.client_id;
|
||||
|
||||
if session.idp_id.as_deref() != Some(idp_id) {
|
||||
return Err!(Request(Unauthorized("Identity Provider {idp_id} session not recognized.")));
|
||||
}
|
||||
|
||||
if session
|
||||
.authorize_expires_at
|
||||
.map(timepoint_has_passed)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return Err!(Request(Unauthorized("Authorization grant session has expired.")));
|
||||
}
|
||||
|
||||
let cookie = body
|
||||
.cookie
|
||||
.get(GRANT_SESSION_COOKIE)
|
||||
.map(Cookie::value)
|
||||
.map(serde_html_form::from_str::<GrantCookie<'_>>)
|
||||
.transpose()?
|
||||
.ok_or_else(|| err!(Request(Unauthorized("Missing cookie {GRANT_SESSION_COOKIE:?}"))))?;
|
||||
|
||||
if cookie.client_id != client_id {
|
||||
return Err!(Request(Unauthorized("Client ID {client_id:?} cookie mismatch.")));
|
||||
}
|
||||
|
||||
if Some(cookie.nonce) != session.cookie_nonce.as_deref() {
|
||||
return Err!(Request(Unauthorized("Cookie nonce does not match session state.")));
|
||||
}
|
||||
|
||||
if cookie.state != sess_id {
|
||||
return Err!(Request(Unauthorized("Session ID {sess_id:?} cookie mismatch.")));
|
||||
}
|
||||
|
||||
// Request access token.
|
||||
let token_response = services
|
||||
.oauth
|
||||
.request_token((&provider, &session), code)
|
||||
.await?;
|
||||
|
||||
let token_expires_at = token_response
|
||||
.expires_in
|
||||
.map(Duration::from_secs)
|
||||
.map(timepoint_from_now)
|
||||
.transpose()?;
|
||||
|
||||
let refresh_token_expires_at = token_response
|
||||
.refresh_token_expires_in
|
||||
.map(Duration::from_secs)
|
||||
.map(timepoint_from_now)
|
||||
.transpose()?;
|
||||
|
||||
// Update the session with access token results
|
||||
let session = Session {
|
||||
scope: token_response.scope,
|
||||
token_type: token_response.token_type,
|
||||
access_token: token_response.access_token,
|
||||
expires_at: token_expires_at,
|
||||
refresh_token: token_response.refresh_token,
|
||||
refresh_token_expires_at,
|
||||
..session
|
||||
};
|
||||
|
||||
// Request userinfo claims.
|
||||
let userinfo = services
|
||||
.oauth
|
||||
.request_userinfo((&provider, &session))
|
||||
.await?;
|
||||
|
||||
// Check for an existing session from this identity. We want to maintain one
|
||||
// session for each identity and keep the newer one which has up-to-date state
|
||||
// and access.
|
||||
let (user_id, old_sess_id) = match services
|
||||
.oauth
|
||||
.sessions
|
||||
.get_by_unique_id(&unique_id_sub((&provider, &userinfo.sub))?)
|
||||
.await
|
||||
{
|
||||
| Ok(session) => (session.user_id, session.sess_id),
|
||||
| Err(error) if !error.is_not_found() => return Err(error),
|
||||
| Err(_) => (None, None),
|
||||
};
|
||||
|
||||
// Update the session with userinfo
|
||||
let session = Session {
|
||||
user_info: Some(userinfo.clone()),
|
||||
..session
|
||||
};
|
||||
|
||||
// Keep the user_id from the old session as best as possible.
|
||||
let user_id = match user_id {
|
||||
| Some(user_id) => user_id,
|
||||
| None => decide_user_id(&services, &provider, &session, &userinfo).await?,
|
||||
};
|
||||
|
||||
// Update the session with user_id
|
||||
let session = Session {
|
||||
user_id: Some(user_id.clone()),
|
||||
..session
|
||||
};
|
||||
|
||||
// Attempt to register a non-existing user.
|
||||
if !services.users.exists(&user_id).await {
|
||||
register_user(&services, &provider, &session, &userinfo, &user_id).await?;
|
||||
}
|
||||
|
||||
// Commit the updated session.
|
||||
services
|
||||
.oauth
|
||||
.sessions
|
||||
.put(sess_id, &session)
|
||||
.await;
|
||||
|
||||
// Delete any old session.
|
||||
if let Some(old_sess_id) = old_sess_id
|
||||
&& sess_id != old_sess_id
|
||||
{
|
||||
services.oauth.sessions.delete(&old_sess_id).await;
|
||||
}
|
||||
|
||||
if !services.users.is_active_local(&user_id).await {
|
||||
return Err!(Request(UserDeactivated("This user has been deactivated.")));
|
||||
}
|
||||
|
||||
// Allow the user to login to Matrix.
|
||||
let login_token = utils::random_string(TOKEN_LENGTH);
|
||||
let _login_token_expires_in = services
|
||||
.users
|
||||
.create_login_token(&user_id, &login_token);
|
||||
|
||||
let location = session
|
||||
.redirect_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| err!(Request(InvalidParam("Missing redirect URL in session data"))))?
|
||||
.clone()
|
||||
.query_pairs_mut()
|
||||
.append_pair("loginToken", &login_token)
|
||||
.finish()
|
||||
.to_string();
|
||||
|
||||
let cookie = Cookie::build((GRANT_SESSION_COOKIE, EMPTY))
|
||||
.removal()
|
||||
.build()
|
||||
.to_string()
|
||||
.into();
|
||||
|
||||
Ok(sso_callback::unstable::Response { location, cookie: Some(cookie) })
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "register",
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
fields(user_id, userinfo)
|
||||
)]
|
||||
async fn register_user(
|
||||
services: &Services,
|
||||
provider: &Provider,
|
||||
session: &Session,
|
||||
userinfo: &UserInfo,
|
||||
user_id: &UserId,
|
||||
) -> Result {
|
||||
debug_info!(%user_id, "Creating new user account...");
|
||||
|
||||
services
|
||||
.users
|
||||
.full_register(Register {
|
||||
user_id: Some(user_id),
|
||||
password: Some("*"),
|
||||
origin: Some("sso"),
|
||||
displayname: userinfo.name.as_deref(),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
|
||||
if let Some(avatar_url) = userinfo
|
||||
.avatar_url
|
||||
.as_deref()
|
||||
.or(userinfo.picture.as_deref())
|
||||
{
|
||||
set_avatar(services, provider, session, userinfo, user_id, avatar_url)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
|
||||
let idp_id = provider.id();
|
||||
let idp_name = provider
|
||||
.name
|
||||
.as_deref()
|
||||
.unwrap_or(provider.brand.as_str());
|
||||
|
||||
// log in conduit admin channel if a non-guest user registered
|
||||
let notice =
|
||||
format!("New user \"{user_id}\" registered on this server via {idp_name} ({idp_id})",);
|
||||
|
||||
info!("{notice}");
|
||||
if services.server.config.admin_room_notices {
|
||||
services.admin.notice(¬ice).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, fields(user_id, avatar_url))]
|
||||
async fn set_avatar(
|
||||
services: &Services,
|
||||
_provider: &Provider,
|
||||
_session: &Session,
|
||||
_userinfo: &UserInfo,
|
||||
user_id: &UserId,
|
||||
avatar_url: &str,
|
||||
) -> Result {
|
||||
use reqwest::Response;
|
||||
|
||||
let response = services
|
||||
.client
|
||||
.default
|
||||
.get(avatar_url)
|
||||
.send()
|
||||
.await
|
||||
.and_then(Response::error_for_status)?;
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.map(HeaderValue::to_str)
|
||||
.flat_ok()
|
||||
.map(ToOwned::to_owned);
|
||||
|
||||
let mxc = Mxc {
|
||||
server_name: services.globals.server_name(),
|
||||
media_id: &utils::random_string(MXC_LENGTH),
|
||||
};
|
||||
|
||||
let content_disposition = make_content_disposition(None, content_type.as_deref(), None);
|
||||
let bytes = response.bytes().await?;
|
||||
services
|
||||
.media
|
||||
.create(&mxc, Some(user_id), Some(&content_disposition), content_type.as_deref(), &bytes)
|
||||
.await?;
|
||||
|
||||
let all_joined_rooms: Vec<OwnedRoomId> = services
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
.map(ToOwned::to_owned)
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let mxc_uri = mxc.to_string().into();
|
||||
services
|
||||
.users
|
||||
.update_avatar_url(user_id, Some(mxc_uri), None, &all_joined_rooms)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
ret(level = "debug")
|
||||
skip_all,
|
||||
fields(user),
|
||||
)]
|
||||
async fn decide_user_id(
|
||||
services: &Services,
|
||||
provider: &Provider,
|
||||
session: &Session,
|
||||
userinfo: &UserInfo,
|
||||
) -> Result<OwnedUserId> {
|
||||
let allowed =
|
||||
|claim: &str| provider.userid_claims.is_empty() || provider.userid_claims.contains(claim);
|
||||
|
||||
let choices = [
|
||||
userinfo
|
||||
.preferred_username
|
||||
.as_deref()
|
||||
.map(str::to_lowercase)
|
||||
.filter(|_| allowed("preferred_username")),
|
||||
userinfo
|
||||
.nickname
|
||||
.as_deref()
|
||||
.map(str::to_lowercase)
|
||||
.filter(|_| allowed("nickname")),
|
||||
provider
|
||||
.brand
|
||||
.eq(&"github")
|
||||
.then_some(userinfo.sub.as_str())
|
||||
.map(str::to_lowercase)
|
||||
.filter(|_| allowed("login")),
|
||||
userinfo
|
||||
.email
|
||||
.as_deref()
|
||||
.and_then(|email| email.split_once('@'))
|
||||
.map(at!(0))
|
||||
.map(str::to_lowercase)
|
||||
.filter(|_| allowed("email")),
|
||||
];
|
||||
|
||||
for choice in choices.into_iter().flatten() {
|
||||
if let Some(user_id) = try_user_id(services, &choice, false).await {
|
||||
return Ok(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(infallible) = unique_id((provider, session))
|
||||
.map(|h| truncate_deterministic(&h, Some(15..23)).to_lowercase())
|
||||
.log_err()
|
||||
{
|
||||
if let Some(user_id) = try_user_id(services, &infallible, true).await {
|
||||
return Ok(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
Err!(Request(UserInUse("User ID is not available.")))
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all, fields(username))]
|
||||
async fn try_user_id(
|
||||
services: &Services,
|
||||
username: &str,
|
||||
may_exist: bool,
|
||||
) -> Option<OwnedUserId> {
|
||||
let server_name = services.globals.server_name();
|
||||
let user_id = parse_user_id(server_name, username)
|
||||
.inspect_err(|e| warn!(?username, "Username invalid: {e}"))
|
||||
.ok()?;
|
||||
|
||||
if services
|
||||
.config
|
||||
.forbidden_usernames
|
||||
.is_match(username)
|
||||
{
|
||||
warn!(?username, "Username forbidden.");
|
||||
return None;
|
||||
}
|
||||
|
||||
if services.users.exists(&user_id).await {
|
||||
debug_warn!(?username, "Username exists.");
|
||||
|
||||
if services
|
||||
.users
|
||||
.origin(&user_id)
|
||||
.await
|
||||
.ok()
|
||||
.is_none_or(|origin| origin != "sso")
|
||||
{
|
||||
debug_warn!(?username, "Username has non-sso origin.");
|
||||
return None;
|
||||
}
|
||||
|
||||
if !may_exist {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(user_id)
|
||||
}
|
||||
|
||||
fn parse_user_id(server_name: &ServerName, username: &str) -> Result<OwnedUserId> {
|
||||
match UserId::parse_with_server_name(username, server_name) {
|
||||
| Err(e) =>
|
||||
Err!(Request(InvalidUsername(debug_error!("Username {username} is not valid: {e}")))),
|
||||
| Ok(user_id) => match user_id.validate_strict() {
|
||||
| Ok(()) => Ok(user_id),
|
||||
| Err(e) => Err!(Request(InvalidUsername(debug_error!(
|
||||
"Username {username} contains disallowed characters or spaces: {e}"
|
||||
)))),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,9 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
|
||||
.ruma_route(&client::login_route)
|
||||
.ruma_route(&client::login_token_route)
|
||||
.ruma_route(&client::refresh_token_route)
|
||||
.ruma_route(&client::sso_login_route)
|
||||
.ruma_route(&client::sso_login_with_provider_route)
|
||||
.ruma_route(&client::sso_callback_route)
|
||||
.ruma_route(&client::whoami_route)
|
||||
.ruma_route(&client::logout_route)
|
||||
.ruma_route(&client::logout_all_route)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::{fmt::Debug, mem, ops::Deref};
|
||||
|
||||
use axum::{body::Body, extract::FromRequest};
|
||||
use axum_extra::extract::cookie::CookieJar;
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use ruma::{
|
||||
CanonicalJsonObject, CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedServerName,
|
||||
@@ -18,6 +19,9 @@ pub(crate) struct Args<T> {
|
||||
/// Request struct body
|
||||
pub(crate) body: T,
|
||||
|
||||
/// Cookies received from the useragent.
|
||||
pub(crate) cookie: CookieJar,
|
||||
|
||||
/// Federation server authentication: X-Matrix origin
|
||||
/// None when not a federation server.
|
||||
pub(crate) origin: Option<OwnedServerName>,
|
||||
@@ -110,6 +114,7 @@ where
|
||||
let auth = auth::auth(services, &mut request, json_body.as_ref(), &T::METADATA).await?;
|
||||
Ok(Self {
|
||||
body: make_body::<T>(services, &mut request, json_body.as_mut(), &auth)?,
|
||||
cookie: request.cookie,
|
||||
origin: auth.origin,
|
||||
sender_user: auth.sender_user,
|
||||
sender_device: auth.sender_device,
|
||||
|
||||
@@ -5,7 +5,7 @@ use ruma::{
|
||||
client::uiaa::{AuthData, AuthFlow, AuthType, Jwt, UiaaInfo},
|
||||
},
|
||||
};
|
||||
use tuwunel_core::{Err, Error, Result, err, utils};
|
||||
use tuwunel_core::{Err, Error, Result, err, is_equal_to, utils};
|
||||
use tuwunel_service::{Services, uiaa::SESSION_ID_LENGTH};
|
||||
|
||||
use crate::{Ruma, client::jwt};
|
||||
@@ -67,9 +67,19 @@ where
|
||||
| Some(ref json) => {
|
||||
let sender_user = body
|
||||
.sender_user
|
||||
.as_ref()
|
||||
.as_deref()
|
||||
.ok_or_else(|| err!(Request(MissingToken("Missing access token."))))?;
|
||||
|
||||
// Skip UIAA for SSO/OIDC users.
|
||||
if services
|
||||
.users
|
||||
.origin(sender_user)
|
||||
.await
|
||||
.is_ok_and(is_equal_to!("sso"))
|
||||
{
|
||||
return Ok(sender_user.to_owned());
|
||||
}
|
||||
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services
|
||||
.uiaa
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::str;
|
||||
|
||||
use axum::{RequestExt, RequestPartsExt, extract::Path};
|
||||
use axum_extra::extract::cookie::CookieJar;
|
||||
use bytes::Bytes;
|
||||
use http::request::Parts;
|
||||
use serde::Deserialize;
|
||||
@@ -17,6 +18,7 @@ pub(super) type UserId = SmallString<[u8; 48]>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct Request {
|
||||
pub(super) cookie: CookieJar,
|
||||
pub(super) path: Path<PathParams>,
|
||||
pub(super) query: QueryParams,
|
||||
pub(super) body: Bytes,
|
||||
@@ -33,6 +35,7 @@ pub(super) async fn from(
|
||||
let limited = request.with_limited_body();
|
||||
let (mut parts, body) = limited.into_parts();
|
||||
|
||||
let cookie: CookieJar = parts.extract().await?;
|
||||
let path: Path<PathParams> = parts.extract().await?;
|
||||
let query = parts.uri.query().unwrap_or_default();
|
||||
let query = serde_html_form::from_str(query)
|
||||
@@ -44,5 +47,5 @@ pub(super) async fn from(
|
||||
.await
|
||||
.map_err(|e| err!(Request(TooLarge("Request body too large: {e}"))))?;
|
||||
|
||||
Ok(Request { path, query, body, parts })
|
||||
Ok(Request { cookie, path, query, body, parts })
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user