diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 0ed79023..77d8d701 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -3,17 +3,16 @@ use std::borrow::Borrow; use axum::extract::State; -use futures::{FutureExt, StreamExt, TryStreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join3}; use ruma::{ - CanonicalJsonValue, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, - ServerName, + OwnedEventId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, api::federation::membership::create_join_event, events::{ StateEventType, room::member::{MembershipState, RoomMemberEventContent}, }, }; -use serde_json::value::{RawValue as RawJsonValue, to_raw_value}; +use serde_json::value::RawValue as RawJsonValue; use tuwunel_core::{ Err, Result, at, err, matrix::event::gen_event_id_canonical_json, @@ -212,20 +211,6 @@ async fn create_join_event( .collect() .await; - let state = state_ids - .iter() - .try_stream() - .broad_and_then(|event_id| services.timeline.get_pdu_json(event_id)) - .broad_and_then(|pdu| { - services - .federation - .format_pdu_into(pdu, Some(&room_version_id)) - .map(Ok) - }) - .try_collect() - .boxed() - .await?; - let starting_events = state_ids.iter().map(Borrow::borrow); let auth_chain = services .auth_chain @@ -237,20 +222,34 @@ async fn create_join_event( .format_pdu_into(pdu, Some(&room_version_id)) .map(Ok) }) - .try_collect() - .boxed() - .await?; + .try_collect(); + + let state = state_ids + .iter() + .try_stream() + .broad_and_then(|event_id| services.timeline.get_pdu_json(event_id)) + .broad_and_then(|pdu| { + services + .federation + .format_pdu_into(pdu, Some(&room_version_id)) + .map(Ok) + }) + .try_collect(); + + let event = services + .federation + .format_pdu_into(value, Some(&room_version_id)) + .map(Some) + .map(Ok); + + let (auth_chain, state, event) = try_join3(auth_chain, state, event).await?; services .sending .send_pdu_room(room_id, &pdu_id) .await?; - Ok(create_join_event::v1::RoomState { - auth_chain, - state, - event: to_raw_value(&CanonicalJsonValue::Object(value)).ok(), - }) + Ok(create_join_event::v1::RoomState { auth_chain, state, event }) } /// # `PUT /_matrix/federation/v1/send_join/{roomId}/{eventId}` @@ -271,6 +270,7 @@ pub(crate) async fn create_join_event_v1_route( body.origin(), &body.room_id, ); + return Err!(Request(Forbidden("Server is banned on this homeserver."))); } @@ -286,17 +286,18 @@ pub(crate) async fn create_join_event_v1_route( body.origin(), &body.room_id, ); + return Err!(Request(Forbidden(warn!( "Room ID server name {server} is banned on this homeserver." )))); } } - let room_state = create_join_event(&services, body.origin(), &body.room_id, &body.pdu) - .boxed() - .await?; - - Ok(create_join_event::v1::Response { room_state }) + Ok(create_join_event::v1::Response { + room_state: create_join_event(&services, body.origin(), &body.room_id, &body.pdu) + .boxed() + .await?, + }) } /// # `PUT /_matrix/federation/v2/send_join/{roomId}/{eventId}` @@ -326,6 +327,7 @@ pub(crate) async fn create_join_event_v2_route( body.origin(), &body.room_id, ); + return Err!(Request(Forbidden(warn!( "Room ID server name {server} is banned on this homeserver." )))); @@ -337,13 +339,13 @@ pub(crate) async fn create_join_event_v2_route( .boxed() .await?; - let room_state = create_join_event::v2::RoomState { - members_omitted: false, - auth_chain, - state, - event, - servers_in_room: None, - }; - - Ok(create_join_event::v2::Response { room_state }) + Ok(create_join_event::v2::Response { + room_state: create_join_event::v2::RoomState { + members_omitted: false, + auth_chain, + state, + event, + servers_in_room: None, + }, + }) }