Files
tuwunel/src/api/client/session/jwt.rs

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

118 lines
3.2 KiB
Rust
Raw Normal View History

use std::str::FromStr;
use jwt::{Algorithm, DecodingKey, Validation, decode};
use ruma::{
OwnedUserId, UserId,
api::client::session::login::v3::{Request, Token},
};
use serde::Deserialize;
use tuwunel_core::{Err, Result, at, config::JwtConfig, debug, err, jwt, warn};
use tuwunel_service::Services;
use crate::Ruma;
#[derive(Debug, Deserialize)]
struct Claim {
/// Subject is the localpart of the User MXID
sub: String,
}
pub(super) async fn handle_login(
services: &Services,
_body: &Ruma<Request>,
info: &Token,
) -> Result<OwnedUserId> {
let config = &services.config.jwt;
if !config.enable {
return Err!(Request(Unknown("JWT login is not enabled.")));
}
let claim = validate(config, &info.token)?;
let local = claim.sub.to_lowercase();
let server = &services.server.name;
let user_id = UserId::parse_with_server_name(local, server).map_err(|e| {
err!(Request(InvalidUsername("JWT subject is not a valid user MXID: {e}")))
})?;
if !services.users.exists(&user_id).await {
if !config.register_user {
return Err!(Request(NotFound("User {user_id} is not registered on this server.")));
}
services
.users
.create(&user_id, Some("*"), Some("jwt"))
.await?;
}
Ok(user_id)
}
fn validate(config: &JwtConfig, token: &str) -> Result<Claim> {
let verifier = init_verifier(config)?;
let validator = init_validator(config)?;
decode::<Claim>(token, &verifier, &validator)
.map(|decoded| (decoded.header, decoded.claims))
.inspect(|(head, claim)| debug!(?head, ?claim, "JWT token decoded"))
.map_err(|e| err!(Request(Forbidden("Invalid JWT token: {e}"))))
.map(at!(1))
}
fn init_verifier(config: &JwtConfig) -> Result<DecodingKey> {
let key = &config.key;
let format = config.format.as_str();
Ok(match format {
| "HMAC" => DecodingKey::from_secret(key.as_bytes()),
| "HMACB64" => DecodingKey::from_base64_secret(key.as_str())
.map_err(|e| err!(Config("jwt.key", "JWT key is not valid base64: {e}")))?,
| "ECDSA" => DecodingKey::from_ec_pem(key.as_bytes())
.map_err(|e| err!(Config("jwt.key", "JWT key is not valid PEM: {e}")))?,
| _ => return Err!(Config("jwt.format", "Key format {format:?} is not supported.")),
})
}
fn init_validator(config: &JwtConfig) -> Result<Validation> {
let alg = config.algorithm.as_str();
let alg = Algorithm::from_str(alg).map_err(|e| {
err!(Config("jwt.algorithm", "JWT algorithm is not recognized or configured {e}"))
})?;
let mut validator = Validation::new(alg);
let mut required_spec_claims: Vec<_> = ["sub"].into();
validator.validate_exp = config.validate_exp;
if config.require_exp {
required_spec_claims.push("exp");
}
validator.validate_nbf = config.validate_nbf;
if config.require_nbf {
required_spec_claims.push("nbf");
}
if !config.audience.is_empty() {
required_spec_claims.push("aud");
validator.set_audience(&config.audience);
}
if !config.issuer.is_empty() {
required_spec_claims.push("iss");
validator.set_issuer(&config.issuer);
}
if cfg!(debug_assertions) && !config.validate_signature {
warn!("JWT signature validation is disabled!");
validator.insecure_disable_signature_validation();
}
validator.set_required_spec_claims(&required_spec_claims);
debug!(?validator, "JWT configured");
Ok(validator)
}