feat: implement OIDC server for next-gen auth (MSC2965/2964/2966/2967)
Implements a built-in OIDC authorization server that allows Matrix clients
like Element X to authenticate via OIDC, delegating user authentication
to upstream identity providers (e.g. Kanidm) through the existing SSO flow.
## Endpoints
- GET /_matrix/client/unstable/org.matrix.msc2965/auth_issuer
- GET /.well-known/openid-configuration
- POST /_tuwunel/oidc/registration (Dynamic Client Registration)
- GET /_tuwunel/oidc/authorize → SSO redirect → _complete bridge
- POST /_tuwunel/oidc/token (auth code exchange + refresh)
- POST /_tuwunel/oidc/revoke
- GET /_tuwunel/oidc/jwks
- GET /_tuwunel/oidc/userinfo
- GET /_tuwunel/oidc/account (placeholder)
## Spec compliance fixes
- OAuth error responses use RFC 6749 §5.2 format ({"error": "...", "error_description": "..."})
- PKCE code_verifier validation per RFC 7636 §4.1
- Scope token matching uses exact whitespace-delimited comparison per RFC 6749 §3.3
- Typed ProviderMetadata struct for the discovery document
- DCR includes policy_uri, tos_uri, software_id, software_version per RFC 7591
Refs: #246, #266
This commit is contained in:
@@ -16,6 +16,7 @@ pub(super) mod media_legacy;
|
||||
pub(super) mod membership;
|
||||
pub(super) mod message;
|
||||
pub(super) mod openid;
|
||||
pub(super) mod oidc;
|
||||
pub(super) mod presence;
|
||||
pub(super) mod profile;
|
||||
pub(super) mod push;
|
||||
@@ -63,6 +64,7 @@ pub(super) use media_legacy::*;
|
||||
pub(super) use membership::*;
|
||||
pub(super) use message::*;
|
||||
pub(super) use openid::*;
|
||||
pub(super) use oidc::*;
|
||||
pub(super) use presence::*;
|
||||
pub(super) use profile::*;
|
||||
pub(super) use push::*;
|
||||
|
||||
222
src/api/client/oidc.rs
Normal file
222
src/api/client/oidc.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
use std::time::SystemTime;
|
||||
|
||||
use axum::{Json, extract::State, response::{IntoResponse, Redirect}};
|
||||
use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}};
|
||||
use http::StatusCode;
|
||||
use ruma::OwnedDeviceId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{Err, Result, err, info, utils};
|
||||
use tuwunel_service::{oauth::oidc_server::{DcrRequest, IdTokenClaims, OidcAuthRequest, OidcServer, ProviderMetadata}, users::device::generate_refresh_token};
|
||||
|
||||
const OIDC_REQ_ID_LENGTH: usize = 32;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AuthIssuerResponse { issuer: String }
|
||||
|
||||
pub(crate) async fn auth_issuer_route(State(services): State<crate::State>) -> Result<impl IntoResponse> {
|
||||
let issuer = oidc_issuer_url(&services)?;
|
||||
Ok(Json(AuthIssuerResponse { issuer }))
|
||||
}
|
||||
|
||||
pub(crate) async fn openid_configuration_route(State(services): State<crate::State>) -> Result<impl IntoResponse> {
|
||||
Ok(Json(oidc_metadata(&services)?))
|
||||
}
|
||||
|
||||
fn oidc_metadata(services: &tuwunel_service::Services) -> Result<ProviderMetadata> {
|
||||
let issuer = oidc_issuer_url(services)?;
|
||||
let base = issuer.trim_end_matches('/').to_owned();
|
||||
|
||||
Ok(ProviderMetadata {
|
||||
issuer,
|
||||
authorization_endpoint: format!("{base}/_tuwunel/oidc/authorize"),
|
||||
token_endpoint: format!("{base}/_tuwunel/oidc/token"),
|
||||
registration_endpoint: Some(format!("{base}/_tuwunel/oidc/registration")),
|
||||
revocation_endpoint: Some(format!("{base}/_tuwunel/oidc/revoke")),
|
||||
jwks_uri: format!("{base}/_tuwunel/oidc/jwks"),
|
||||
userinfo_endpoint: Some(format!("{base}/_tuwunel/oidc/userinfo")),
|
||||
account_management_uri: Some(format!("{base}/_tuwunel/oidc/account")),
|
||||
account_management_actions_supported: Some(vec!["org.matrix.profile".to_owned(), "org.matrix.sessions_list".to_owned(), "org.matrix.session_view".to_owned(), "org.matrix.session_end".to_owned(), "org.matrix.cross_signing_reset".to_owned()]),
|
||||
response_types_supported: vec!["code".to_owned()],
|
||||
response_modes_supported: Some(vec!["query".to_owned(), "fragment".to_owned()]),
|
||||
grant_types_supported: Some(vec!["authorization_code".to_owned(), "refresh_token".to_owned()]),
|
||||
code_challenge_methods_supported: Some(vec!["S256".to_owned()]),
|
||||
token_endpoint_auth_methods_supported: Some(vec!["none".to_owned(), "client_secret_basic".to_owned(), "client_secret_post".to_owned()]),
|
||||
scopes_supported: Some(vec!["openid".to_owned(), "urn:matrix:org.matrix.msc2967.client:api:*".to_owned(), "urn:matrix:org.matrix.msc2967.client:device:*".to_owned()]),
|
||||
subject_types_supported: Some(vec!["public".to_owned()]),
|
||||
id_token_signing_alg_values_supported: Some(vec!["ES256".to_owned()]),
|
||||
prompt_values_supported: Some(vec!["create".to_owned()]),
|
||||
claim_types_supported: Some(vec!["normal".to_owned()]),
|
||||
claims_supported: Some(vec!["iss".to_owned(), "sub".to_owned(), "aud".to_owned(), "exp".to_owned(), "iat".to_owned(), "nonce".to_owned()]),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn registration_route(State(services): State<crate::State>, Json(body): Json<DcrRequest>) -> Result<impl IntoResponse> {
|
||||
let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); };
|
||||
|
||||
if body.redirect_uris.is_empty() { return Err!(Request(InvalidParam("redirect_uris must not be empty"))); }
|
||||
|
||||
let reg = oidc.register_client(body)?;
|
||||
info!("OIDC client registered: {} ({})", reg.client_id, reg.client_name.as_deref().unwrap_or("unnamed"));
|
||||
|
||||
Ok((StatusCode::CREATED, Json(serde_json::json!({"client_id": reg.client_id, "client_id_issued_at": reg.registered_at, "redirect_uris": reg.redirect_uris, "client_name": reg.client_name, "client_uri": reg.client_uri, "logo_uri": reg.logo_uri, "contacts": reg.contacts, "token_endpoint_auth_method": reg.token_endpoint_auth_method, "grant_types": reg.grant_types, "response_types": reg.response_types, "application_type": reg.application_type}))))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct AuthorizeParams {
|
||||
client_id: String, redirect_uri: String, response_type: String, scope: String,
|
||||
state: Option<String>, nonce: Option<String>, code_challenge: Option<String>,
|
||||
code_challenge_method: Option<String>, #[serde(default, rename = "prompt")] _prompt: Option<String>,
|
||||
}
|
||||
|
||||
pub(crate) async fn authorize_route(State(services): State<crate::State>, request: axum::extract::Request) -> Result<impl IntoResponse> {
|
||||
let params: AuthorizeParams = serde_html_form::from_str(request.uri().query().unwrap_or_default())?;
|
||||
let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); };
|
||||
|
||||
if params.response_type != "code" { return Err!(Request(InvalidParam("Only response_type=code is supported"))); }
|
||||
|
||||
oidc.validate_redirect_uri(¶ms.client_id, ¶ms.redirect_uri).await?;
|
||||
|
||||
if !scope_contains_token(¶ms.scope, "openid") { return Err!(Request(InvalidParam("openid scope is required"))); }
|
||||
|
||||
let req_id = utils::random_string(OIDC_REQ_ID_LENGTH);
|
||||
let now = SystemTime::now();
|
||||
|
||||
oidc.store_auth_request(&req_id, &OidcAuthRequest {
|
||||
client_id: params.client_id, redirect_uri: params.redirect_uri, scope: params.scope,
|
||||
state: params.state, nonce: params.nonce, code_challenge: params.code_challenge,
|
||||
code_challenge_method: params.code_challenge_method, created_at: now,
|
||||
expires_at: now.checked_add(OidcServer::auth_request_lifetime()).unwrap_or(now),
|
||||
});
|
||||
|
||||
let default_idp = services.config.identity_provider.values().find(|idp| idp.default).or_else(|| services.config.identity_provider.values().next()).ok_or_else(|| err!(Config("identity_provider", "No identity provider configured")))?;
|
||||
let idp_id = default_idp.id();
|
||||
|
||||
let base = oidc_issuer_url(&services)?;
|
||||
let base = base.trim_end_matches('/');
|
||||
|
||||
let mut complete_url = url::Url::parse(&format!("{base}/_tuwunel/oidc/_complete")).map_err(|_| err!(error!("Failed to build complete URL")))?;
|
||||
complete_url.query_pairs_mut().append_pair("oidc_req_id", &req_id);
|
||||
|
||||
let mut sso_url = url::Url::parse(&format!("{base}/_matrix/client/v3/login/sso/redirect/{idp_id}")).map_err(|_| err!(error!("Failed to build SSO URL")))?;
|
||||
sso_url.query_pairs_mut().append_pair("redirectUrl", complete_url.as_str());
|
||||
|
||||
Ok(Redirect::temporary(sso_url.as_str()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct CompleteParams { oidc_req_id: String, #[serde(rename = "loginToken")] login_token: String }
|
||||
|
||||
pub(crate) async fn complete_route(State(services): State<crate::State>, request: axum::extract::Request) -> Result<impl IntoResponse> {
|
||||
let params: CompleteParams = serde_html_form::from_str(request.uri().query().unwrap_or_default())?;
|
||||
let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); };
|
||||
|
||||
let user_id = services.users.find_from_login_token(¶ms.login_token).await.map_err(|_| err!(Request(Forbidden("Invalid or expired login token"))))?;
|
||||
let auth_req = oidc.take_auth_request(¶ms.oidc_req_id).await?;
|
||||
let code = oidc.create_auth_code(&auth_req, user_id);
|
||||
|
||||
let mut redirect_url = url::Url::parse(&auth_req.redirect_uri).map_err(|_| err!(Request(InvalidParam("Invalid redirect_uri"))))?;
|
||||
redirect_url.query_pairs_mut().append_pair("code", &code);
|
||||
if let Some(state) = &auth_req.state { redirect_url.query_pairs_mut().append_pair("state", state); }
|
||||
|
||||
Ok(Redirect::temporary(redirect_url.as_str()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct TokenRequest {
|
||||
grant_type: String, code: Option<String>, redirect_uri: Option<String>, client_id: Option<String>,
|
||||
code_verifier: Option<String>, refresh_token: Option<String>, #[serde(rename = "scope")] _scope: Option<String>,
|
||||
}
|
||||
|
||||
pub(crate) async fn token_route(State(services): State<crate::State>, axum::extract::Form(body): axum::extract::Form<TokenRequest>) -> impl IntoResponse {
|
||||
match body.grant_type.as_str() {
|
||||
| "authorization_code" => token_authorization_code(&services, &body).await.unwrap_or_else(|e| oauth_error(StatusCode::INTERNAL_SERVER_ERROR, "server_error", &e.to_string())),
|
||||
| "refresh_token" => token_refresh(&services, &body).await.unwrap_or_else(|e| oauth_error(StatusCode::INTERNAL_SERVER_ERROR, "server_error", &e.to_string())),
|
||||
| _ => oauth_error(StatusCode::BAD_REQUEST, "unsupported_grant_type", "Unsupported grant_type"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn token_authorization_code(services: &tuwunel_service::Services, body: &TokenRequest) -> Result<http::Response<axum::body::Body>> {
|
||||
let code = body.code.as_deref().ok_or_else(|| err!(Request(InvalidParam("code is required"))))?;
|
||||
let redirect_uri = body.redirect_uri.as_deref().ok_or_else(|| err!(Request(InvalidParam("redirect_uri is required"))))?;
|
||||
let client_id = body.client_id.as_deref().ok_or_else(|| err!(Request(InvalidParam("client_id is required"))))?;
|
||||
|
||||
let oidc = get_oidc_server(services)?;
|
||||
let session = oidc.exchange_auth_code(code, client_id, redirect_uri, body.code_verifier.as_deref()).await?;
|
||||
|
||||
let user_id = &session.user_id;
|
||||
let (access_token, expires_in) = services.users.generate_access_token(true);
|
||||
let refresh_token = generate_refresh_token();
|
||||
|
||||
let device_id: Option<OwnedDeviceId> = extract_device_id(&session.scope).map(OwnedDeviceId::from);
|
||||
let device_id = services.users.create_device(user_id, device_id.as_deref(), (Some(&access_token), expires_in), Some(&refresh_token), Some("OIDC Client"), None).await?;
|
||||
|
||||
info!("{user_id} logged in via OIDC (device {device_id})");
|
||||
|
||||
let id_token = if session.scope.contains("openid") {
|
||||
let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap_or_default().as_secs();
|
||||
let issuer = oidc_issuer_url(services)?;
|
||||
let claims = IdTokenClaims { iss: issuer, sub: user_id.to_string(), aud: client_id.to_owned(), exp: now.saturating_add(3600), iat: now, nonce: session.nonce, at_hash: Some(OidcServer::at_hash(&access_token)) };
|
||||
Some(oidc.sign_id_token(&claims)?)
|
||||
} else { None };
|
||||
|
||||
let mut response = serde_json::json!({"access_token": access_token, "token_type": "Bearer", "scope": session.scope, "refresh_token": refresh_token});
|
||||
if let Some(expires_in) = expires_in { response["expires_in"] = serde_json::json!(expires_in.as_secs()); }
|
||||
if let Some(id_token) = id_token { response["id_token"] = serde_json::json!(id_token); }
|
||||
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
|
||||
async fn token_refresh(services: &tuwunel_service::Services, body: &TokenRequest) -> Result<http::Response<axum::body::Body>> {
|
||||
let refresh_token = body.refresh_token.as_deref().ok_or_else(|| err!(Request(InvalidParam("refresh_token is required"))))?;
|
||||
let (user_id, device_id, _) = services.users.find_from_token(refresh_token).await.map_err(|_| err!(Request(Forbidden("Invalid refresh token"))))?;
|
||||
|
||||
let (new_access_token, expires_in) = services.users.generate_access_token(true);
|
||||
let new_refresh_token = generate_refresh_token();
|
||||
services.users.set_access_token(&user_id, &device_id, &new_access_token, expires_in, Some(&new_refresh_token)).await?;
|
||||
|
||||
let mut response = serde_json::json!({"access_token": new_access_token, "token_type": "Bearer", "refresh_token": new_refresh_token});
|
||||
if let Some(expires_in) = expires_in { response["expires_in"] = serde_json::json!(expires_in.as_secs()); }
|
||||
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct RevokeRequest { token: String, #[serde(default, rename = "token_type_hint")] _token_type_hint: Option<String> }
|
||||
|
||||
pub(crate) async fn revoke_route(State(services): State<crate::State>, axum::extract::Form(body): axum::extract::Form<RevokeRequest>) -> Result<impl IntoResponse> {
|
||||
if let Ok((user_id, device_id, _)) = services.users.find_from_token(&body.token).await { services.users.remove_device(&user_id, &device_id).await; }
|
||||
Ok(Json(serde_json::json!({})))
|
||||
}
|
||||
|
||||
pub(crate) async fn jwks_route(State(services): State<crate::State>) -> Result<impl IntoResponse> {
|
||||
let oidc = get_oidc_server(&services)?;
|
||||
Ok(Json(oidc.jwks()))
|
||||
}
|
||||
|
||||
pub(crate) async fn userinfo_route(State(services): State<crate::State>, TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>) -> Result<impl IntoResponse> {
|
||||
let token = bearer.token();
|
||||
let Ok((user_id, _device_id, _expires)) = services.users.find_from_token(token).await else { return Err!(Request(Unauthorized("Invalid access token"))); };
|
||||
let displayname = services.users.displayname(&user_id).await.ok();
|
||||
let avatar_url = services.users.avatar_url(&user_id).await.ok();
|
||||
Ok(Json(serde_json::json!({"sub": user_id.to_string(), "name": displayname, "picture": avatar_url})))
|
||||
}
|
||||
|
||||
pub(crate) async fn account_route() -> impl IntoResponse {
|
||||
axum::response::Html("<html><body><h1>Account Management</h1><p>Account management is not yet implemented. Please use your identity provider to manage your account.</p></body></html>")
|
||||
}
|
||||
|
||||
fn oauth_error(status: StatusCode, error: &str, description: &str) -> http::Response<axum::body::Body> {
|
||||
(status, Json(serde_json::json!({"error": error, "error_description": description}))).into_response()
|
||||
}
|
||||
|
||||
fn scope_contains_token(scope: &str, token: &str) -> bool { scope.split_whitespace().any(|t| t == token) }
|
||||
|
||||
fn get_oidc_server(services: &tuwunel_service::Services) -> Result<&OidcServer> {
|
||||
services.oauth.oidc_server.as_deref().ok_or_else(|| err!(Request(NotFound("OIDC server not configured"))))
|
||||
}
|
||||
|
||||
fn oidc_issuer_url(services: &tuwunel_service::Services) -> Result<String> {
|
||||
services.config.well_known.client.as_ref().map(|url| { let s = url.to_string(); if s.ends_with('/') { s } else { s + "/" } }).ok_or_else(|| err!(Config("well_known.client", "well_known.client must be set for OIDC server")))
|
||||
}
|
||||
|
||||
fn extract_device_id(scope: &str) -> Option<String> { scope.split_whitespace().find_map(|s| s.strip_prefix("urn:matrix:org.matrix.msc2967.client:device:")).map(ToOwned::to_owned) }
|
||||
@@ -51,7 +51,7 @@ static VERSIONS: [&str; 17] = [
|
||||
"v1.15", /* custom profile fields */
|
||||
];
|
||||
|
||||
static UNSTABLE_FEATURES: [&str; 18] = [
|
||||
static UNSTABLE_FEATURES: [&str; 22] = [
|
||||
"org.matrix.e2e_cross_signing",
|
||||
// private read receipts (https://github.com/matrix-org/matrix-spec-proposals/pull/2285)
|
||||
"org.matrix.msc2285.stable",
|
||||
@@ -86,4 +86,12 @@ static UNSTABLE_FEATURES: [&str; 18] = [
|
||||
"org.matrix.simplified_msc3575",
|
||||
// Allow room moderators to view redacted event content (https://github.com/matrix-org/matrix-spec-proposals/pull/2815)
|
||||
"fi.mau.msc2815",
|
||||
// OIDC-native auth: authorization code grant (https://github.com/matrix-org/matrix-spec-proposals/pull/2964)
|
||||
"org.matrix.msc2964",
|
||||
// OIDC-native auth: auth issuer discovery (https://github.com/matrix-org/matrix-spec-proposals/pull/2965)
|
||||
"org.matrix.msc2965",
|
||||
// OIDC-native auth: dynamic client registration (https://github.com/matrix-org/matrix-spec-proposals/pull/2966)
|
||||
"org.matrix.msc2966",
|
||||
// OIDC-native auth: API scopes (https://github.com/matrix-org/matrix-spec-proposals/pull/2967)
|
||||
"org.matrix.msc2967",
|
||||
];
|
||||
|
||||
@@ -197,6 +197,20 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
|
||||
.ruma_route(&client::well_known_support)
|
||||
.ruma_route(&client::well_known_client)
|
||||
.route("/_tuwunel/server_version", get(client::tuwunel_server_version))
|
||||
// OIDC server endpoints (next-gen auth, MSC2965/2964/2966/2967)
|
||||
.route("/_matrix/client/unstable/org.matrix.msc2965/auth_issuer", get(client::auth_issuer_route))
|
||||
.route("/_matrix/client/v1/auth_issuer", get(client::auth_issuer_route))
|
||||
.route("/_matrix/client/unstable/org.matrix.msc2965/auth_metadata", get(client::openid_configuration_route))
|
||||
.route("/_matrix/client/v1/auth_metadata", get(client::openid_configuration_route))
|
||||
.route("/.well-known/openid-configuration", get(client::openid_configuration_route))
|
||||
.route("/_tuwunel/oidc/registration", post(client::registration_route))
|
||||
.route("/_tuwunel/oidc/authorize", get(client::authorize_route))
|
||||
.route("/_tuwunel/oidc/_complete", get(client::complete_route))
|
||||
.route("/_tuwunel/oidc/token", post(client::token_route))
|
||||
.route("/_tuwunel/oidc/revoke", post(client::revoke_route))
|
||||
.route("/_tuwunel/oidc/jwks", get(client::jwks_route))
|
||||
.route("/_tuwunel/oidc/userinfo", get(client::userinfo_route))
|
||||
.route("/_tuwunel/oidc/account", get(client::account_route))
|
||||
.ruma_route(&client::room_initial_sync_route)
|
||||
.route("/client/server.json", get(client::syncv3_client_server_json));
|
||||
|
||||
|
||||
@@ -145,6 +145,22 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "oauthuniqid_oauthid",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oidc_signingkey",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oidcclientid_registration",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oidccode_authsession",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oidcreqid_authrequest",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "onetimekeyid_onetimekeys",
|
||||
..descriptor::RANDOM_SMALL
|
||||
|
||||
@@ -104,6 +104,7 @@ lru-cache.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
reqwest.workspace = true
|
||||
ring.workspace = true
|
||||
ruma.workspace = true
|
||||
rustls.workspace = true
|
||||
rustyline-async.workspace = true
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod oidc_server;
|
||||
pub mod providers;
|
||||
pub mod sessions;
|
||||
pub mod user_info;
|
||||
@@ -14,13 +15,14 @@ use ruma::UserId;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tuwunel_core::{
|
||||
Err, Result, err, implement,
|
||||
Err, Result, err, implement, info,
|
||||
utils::{hash::sha256, result::LogErr, stream::ReadyExt},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use self::{providers::Providers, sessions::Sessions};
|
||||
use self::{oidc_server::OidcServer, providers::Providers, sessions::Sessions};
|
||||
pub use self::{
|
||||
oidc_server::ProviderMetadata,
|
||||
providers::{Provider, ProviderId},
|
||||
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
|
||||
user_info::UserInfo,
|
||||
@@ -31,16 +33,30 @@ pub struct Service {
|
||||
services: SelfServices,
|
||||
pub providers: Arc<Providers>,
|
||||
pub sessions: Arc<Sessions>,
|
||||
/// OIDC server (authorization server) for next-gen Matrix auth.
|
||||
/// Only initialized when identity providers are configured.
|
||||
pub oidc_server: Option<Arc<OidcServer>>,
|
||||
}
|
||||
|
||||
impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
let providers = Arc::new(Providers::build(args));
|
||||
let sessions = Arc::new(Sessions::build(args, providers.clone()));
|
||||
|
||||
let oidc_server = if !args.server.config.identity_provider.is_empty()
|
||||
|| args.server.config.well_known.client.is_some()
|
||||
{
|
||||
info!("Initializing OIDC server for next-gen auth (MSC2965)");
|
||||
Some(Arc::new(OidcServer::build(args)?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
services: args.services.clone(),
|
||||
sessions,
|
||||
providers,
|
||||
oidc_server,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
199
src/service/oauth/oidc_server.rs
Normal file
199
src/service/oauth/oidc_server.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use std::{sync::Arc, time::{Duration, SystemTime}};
|
||||
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
|
||||
use ring::{rand::SystemRandom, signature::{self, EcdsaKeyPair, KeyPair}};
|
||||
use ruma::OwnedUserId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{Err, Result, err, info, jwt, utils};
|
||||
use tuwunel_database::{Cbor, Deserialized, Map};
|
||||
|
||||
const AUTH_CODE_LENGTH: usize = 64;
|
||||
const OIDC_CLIENT_ID_LENGTH: usize = 32;
|
||||
const AUTH_CODE_LIFETIME: Duration = Duration::from_secs(600);
|
||||
const AUTH_REQUEST_LIFETIME: Duration = Duration::from_secs(600);
|
||||
const SIGNING_KEY_DB_KEY: &str = "oidc_signing_key";
|
||||
|
||||
pub struct OidcServer { db: Data, signing_key_der: Vec<u8>, jwk: serde_json::Value, key_id: String }
|
||||
|
||||
struct Data { oidc_signingkey: Arc<Map>, oidcclientid_registration: Arc<Map>, oidccode_authsession: Arc<Map>, oidcreqid_authrequest: Arc<Map> }
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DcrRequest {
|
||||
pub redirect_uris: Vec<String>, pub client_name: Option<String>, pub client_uri: Option<String>,
|
||||
pub logo_uri: Option<String>, #[serde(default)] pub contacts: Vec<String>,
|
||||
pub token_endpoint_auth_method: Option<String>, pub grant_types: Option<Vec<String>>,
|
||||
pub response_types: Option<Vec<String>>, pub application_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct OidcClientRegistration {
|
||||
pub client_id: String, pub redirect_uris: Vec<String>, pub client_name: Option<String>,
|
||||
pub client_uri: Option<String>, pub logo_uri: Option<String>, pub contacts: Vec<String>,
|
||||
pub token_endpoint_auth_method: String, pub grant_types: Vec<String>, pub response_types: Vec<String>,
|
||||
pub application_type: Option<String>, pub registered_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthCodeSession {
|
||||
pub code: String, pub client_id: String, pub redirect_uri: String, pub scope: String,
|
||||
pub state: Option<String>, pub nonce: Option<String>, pub code_challenge: Option<String>,
|
||||
pub code_challenge_method: Option<String>, pub user_id: OwnedUserId,
|
||||
pub created_at: SystemTime, pub expires_at: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct OidcAuthRequest {
|
||||
pub client_id: String, pub redirect_uri: String, pub scope: String, pub state: Option<String>,
|
||||
pub nonce: Option<String>, pub code_challenge: Option<String>, pub code_challenge_method: Option<String>,
|
||||
pub created_at: SystemTime, pub expires_at: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SigningKeyData { key_der: Vec<u8>, key_id: String }
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ProviderMetadata {
|
||||
pub issuer: String, pub authorization_endpoint: String, pub token_endpoint: String,
|
||||
pub registration_endpoint: Option<String>, pub revocation_endpoint: Option<String>, pub jwks_uri: String,
|
||||
pub userinfo_endpoint: Option<String>, pub account_management_uri: Option<String>,
|
||||
pub account_management_actions_supported: Option<Vec<String>>, pub response_types_supported: Vec<String>,
|
||||
pub response_modes_supported: Option<Vec<String>>, pub grant_types_supported: Option<Vec<String>>,
|
||||
pub code_challenge_methods_supported: Option<Vec<String>>, pub token_endpoint_auth_methods_supported: Option<Vec<String>>,
|
||||
pub scopes_supported: Option<Vec<String>>, pub subject_types_supported: Option<Vec<String>>,
|
||||
pub id_token_signing_alg_values_supported: Option<Vec<String>>, pub prompt_values_supported: Option<Vec<String>>,
|
||||
pub claim_types_supported: Option<Vec<String>>, pub claims_supported: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct IdTokenClaims {
|
||||
pub iss: String, pub sub: String, pub aud: String, pub exp: u64, pub iat: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] pub nonce: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] pub at_hash: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderMetadata { #[must_use] pub fn into_json(self) -> serde_json::Value { serde_json::to_value(self).unwrap() } }
|
||||
|
||||
impl OidcServer {
|
||||
pub(crate) fn build(args: &crate::Args<'_>) -> Result<Self> {
|
||||
let db = Data {
|
||||
oidc_signingkey: args.db["oidc_signingkey"].clone(),
|
||||
oidcclientid_registration: args.db["oidcclientid_registration"].clone(),
|
||||
oidccode_authsession: args.db["oidccode_authsession"].clone(),
|
||||
oidcreqid_authrequest: args.db["oidcreqid_authrequest"].clone(),
|
||||
};
|
||||
|
||||
let (signing_key_der, key_id) = match db.oidc_signingkey.get_blocking(SIGNING_KEY_DB_KEY).and_then(|handle| handle.deserialized::<Cbor<SigningKeyData>>().map(|cbor| cbor.0)) {
|
||||
| Ok(data) => { info!("Loaded existing OIDC signing key (kid={})", data.key_id); (data.key_der, data.key_id) },
|
||||
| Err(_) => {
|
||||
let (key_der, key_id) = Self::generate_signing_key()?;
|
||||
info!("Generated new OIDC signing key (kid={key_id})");
|
||||
let data = SigningKeyData { key_der: key_der.clone(), key_id: key_id.clone() };
|
||||
db.oidc_signingkey.raw_put(SIGNING_KEY_DB_KEY, Cbor(&data));
|
||||
(key_der, key_id)
|
||||
},
|
||||
};
|
||||
|
||||
let jwk = Self::build_jwk(&signing_key_der, &key_id)?;
|
||||
Ok(Self { db, signing_key_der, jwk, key_id })
|
||||
}
|
||||
|
||||
fn generate_signing_key() -> Result<(Vec<u8>, String)> {
|
||||
let rng = SystemRandom::new();
|
||||
let alg = &signature::ECDSA_P256_SHA256_FIXED_SIGNING;
|
||||
let pkcs8 = EcdsaKeyPair::generate_pkcs8(alg, &rng).map_err(|e| err!(error!("Failed to generate ECDSA key: {e}")))?;
|
||||
let key_id = utils::random_string(16);
|
||||
Ok((pkcs8.as_ref().to_vec(), key_id))
|
||||
}
|
||||
|
||||
fn build_jwk(signing_key_der: &[u8], key_id: &str) -> Result<serde_json::Value> {
|
||||
let rng = SystemRandom::new();
|
||||
let alg = &signature::ECDSA_P256_SHA256_FIXED_SIGNING;
|
||||
let key_pair = EcdsaKeyPair::from_pkcs8(alg, signing_key_der, &rng).map_err(|e| err!(error!("Failed to load ECDSA key: {e}")))?;
|
||||
let public_bytes = key_pair.public_key().as_ref();
|
||||
let x = b64.encode(&public_bytes[1..33]);
|
||||
let y = b64.encode(&public_bytes[33..65]);
|
||||
Ok(serde_json::json!({"kty": "EC", "crv": "P-256", "use": "sig", "alg": "ES256", "kid": key_id, "x": x, "y": y}))
|
||||
}
|
||||
|
||||
pub fn register_client(&self, request: DcrRequest) -> Result<OidcClientRegistration> {
|
||||
let client_id = utils::random_string(OIDC_CLIENT_ID_LENGTH);
|
||||
let auth_method = request.token_endpoint_auth_method.unwrap_or_else(|| "none".to_owned());
|
||||
let registration = OidcClientRegistration {
|
||||
client_id: client_id.clone(), redirect_uris: request.redirect_uris, client_name: request.client_name,
|
||||
client_uri: request.client_uri, logo_uri: request.logo_uri, contacts: request.contacts,
|
||||
token_endpoint_auth_method: auth_method,
|
||||
grant_types: request.grant_types.unwrap_or_else(|| vec!["authorization_code".to_owned(), "refresh_token".to_owned()]),
|
||||
response_types: request.response_types.unwrap_or_else(|| vec!["code".to_owned()]),
|
||||
application_type: request.application_type,
|
||||
registered_at: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap_or_default().as_secs(),
|
||||
};
|
||||
self.db.oidcclientid_registration.raw_put(&*client_id, Cbor(®istration));
|
||||
Ok(registration)
|
||||
}
|
||||
|
||||
pub async fn get_client(&self, client_id: &str) -> Result<OidcClientRegistration> {
|
||||
self.db.oidcclientid_registration.get(client_id).await.deserialized::<Cbor<_>>().map(|cbor: Cbor<OidcClientRegistration>| cbor.0).map_err(|_| err!(Request(NotFound("Unknown client_id"))))
|
||||
}
|
||||
|
||||
pub async fn validate_redirect_uri(&self, client_id: &str, redirect_uri: &str) -> Result {
|
||||
let client = self.get_client(client_id).await?;
|
||||
if client.redirect_uris.iter().any(|uri| uri == redirect_uri) { Ok(()) } else { Err!(Request(InvalidParam("redirect_uri not registered for this client"))) }
|
||||
}
|
||||
|
||||
pub fn store_auth_request(&self, req_id: &str, request: &OidcAuthRequest) { self.db.oidcreqid_authrequest.raw_put(req_id, Cbor(request)); }
|
||||
|
||||
pub async fn take_auth_request(&self, req_id: &str) -> Result<OidcAuthRequest> {
|
||||
let request: OidcAuthRequest = self.db.oidcreqid_authrequest.get(req_id).await.deserialized::<Cbor<_>>().map(|cbor: Cbor<OidcAuthRequest>| cbor.0).map_err(|_| err!(Request(NotFound("Unknown or expired authorization request"))))?;
|
||||
self.db.oidcreqid_authrequest.remove(req_id);
|
||||
if SystemTime::now() > request.expires_at { return Err!(Request(NotFound("Authorization request has expired"))); }
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn create_auth_code(&self, auth_req: &OidcAuthRequest, user_id: OwnedUserId) -> String {
|
||||
let code = utils::random_string(AUTH_CODE_LENGTH);
|
||||
let now = SystemTime::now();
|
||||
let session = AuthCodeSession {
|
||||
code: code.clone(), client_id: auth_req.client_id.clone(), redirect_uri: auth_req.redirect_uri.clone(),
|
||||
scope: auth_req.scope.clone(), state: auth_req.state.clone(), nonce: auth_req.nonce.clone(),
|
||||
code_challenge: auth_req.code_challenge.clone(), code_challenge_method: auth_req.code_challenge_method.clone(),
|
||||
user_id, created_at: now, expires_at: now.checked_add(AUTH_CODE_LIFETIME).unwrap_or(now),
|
||||
};
|
||||
self.db.oidccode_authsession.raw_put(&*code, Cbor(&session));
|
||||
code
|
||||
}
|
||||
|
||||
pub async fn exchange_auth_code(&self, code: &str, client_id: &str, redirect_uri: &str, code_verifier: Option<&str>) -> Result<AuthCodeSession> {
|
||||
let session: AuthCodeSession = self.db.oidccode_authsession.get(code).await.deserialized::<Cbor<_>>().map(|cbor: Cbor<AuthCodeSession>| cbor.0).map_err(|_| err!(Request(Forbidden("Invalid or expired authorization code"))))?;
|
||||
self.db.oidccode_authsession.remove(code);
|
||||
if SystemTime::now() > session.expires_at { return Err!(Request(Forbidden("Authorization code has expired"))); }
|
||||
if session.client_id != client_id { return Err!(Request(Forbidden("client_id mismatch"))); }
|
||||
if session.redirect_uri != redirect_uri { return Err!(Request(Forbidden("redirect_uri mismatch"))); }
|
||||
|
||||
if let Some(challenge) = &session.code_challenge {
|
||||
let Some(verifier) = code_verifier else { return Err!(Request(Forbidden("code_verifier required for PKCE"))); };
|
||||
let method = session.code_challenge_method.as_deref().unwrap_or("S256");
|
||||
let computed = match method {
|
||||
| "S256" => { let hash = utils::hash::sha256::hash(verifier.as_bytes()); b64.encode(hash) },
|
||||
| "plain" => verifier.to_owned(),
|
||||
| _ => return Err!(Request(InvalidParam("Unsupported code_challenge_method"))),
|
||||
};
|
||||
if computed != *challenge { return Err!(Request(Forbidden("PKCE verification failed"))); }
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub fn sign_id_token(&self, claims: &IdTokenClaims) -> Result<String> {
|
||||
let mut header = jwt::Header::new(jwt::Algorithm::ES256);
|
||||
header.kid = Some(self.key_id.clone());
|
||||
let key = jwt::EncodingKey::from_ec_der(&self.signing_key_der);
|
||||
jwt::encode(&header, claims, &key).map_err(|e| err!(error!("Failed to sign ID token: {e}")))
|
||||
}
|
||||
|
||||
#[must_use] pub fn jwks(&self) -> serde_json::Value { serde_json::json!({"keys": [self.jwk.clone()]}) }
|
||||
|
||||
#[must_use] pub fn at_hash(access_token: &str) -> String { let hash = utils::hash::sha256::hash(access_token.as_bytes()); b64.encode(&hash[..16]) }
|
||||
|
||||
#[must_use] pub fn auth_request_lifetime() -> Duration { AUTH_REQUEST_LIFETIME }
|
||||
}
|
||||
Reference in New Issue
Block a user