diff --git a/src/service/federation/execute.rs b/src/service/federation/execute.rs index 69c765ee..8f7f347c 100644 --- a/src/service/federation/execute.rs +++ b/src/service/federation/execute.rs @@ -7,8 +7,8 @@ use reqwest::{Client, Method, Request, Response, Url}; use ruma::{ CanonicalJsonName, CanonicalJsonObject, CanonicalJsonValue, ServerName, ServerSigningKeyId, api::{ - EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, - SupportedVersions, client::error::Error as RumaError, + AuthScheme, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, + SendAccessToken, SupportedVersions, client::error::Error as RumaError, federation::authentication::XMatrix, }, serde::Base64, @@ -48,10 +48,10 @@ where #[implement(super::Service)] #[tracing::instrument( - name = "fed", - level = INFO_SPAN_LEVEL, - skip(self, client, request), - )] + name = "fed", + level = INFO_SPAN_LEVEL, + skip(self, client, request), +)] pub async fn execute_on( &self, client: &Client, @@ -81,17 +81,16 @@ where .get_actual_dest(dest) .await?; - let request = into_http_request::(&actual, request)?; - let request = self.prepare(dest, request)?; - self.perform::(dest, &actual, request, client) + let request = self.prepare(&actual, dest, request)?; + self.perform::(&actual, dest, request, client) .await } #[implement(super::Service)] async fn perform( &self, - dest: &ServerName, actual: &ActualDest, + dest: &ServerName, request: Request, client: &Client, ) -> Result @@ -103,7 +102,7 @@ where debug!(?method, ?url, "Sending request"); match client.execute(request).await { - | Ok(response) => handle_response::(dest, actual, &method, &url, response).await, + | Ok(response) => handle_response::(actual, dest, &method, &url, response).await, | Err(error) => Err(self .handle_error(dest, actual, &method, &url, error) .expect_err("always returns error")), @@ -111,9 +110,11 @@ where } #[implement(super::Service)] -fn prepare(&self, dest: &ServerName, mut request: http::Request>) -> Result { - self.sign_request(&mut request, dest); - +fn prepare(&self, actual: &ActualDest, dest: &ServerName, request: T) -> Result +where + T: OutgoingRequest + Send, +{ + let request = self.to_http_request::(actual, dest, request)?; let request = Request::try_from(request)?; self.validate_url(request.url())?; self.services.server.check_running()?; @@ -134,8 +135,8 @@ fn validate_url(&self, url: &Url) -> Result { } async fn handle_response( - dest: &ServerName, actual: &ActualDest, + dest: &ServerName, method: &Method, url: &Url, response: Response, @@ -230,6 +231,34 @@ fn handle_error( Err(e.into()) } +#[implement(super::Service)] +fn to_http_request( + &self, + actual: &ActualDest, + dest: &ServerName, + request: T, +) -> Result>> +where + T: OutgoingRequest + Send, +{ + const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11]; + const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); + let supported = SupportedVersions { + versions: VERSIONS.into(), + features: Default::default(), + }; + + let mut request = request + .try_into_http_request::>(actual.string().as_str(), SATIR, &supported) + .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; + + if matches!(T::METADATA.authentication, AuthScheme::ServerSignatures) { + self.sign_request(&mut request, dest); + } + + Ok(request) +} + #[implement(super::Service)] fn sign_request(&self, http_request: &mut http::Request>, dest: &ServerName) { type Member = (CanonicalJsonName, Value); @@ -300,21 +329,3 @@ fn sign_request(&self, http_request: &mut http::Request>, dest: &ServerN debug_assert!(authorization.is_none(), "Authorization header already present"); } - -fn into_http_request(actual: &ActualDest, request: T) -> Result>> -where - T: OutgoingRequest + Send, -{ - const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_11]; - const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); - let supported = SupportedVersions { - versions: VERSIONS.into(), - features: Default::default(), - }; - - let http_request = request - .try_into_http_request::>(actual.string().as_str(), SATIR, &supported) - .map_err(|e| err!(BadServerResponse("Invalid destination: {e:?}")))?; - - Ok(http_request) -}