diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 6f38f19e..d81b6ea4 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -53,6 +53,12 @@ pub(crate) async fn get_backfill_route( .ready_fold(PduCount::min(), cmp::max) .await; + let room_version = services + .state + .get_room_version(&body.room_id) + .await + .ok(); + Ok(get_backfill::v1::Response { origin_server_ts: MilliSecondsSinceUnixEpoch::now(), @@ -79,7 +85,7 @@ pub(crate) async fn get_backfill_route( .and_then(|pdu| { services .federation - .format_pdu_into(pdu, None) + .format_pdu_into(pdu, room_version.as_ref()) .map(Ok) }) .try_collect() diff --git a/src/api/server/event.rs b/src/api/server/event.rs index ab2b9931..15ea2737 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,5 +1,6 @@ use axum::extract::State; -use ruma::{MilliSecondsSinceUnixEpoch, RoomId, api::federation::event::get_event}; +use futures::{FutureExt, future::try_join}; +use ruma::{MilliSecondsSinceUnixEpoch, OwnedRoomId, api::federation::event::get_event}; use tuwunel_core::{Result, err}; use super::AccessCheck; @@ -21,28 +22,30 @@ pub(crate) async fn get_event_route( .await .map_err(|_| err!(Request(NotFound("Event not found."))))?; - let room_id: &RoomId = event + let room_id: OwnedRoomId = event .get("room_id") .and_then(|val| val.as_str()) .ok_or_else(|| err!(Database("Invalid event in database.")))? .try_into() .map_err(|_| err!(Database("Invalid room_id in event in database.")))?; - AccessCheck { + let access_check = AccessCheck { services: &services, origin: body.origin(), - room_id, + room_id: &room_id, event_id: Some(&body.event_id), - } - .check() - .await?; + }; + + let pdu = services + .federation + .format_pdu_into(event, None) + .map(Ok); + + let ((), pdu) = try_join(access_check.check(), pdu).await?; Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: services - .federation - .format_pdu_into(event, None) - .await, + pdu, }) } diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index b4be7f60..fd6037a6 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -6,7 +6,10 @@ use ruma::{ RoomId, api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, }; -use tuwunel_core::{Error, Result, utils::stream::ReadyExt}; +use tuwunel_core::{ + Error, Result, + utils::stream::{BroadbandExt, ReadyExt}, +}; use super::AccessCheck; use crate::Ruma; @@ -43,12 +46,26 @@ pub(crate) async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; + let room_version = services + .state + .get_room_version(room_id) + .await + .ok(); + let auth_chain = services .auth_chain .event_ids_iter(room_id, once(body.event_id.borrow())) .ready_filter_map(Result::ok) - .filter_map(async |id| services.timeline.get_pdu_json(&id).await.ok()) - .then(|pdu| services.federation.format_pdu_into(pdu, None)) + .broad_filter_map(async |id| { + let pdu = services.timeline.get_pdu_json(&id).await.ok()?; + + let pdu = services + .federation + .format_pdu_into(pdu, room_version.as_ref()) + .await; + + Some(pdu) + }) .collect() .await; diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index 6640ea4f..c406bbb2 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -32,6 +32,12 @@ pub(crate) async fn get_missing_events_route( .unwrap_or(LIMIT_DEFAULT) .min(LIMIT_MAX); + let room_version = services + .state + .get_room_version(&body.room_id) + .await + .ok(); + let mut queued_events = body.latest_events.clone(); // the vec will never have more entries the limit let mut events = Vec::with_capacity(limit); @@ -78,7 +84,7 @@ pub(crate) async fn get_missing_events_route( let event = services .federation - .format_pdu_into(event, None) + .format_pdu_into(event, room_version.as_ref()) .await; queued_events.extend(prev_events); diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 77d8d701..609f93df 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use axum::extract::State; -use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join3}; +use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join4}; use ruma::{ OwnedEventId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, api::federation::membership::create_join_event, @@ -189,6 +189,14 @@ async fn create_join_event( ) .map_err(|e| err!(Request(BadJson("Event has an invalid origin server name: {e}"))))?; + // Prestart state gather here since it doesn't involve the new join event. + let state_ids = services + .state_accessor + .state_full_ids(shortstatehash) + .map(at!(1)) + .collect::>() + .boxed(); + let mutex_lock = services .event_handler .mutex_federation @@ -204,49 +212,51 @@ async fn create_join_event( drop(mutex_lock); - let state_ids: Vec = services - .state_accessor - .state_full_ids(shortstatehash) - .map(at!(1)) - .collect() - .await; - - let starting_events = state_ids.iter().map(Borrow::borrow); - let auth_chain = services - .auth_chain - .event_ids_iter(room_id, starting_events) - .broad_and_then(async |event_id| services.timeline.get_pdu_json(&event_id).await) - .broad_and_then(|pdu| { - services - .federation - .format_pdu_into(pdu, Some(&room_version_id)) - .map(Ok) - }) - .try_collect(); - - let state = state_ids - .iter() - .try_stream() - .broad_and_then(|event_id| services.timeline.get_pdu_json(event_id)) - .broad_and_then(|pdu| { - services - .federation - .format_pdu_into(pdu, Some(&room_version_id)) - .map(Ok) - }) - .try_collect(); - + // Join event for new server. let event = services .federation .format_pdu_into(value, Some(&room_version_id)) .map(Some) .map(Ok); - let (auth_chain, state, event) = try_join3(auth_chain, state, event).await?; + // Join event revealed to existing servers. + let broadcast = services.sending.send_pdu_room(room_id, &pdu_id); - services - .sending - .send_pdu_room(room_id, &pdu_id) + // Wait for state gather which the remaining operations depend on. + let state_ids = state_ids.await; + let auth_heads = state_ids.iter().map(Borrow::borrow); + let into_federation_format = |pdu| { + services + .federation + .format_pdu_into(pdu, Some(&room_version_id)) + .map(Ok) + }; + + let auth_chain = services + .auth_chain + .event_ids_iter(room_id, auth_heads) + .broad_and_then(async |event_id| { + services + .timeline + .get_pdu_json(&event_id) + .and_then(into_federation_format) + .await + }) + .try_collect(); + + let state = state_ids + .iter() + .try_stream() + .broad_and_then(|event_id| { + services + .timeline + .get_pdu_json(event_id) + .and_then(into_federation_format) + }) + .try_collect(); + + let (auth_chain, state, event, ()) = try_join4(auth_chain, state, event, broadcast) + .boxed() .await?; Ok(create_join_event::v1::RoomState { auth_chain, state, event }) diff --git a/src/api/server/send_knock.rs b/src/api/server/send_knock.rs index 1375f9b5..1938e8e3 100644 --- a/src/api/server/send_knock.rs +++ b/src/api/server/send_knock.rs @@ -1,5 +1,5 @@ use axum::extract::State; -use futures::FutureExt; +use futures::{FutureExt, future::try_join}; use ruma::{ OwnedServerName, OwnedUserId, RoomVersionId::*, @@ -174,16 +174,16 @@ pub(crate) async fn create_knock_event_v1_route( drop(mutex_lock); - services + let broadcast = services .sending - .send_pdu_room(&body.room_id, &pdu_id) - .await?; + .send_pdu_room(&body.room_id, &pdu_id); + + let knock_room_state = services.state.summary_stripped(&pdu).map(Ok); + + let (knock_room_state, ()) = try_join(knock_room_state, broadcast).await?; Ok(create_knock_event::v1::Response { - knock_room_state: services - .state - .summary_stripped(&pdu) - .await + knock_room_state: knock_room_state .into_iter() .map(Into::into) .collect(), diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 7e417365..bd4061cb 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,9 +1,18 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; -use futures::{FutureExt, StreamExt, TryStreamExt}; +use futures::{ + FutureExt, StreamExt, TryFutureExt, TryStreamExt, + future::{join, try_join}, +}; use ruma::{OwnedEventId, api::federation::event::get_room_state}; -use tuwunel_core::{Result, at, err, utils::IterStream}; +use tuwunel_core::{ + Result, at, err, + utils::{ + future::TryExtExt, + stream::{IterStream, TryBroadbandExt}, + }, +}; use super::AccessCheck; use crate::Ruma; @@ -30,38 +39,50 @@ pub(crate) async fn get_room_state_route( .await .map_err(|_| err!(Request(NotFound("PDU state not found."))))?; - let state_ids: Vec = services + let state_ids = services .state_accessor .state_full_ids(shortstatehash) .map(at!(1)) - .collect() - .await; + .collect::>(); - let pdus = state_ids - .iter() - .try_stream() - .and_then(|id| services.timeline.get_pdu_json(id)) - .and_then(|pdu| { - services - .federation - .format_pdu_into(pdu, None) - .map(Ok) - }) - .try_collect() - .await?; + let room_version = services + .state + .get_room_version(&body.room_id) + .ok(); + + let (room_version, state_ids) = join(room_version, state_ids).await; + + let into_federation_format = |pdu| { + services + .federation + .format_pdu_into(pdu, room_version.as_ref()) + .map(Ok) + }; let auth_chain = services .auth_chain .event_ids_iter(&body.room_id, once(body.event_id.borrow())) - .and_then(async |id| services.timeline.get_pdu_json(&id).await) - .and_then(|pdu| { + .broad_and_then(async |id| { services - .federation - .format_pdu_into(pdu, None) - .map(Ok) + .timeline + .get_pdu_json(&id) + .and_then(into_federation_format) + .await }) - .try_collect() - .await?; + .try_collect(); + + let pdus = state_ids + .iter() + .try_stream() + .broad_and_then(|id| { + services + .timeline + .get_pdu_json(id) + .and_then(into_federation_format) + }) + .try_collect(); + + let (auth_chain, pdus) = try_join(auth_chain, pdus).await?; Ok(get_room_state::v1::Response { auth_chain, pdus }) } diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 7fda278c..4c579435 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,7 +1,7 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; -use futures::{StreamExt, TryStreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join}; use ruma::{OwnedEventId, api::federation::event::get_room_state_ids}; use tuwunel_core::{Result, at, err}; @@ -31,18 +31,19 @@ pub(crate) async fn get_room_state_ids_route( .await .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; - let pdu_ids: Vec = services - .state_accessor - .state_full_ids(shortstatehash) - .map(at!(1)) - .collect() - .await; - let auth_chain_ids = services .auth_chain .event_ids_iter(&body.room_id, once(body.event_id.borrow())) - .try_collect() - .await?; + .try_collect(); + + let pdu_ids = services + .state_accessor + .state_full_ids(shortstatehash) + .map(at!(1)) + .collect::>() + .map(Ok); + + let (auth_chain_ids, pdu_ids) = try_join(auth_chain_ids, pdu_ids).await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids, pdu_ids }) }