From cd8648dce5f581e4d6160fc8e888a73b282e6a60 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 16 Jun 2025 03:43:22 +0000 Subject: [PATCH] Auth related cleanups. Cleanup; additional error macros. Signed-off-by: Jason Volk --- src/api/client/session/mod.rs | 24 ++++++---- src/api/router/auth.rs | 31 ++++++------ src/service/uiaa/mod.rs | 88 ++++++++++++----------------------- 3 files changed, 57 insertions(+), 86 deletions(-) diff --git a/src/api/client/session/mod.rs b/src/api/client/session/mod.rs index 450ec43f..12a7ac87 100644 --- a/src/api/client/session/mod.rs +++ b/src/api/client/session/mod.rs @@ -81,7 +81,7 @@ pub(crate) async fn login_route( }; // Generate a new token for the device - let token = utils::random_string(TOKEN_LENGTH); + let access_token = utils::random_string(TOKEN_LENGTH); // Generate new device id if the user didn't specify one let device_id = body @@ -102,7 +102,7 @@ pub(crate) async fn login_route( .create_device( &user_id, &device_id, - &token, + &access_token, body.initial_device_display_name.clone(), Some(client.to_string()), ) @@ -110,28 +110,32 @@ pub(crate) async fn login_route( } else { services .users - .set_token(&user_id, &device_id, &token) + .set_token(&user_id, &device_id, &access_token) .await?; } + info!("{user_id} logged in"); + + let home_server = services.server.name.clone().into(); + // send client well-known if specified so the client knows to reconfigure itself - let client_discovery_info: Option = services + let well_known: Option = services .config .well_known .client .as_ref() - .map(|server| DiscoveryInfo::new(HomeserverInfo::new(server.to_string()))); - - info!("{user_id} logged in"); + .map(ToString::to_string) + .map(HomeserverInfo::new) + .map(DiscoveryInfo::new); #[allow(deprecated)] Ok(login::v3::Response { user_id, - access_token: token, + access_token, device_id, - well_known: client_discovery_info, + home_server, + well_known, expires_in: None, - home_server: Some(services.config.server_name.clone()), refresh_token: None, }) } diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index b999cce4..2c4553d9 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -50,6 +50,7 @@ pub(super) async fn auth( ) -> Result { let bearer: Option>> = request.parts.extract().await.unwrap_or(None); + let token = match &bearer { | Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), | None => request.query.access_token.as_deref(), @@ -81,10 +82,9 @@ pub(super) async fn auth( // already }, | Token::None | Token::Invalid => { - return Err(Error::BadRequest( - ErrorKind::MissingToken, - "Missing or invalid access token.", - )); + return Err!(Request(MissingToken( + "Missing or invalid access token." + ))); }, } } @@ -105,10 +105,9 @@ pub(super) async fn auth( // already }, | Token::None | Token::Invalid => { - return Err(Error::BadRequest( - ErrorKind::MissingToken, - "Missing or invalid access token.", - )); + return Err!(Request(MissingToken( + "Missing or invalid access token." + ))); }, } } @@ -139,10 +138,10 @@ pub(super) async fn auth( appservice_info: None, }) } else { - Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")) + Err!(Request(MissingToken("Missing access token."))) } }, - | _ => Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")), + | _ => Err!(Request(MissingToken("Missing access token."))), }, | ( AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, @@ -165,14 +164,10 @@ pub(super) async fn auth( appservice_info: None, }), | (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => - Err(Error::BadRequest( - ErrorKind::Unauthorized, - "Only server signatures should be used on this endpoint.", - )), - | (AuthScheme::AppserviceToken, Token::User(_)) => Err(Error::BadRequest( - ErrorKind::Unauthorized, - "Only appservice access tokens should be used on this endpoint.", - )), + Err!(Request(Unauthorized("Only server signatures should be used on this endpoint."))), + | (AuthScheme::AppserviceToken, Token::User(_)) => Err!(Request(Unauthorized( + "Only appservice access tokens should be used on this endpoint." + ))), | (AuthScheme::None, Token::Invalid) => { // OpenID federation endpoint uses a query param with the same name, drop this // once query params for user auth are removed from the spec. This is diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index e33b3f89..c180211d 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -6,12 +6,12 @@ use std::{ use ruma::{ CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId, api::client::{ - error::ErrorKind, + error::{ErrorKind, StandardErrorBody}, uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier}, }, }; use tuwunel_core::{ - Err, Error, Result, err, error, implement, utils, + Err, Result, err, error, implement, utils, utils::{hash, string::EMPTY}, }; use tuwunel_database::{Deserialized, Json, Map}; @@ -67,14 +67,15 @@ pub async fn read_tokens(&self) -> Result> { .as_ref() { match std::fs::read_to_string(file) { + | Err(e) => error!("Failed to read the registration token file: {e}"), | Ok(text) => { text.split_ascii_whitespace().for_each(|token| { tokens.insert(token.to_owned()); }); }, - | Err(e) => error!("Failed to read the registration token file: {e}"), } } + if let Some(token) = &self.services.config.registration_token { tokens.insert(token.to_owned()); } @@ -93,25 +94,14 @@ pub fn create( ) { // TODO: better session error handling (why is uiaainfo.session optional in // ruma?) - self.set_uiaa_request( - user_id, - device_id, - uiaainfo - .session - .as_ref() - .expect("session should be set"), - json_body, - ); + let session = uiaainfo + .session + .as_ref() + .expect("session should be set"); - self.update_uiaa_session( - user_id, - device_id, - uiaainfo - .session - .as_ref() - .expect("session should be set"), - Some(uiaainfo), - ); + self.set_uiaa_request(user_id, device_id, session, json_body); + + self.update_uiaa_session(user_id, device_id, session, Some(uiaainfo)); } #[implement(Service)] @@ -148,40 +138,35 @@ pub async fn try_auth( } else if let Some(username) = user { username } else { - return Err(Error::BadRequest( - ErrorKind::Unrecognized, - "Identifier type not recognized.", - )); + return Err!(Request(Unrecognized("Identifier type not recognized."))); }; #[cfg(not(feature = "element_hacks"))] let Some(UserIdentifier::UserIdOrLocalpart(username)) = identifier else { - return Err(Error::BadRequest( - ErrorKind::Unrecognized, - "Identifier type not recognized.", - )); + return Err!(Request(Unrecognized("Identifier type not recognized."))); }; let user_id_from_username = UserId::parse_with_server_name( username.clone(), self.services.globals.server_name(), ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; + .map_err(|_| err!(Request(InvalidParam("User ID is invalid."))))?; // Check if the access token being used matches the credentials used for UIAA if user_id.localpart() != user_id_from_username.localpart() { return Err!(Request(Forbidden("User ID and access token mismatch."))); } - let user_id = user_id_from_username; // Check if password is correct + let user_id = user_id_from_username; if let Ok(hash) = self.services.users.password_hash(&user_id).await { let hash_matches = hash::verify_password(password, &hash).is_ok(); if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + uiaainfo.auth_error = Some(StandardErrorBody { kind: ErrorKind::forbidden(), message: "Invalid username or password.".to_owned(), }); + return Ok((false, uiaainfo)); } } @@ -196,10 +181,11 @@ pub async fn try_auth( .completed .push(AuthType::RegistrationToken); } else { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + uiaainfo.auth_error = Some(StandardErrorBody { kind: ErrorKind::forbidden(), message: "Invalid registration token.".to_owned(), }); + return Ok((false, uiaainfo)); } }, @@ -221,30 +207,19 @@ pub async fn try_auth( completed = true; } + let session = uiaainfo + .session + .as_ref() + .expect("session is always set"); + if !completed { - self.update_uiaa_session( - user_id, - device_id, - uiaainfo - .session - .as_ref() - .expect("session is always set"), - Some(&uiaainfo), - ); + self.update_uiaa_session(user_id, device_id, session, Some(&uiaainfo)); return Ok((false, uiaainfo)); } // UIAA was successful! Remove this session and return true - self.update_uiaa_session( - user_id, - device_id, - uiaainfo - .session - .as_ref() - .expect("session is always set"), - None, - ); + self.update_uiaa_session(user_id, device_id, session, None); Ok((true, uiaainfo)) } @@ -258,6 +233,7 @@ fn set_uiaa_request( request: &CanonicalJsonValue, ) { let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); + self.userdevicesessionid_uiaarequest .write() .expect("locked for writing") @@ -271,13 +247,8 @@ pub fn get_uiaa_request( device_id: Option<&DeviceId>, session: &str, ) -> Option { - let key = ( - user_id.to_owned(), - device_id - .unwrap_or_else(|| EMPTY.into()) - .to_owned(), - session.to_owned(), - ); + let device_id = device_id.unwrap_or_else(|| EMPTY.into()); + let key = (user_id.to_owned(), device_id.to_owned(), session.to_owned()); self.userdevicesessionid_uiaarequest .read() @@ -313,6 +284,7 @@ async fn get_uiaa_session( session: &str, ) -> Result { let key = (user_id, device_id, session); + self.db .userdevicesessionid_uiaainfo .qry(&key)