diff --git a/src/api/server/send.rs b/src/api/server/send.rs index f8a5bf35..4db08758 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -32,12 +32,13 @@ use tuwunel_core::{ debug_warn, defer, err, error, itertools::Itertools, result::LogErr, + smallvec::SmallVec, trace, utils::{ - IterStream, ReadyExt, debug::str_truncated, + future::TryExtExt, millis_since_unix_epoch, - stream::{BroadbandExt, TryBroadbandExt, automatic_width}, + stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, automatic_width}, }, warn, }; @@ -49,6 +50,9 @@ use tuwunel_service::{ use crate::Ruma; type ResolvedMap = BTreeMap; +type RoomsPdus = SmallVec<[RoomPdus; 1]>; +type RoomPdus = (OwnedRoomId, TxnPdus); +type TxnPdus = SmallVec<[(usize, Pdu); 1]>; type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject); /// # `PUT /_matrix/federation/v1/send/{txnId}` @@ -99,22 +103,27 @@ pub(crate) async fn send_transaction_message_route( .pdus .iter() .stream() - .broad_then(|pdu| services.event_handler.parse_incoming_pdu(pdu)) - .inspect_err(|e| debug_warn!("Could not parse PDU: {e}")) - .ready_filter_map(Result::ok); + .enumerate() + .broad_filter_map(|(i, pdu)| { + services + .event_handler + .parse_incoming_pdu(pdu) + .inspect_err(move |e| debug_warn!("Could not parse PDU[{i}]: {e}")) + .map_ok(move |pdu| (i, pdu)) + .ok() + }); let edus = body .edus .iter() - .map(|edu| edu.json().get()) - .map(serde_json::from_str) - .filter_map(Result::ok) - .stream(); - - trace!( - elapsed = ?txn_start_time.elapsed(), - "Parsed txn", - ); + .stream() + .enumerate() + .ready_filter_map(|(i, edu)| { + serde_json::from_str(edu.json().get()) + .inspect_err(|e| debug_warn!("Could not parse EDU[{i}]: {e}")) + .map(|edu| (i, edu)) + .ok() + }); let results = handle( &services, @@ -156,8 +165,8 @@ async fn handle( origin: &ServerName, txn_id: &TransactionId, started: Instant, - pdus: impl Stream + Send, - edus: impl Stream + Send, + pdus: impl Stream + Send, + edus: impl Stream + Send, ) -> Result { let results = handle_pdus(services, client, origin, txn_id, started, pdus).await?; @@ -172,25 +181,24 @@ async fn handle_pdus( origin: &ServerName, txn_id: &TransactionId, started: Instant, - pdus: impl Stream + Send, + pdus: impl Stream + Send, ) -> Result { - // group pdus by room - let pdus = pdus - .enumerate() - .collect() - .map(|mut pdus: Vec<_>| { - pdus.sort_by(|(_, (room_a, ..)), (_, (room_b, ..))| room_a.cmp(room_b)); + pdus.collect() + .map(Ok) + .map_ok(|pdus: TxnPdus| { pdus.into_iter() + .sorted_by(|(_, (room_a, ..)), (_, (room_b, ..))| room_a.cmp(room_b)) .into_grouping_map_by(|(_, (room_id, ..))| room_id.clone()) .collect() + .into_iter() + .try_stream() }) - .await; - - // we can evaluate rooms concurrently - let results: ResolvedMap = pdus - .into_iter() - .try_stream() - .broad_and_then(async |(room_id, pdus): (_, Vec<_>)| { + .try_flatten_stream() + .try_collect::() + .map_ok(IntoIterator::into_iter) + .map_ok(IterStream::try_stream) + .try_flatten_stream() + .broad_and_then(async |(room_id, pdus)| { handle_room(services, client, origin, txn_id, started, room_id, pdus.into_iter()) .map_ok(ResolvedMap::into_iter) .map_ok(IterStream::try_stream) @@ -198,9 +206,7 @@ async fn handle_pdus( }) .try_flatten() .try_collect() - .await?; - - Ok(results) + .await } #[tracing::instrument( @@ -218,24 +224,24 @@ async fn handle_room( ref room_id: OwnedRoomId, pdus: impl Iterator + Send, ) -> Result { - let _room_lock = services + services .event_handler .mutex_federation .lock(room_id) - .await; - - pdus.enumerate() - .try_stream() - .and_then(|(ri, (ti, (room_id, event_id, value)))| { - handle_pdu( - services, - (origin, txn_id, txn_start_time, ti), - (room_id, event_id, ri), - value, - ) + .then(async |_lock| { + pdus.enumerate() + .try_stream() + .and_then(async |pdu| { + services.server.check_running().map(|()| pdu) // interruption point + }) + .and_then(|(ri, (ti, (room_id, event_id, value)))| { + let meta = (origin, txn_id, txn_start_time, ti); + let pdu = (ri, (room_id, event_id, value)); + handle_pdu(services, meta, pdu).map(Ok) + }) + .try_collect() + .await }) - .try_collect() - .boxed() .await } @@ -248,11 +254,8 @@ async fn handle_room( async fn handle_pdu( services: &Services, (origin, txn_id, txn_start_time, ti): (&ServerName, &TransactionId, Instant, usize), - (ref room_id, event_id, ri): (OwnedRoomId, OwnedEventId, usize), - value: CanonicalJsonObject, -) -> Result<(OwnedEventId, Result)> { - services.server.check_running()?; - + (ri, (ref room_id, event_id, value)): (usize, Pdu), +) -> (OwnedEventId, Result) { let pdu_start_time = Instant::now(); let completed: AtomicBool = Default::default(); defer! {{ @@ -279,6 +282,7 @@ async fn handle_pdu( .event_handler .handle_incoming_pdu(origin, room_id, &event_id, value, true) .map_ok(|_| ()) + .boxed() .await; completed.store(true, Ordering::Release); @@ -289,7 +293,7 @@ async fn handle_pdu( "Finished PDU", ); - Ok((event_id.clone(), result)) + (event_id.clone(), result) } #[tracing::instrument(name = "edus", level = "debug", skip_all)] @@ -298,13 +302,12 @@ async fn handle_edus( client: &IpAddr, origin: &ServerName, txn_id: &TransactionId, - edus: impl Stream + Send, + edus: impl Stream + Send, ) -> Result { - edus.enumerate() - .for_each_concurrent(automatic_width(), |(i, edu)| { - handle_edu(services, client, origin, txn_id, i, edu) - }) - .await; + edus.for_each_concurrent(automatic_width(), |(i, edu)| { + handle_edu(services, client, origin, txn_id, i, edu) + }) + .await; Ok(()) }