Use form-urlencoded bodies for server-to-server oauth requests. (fixes #249)

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-02 04:11:37 +00:00
parent fbf66f565a
commit 2a7455b5c9

View File

@@ -5,7 +5,10 @@ pub mod user_info;
use std::sync::Arc; use std::sync::Arc;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode}; use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64encode};
use reqwest::{Method, header::ACCEPT}; use reqwest::{
Method,
header::{ACCEPT, CONTENT_TYPE},
};
use ruma::UserId; use ruma::UserId;
use serde::Serialize; use serde::Serialize;
use serde_json::Value as JsonValue; use serde_json::Value as JsonValue;
@@ -49,12 +52,15 @@ pub async fn request_userinfo(
&self, &self,
(provider, session): (&Provider, &Session), (provider, session): (&Provider, &Session),
) -> Result<UserInfo> { ) -> Result<UserInfo> {
#[derive(Debug, Serialize)]
struct Query;
let url = provider let url = provider
.userinfo_url .userinfo_url
.clone() .clone()
.ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?; .ok_or_else(|| err!(Config("userinfo_url", "Missing userinfo URL in config")))?;
self.request((Some(provider), Some(session)), Method::GET, url) self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
.await .await
.and_then(|value| serde_json::from_value(value).map_err(Into::into)) .and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err() .log_err()
@@ -66,6 +72,9 @@ pub async fn request_tokeninfo(
&self, &self,
(provider, session): (&Provider, &Session), (provider, session): (&Provider, &Session),
) -> Result<UserInfo> { ) -> Result<UserInfo> {
#[derive(Debug, Serialize)]
struct Query;
let url = provider let url = provider
.introspection_url .introspection_url
.clone() .clone()
@@ -73,7 +82,7 @@ pub async fn request_tokeninfo(
err!(Config("introspection_url", "Missing introspection URL in config")) err!(Config("introspection_url", "Missing introspection URL in config"))
})?; })?;
self.request((Some(provider), Some(session)), Method::GET, url) self.request((Some(provider), Some(session)), Method::GET, url, Option::<Query>::None)
.await .await
.and_then(|value| serde_json::from_value(value).map_err(Into::into)) .and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err() .log_err()
@@ -93,17 +102,12 @@ pub async fn revoke_token(&self, (provider, session): (&Provider, &Session)) ->
client_secret: &provider.client_secret, client_secret: &provider.client_secret,
}; };
let query = serde_html_form::to_string(&query)?;
let url = provider let url = provider
.revocation_url .revocation_url
.clone() .clone()
.map(|mut url| {
url.set_query(Some(&query));
url
})
.ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?; .ok_or_else(|| err!(Config("revocation_url", "Missing revocation URL in config")))?;
self.request((Some(provider), Some(session)), Method::POST, url) self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
.await .await
.log_err() .log_err()
.map(|_| ()) .map(|_| ())
@@ -135,17 +139,12 @@ pub async fn request_token(
redirect_uri: provider.callback_url.as_ref().map(Url::as_str), redirect_uri: provider.callback_url.as_ref().map(Url::as_str),
}; };
let query = serde_html_form::to_string(&query)?;
let url = provider let url = provider
.token_url .token_url
.clone() .clone()
.map(|mut url| {
url.set_query(Some(&query));
url
})
.ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?; .ok_or_else(|| err!(Config("token_url", "Missing token URL in config")))?;
self.request((Some(provider), Some(session)), Method::POST, url) self.request((Some(provider), Some(session)), Method::POST, url, Some(query))
.await .await
.and_then(|value| serde_json::from_value(value).map_err(Into::into)) .and_then(|value| serde_json::from_value(value).map_err(Into::into))
.log_err() .log_err()
@@ -160,14 +159,18 @@ pub async fn request_token(
name = "request", name = "request",
level = "debug", level = "debug",
ret(level = "trace"), ret(level = "trace"),
skip(self) skip(self, body)
)] )]
pub async fn request( pub async fn request<Body>(
&self, &self,
(provider, session): (Option<&Provider>, Option<&Session>), (provider, session): (Option<&Provider>, Option<&Session>),
method: Method, method: Method,
url: Url, url: Url,
) -> Result<JsonValue> { body: Option<Body>,
) -> Result<JsonValue>
where
Body: Serialize,
{
let mut request = self let mut request = self
.services .services
.client .client
@@ -175,6 +178,12 @@ pub async fn request(
.request(method, url) .request(method, url)
.header(ACCEPT, "application/json"); .header(ACCEPT, "application/json");
if let Some(body) = body.map(serde_html_form::to_string).transpose()? {
request = request
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(body);
}
if let Some(session) = session { if let Some(session) = session {
if let Some(access_token) = session.access_token.clone() { if let Some(access_token) = session.access_token.clone() {
request = request.bearer_auth(access_token); request = request.bearer_auth(access_token);