diff --git a/src/api/server/query.rs b/src/api/server/query.rs index 8e600d75..d82a3ab5 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -1,17 +1,15 @@ -use std::collections::BTreeMap; - use axum::extract::State; -use futures::StreamExt; +use futures::{ + StreamExt, + future::{join, join5}, +}; use get_profile_information::v1::ProfileField; use rand::seq::SliceRandom; use ruma::{ OwnedServerName, - api::{ - client::error::ErrorKind, - federation::query::{get_profile_information, get_room_information}, - }, + api::federation::query::{get_profile_information, get_room_information}, }; -use tuwunel_core::{Error, Result, err}; +use tuwunel_core::{Err, Result, err, utils::future::TryExtExt}; use crate::Ruma; @@ -60,48 +58,38 @@ pub(crate) async fn get_profile_information_route( State(services): State, body: Ruma, ) -> Result { + use get_profile_information::v1::Response; + if !services .server .config .allow_inbound_profile_lookup_federation_requests { - return Err(Error::BadRequest( - ErrorKind::forbidden(), + return Err!(Request(Forbidden( "Profile lookup over federation is not allowed on this homeserver.", - )); + ))); } if !services .globals .server_is_ours(body.user_id.server_name()) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not belong to this server.", - )); + return Err!(Request(InvalidParam("User does not belong to this server.",))); } - let mut displayname = None; - let mut avatar_url = None; - let mut blurhash = None; - let mut tz = None; - let mut custom_profile_fields = BTreeMap::new(); - match &body.field { - | Some(ProfileField::DisplayName) => { - displayname = services - .users - .displayname(&body.user_id) - .await - .ok(); - }, - | Some(ProfileField::AvatarUrl) => { - avatar_url = services - .users - .avatar_url(&body.user_id) - .await - .ok(); - blurhash = services.users.blurhash(&body.user_id).await.ok(); + | Some(ProfileField::AvatarUrl | ProfileField::DisplayName) => { + let avatar_url = services.users.avatar_url(&body.user_id).ok(); + + let displayname = services.users.displayname(&body.user_id).ok(); + + let (avatar_url, displayname) = join(avatar_url, displayname).await; + + Ok(Response { + avatar_url, + displayname, + ..Response::default() + }) }, | Some(custom_field) => { if let Ok(value) = services @@ -109,39 +97,39 @@ pub(crate) async fn get_profile_information_route( .profile_key(&body.user_id, custom_field.as_str()) .await { - custom_profile_fields.insert(custom_field.to_string(), value); + Ok(Response { + custom_profile_fields: [(custom_field.to_string(), value)].into(), + ..Response::default() + }) + } else { + Ok(Response::default()) } }, | None => { - displayname = services - .users - .displayname(&body.user_id) - .await - .ok(); - avatar_url = services - .users - .avatar_url(&body.user_id) - .await - .ok(); - blurhash = services.users.blurhash(&body.user_id).await.ok(); - tz = services.users.timezone(&body.user_id).await.ok(); - custom_profile_fields = services + let avatar_url = services.users.avatar_url(&body.user_id).ok(); + + let blurhash = services.users.blurhash(&body.user_id).ok(); + + let displayname = services.users.displayname(&body.user_id).ok(); + + let tz = services.users.timezone(&body.user_id).ok(); + + let custom_profile_fields = services .users .all_profile_keys(&body.user_id) - .collect() - .await; + .collect(); + + let (avatar_url, blurhash, custom_profile_fields, displayname, tz) = + join5(avatar_url, blurhash, custom_profile_fields, displayname, tz).await; + + Ok(Response { + avatar_url, + blurhash, + custom_profile_fields, + displayname, + tz, + ..Response::default() + }) }, } - - // services.users.timezone will collect the MSC4175 timezone key if it exists - custom_profile_fields.remove("us.cloke.msc4175.tz"); - custom_profile_fields.remove("m.tz"); - - Ok(get_profile_information::v1::Response { - displayname, - avatar_url, - blurhash, - tz, - custom_profile_fields, - }) }