diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 30f725ed..f8a5bf35 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -1,10 +1,16 @@ -use std::{collections::BTreeMap, net::IpAddr, time::Instant}; +use std::{ + collections::BTreeMap, + net::IpAddr, + sync::atomic::{AtomicBool, Ordering}, + time::{Duration, Instant}, +}; use axum::extract::State; use axum_client_ip::InsecureClientIp; use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ - CanonicalJsonObject, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, ServerName, UserId, + CanonicalJsonObject, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, ServerName, + TransactionId, UserId, api::{ client::error::ErrorKind, federation::transactions::{ @@ -23,7 +29,7 @@ use ruma::{ use tuwunel_core::{ Err, Error, Result, debug, debug::INFO_SPAN_LEVEL, - debug_warn, err, error, + debug_warn, defer, err, error, itertools::Itertools, result::LogErr, trace, @@ -110,7 +116,16 @@ pub(crate) async fn send_transaction_message_route( "Parsed txn", ); - let results = handle(&services, &client, body.origin(), txn_start_time, pdus, edus).await?; + let results = handle( + &services, + &client, + body.origin(), + &body.transaction_id, + txn_start_time, + pdus, + edus, + ) + .await?; debug!( pdus = body.pdus.len(), @@ -139,22 +154,23 @@ async fn handle( services: &Services, client: &IpAddr, origin: &ServerName, + txn_id: &TransactionId, started: Instant, pdus: impl Stream + Send, edus: impl Stream + Send, ) -> Result { - let results = handle_pdus(services, client, origin, started, pdus).await?; + let results = handle_pdus(services, client, origin, txn_id, started, pdus).await?; - handle_edus(services, client, origin, edus).await?; + handle_edus(services, client, origin, txn_id, edus).await?; Ok(results) } -#[tracing::instrument(name = "pdus", level = "debug", skip_all)] async fn handle_pdus( services: &Services, client: &IpAddr, origin: &ServerName, + txn_id: &TransactionId, started: Instant, pdus: impl Stream + Send, ) -> Result { @@ -175,7 +191,7 @@ async fn handle_pdus( .into_iter() .try_stream() .broad_and_then(async |(room_id, pdus): (_, Vec<_>)| { - handle_pdus_room(services, client, origin, &started, room_id, pdus.into_iter()) + handle_room(services, client, origin, txn_id, started, room_id, pdus.into_iter()) .map_ok(ResolvedMap::into_iter) .map_ok(IterStream::try_stream) .await @@ -193,11 +209,12 @@ async fn handle_pdus( skip_all, fields(%room_id) )] -async fn handle_pdus_room( +async fn handle_room( services: &Services, _client: &IpAddr, origin: &ServerName, - txn_start_time: &Instant, + txn_id: &TransactionId, + txn_start_time: Instant, ref room_id: OwnedRoomId, pdus: impl Iterator + Send, ) -> Result { @@ -209,39 +226,83 @@ async fn handle_pdus_room( pdus.enumerate() .try_stream() - .and_then(async |(ri, (ti, (room_id, event_id, value)))| { - services.server.check_running()?; - let pdu_start_time = Instant::now(); - let result = services - .event_handler - .handle_incoming_pdu(origin, &room_id, &event_id, value, true) - .map_ok(|_| ()) - .await; - - debug!( - %event_id, ri, ti, - pdu_elapsed = ?pdu_start_time.elapsed(), - txn_elapsed = ?txn_start_time.elapsed(), - "Finished PDU", - ); - - Ok((event_id, result)) + .and_then(|(ri, (ti, (room_id, event_id, value)))| { + handle_pdu( + services, + (origin, txn_id, txn_start_time, ti), + (room_id, event_id, ri), + value, + ) }) .try_collect() .boxed() .await } +#[tracing::instrument( + name = "pdu", + level = INFO_SPAN_LEVEL, + skip_all, + fields(%event_id, %ti, %ri) +)] +async fn handle_pdu( + services: &Services, + (origin, txn_id, txn_start_time, ti): (&ServerName, &TransactionId, Instant, usize), + (ref room_id, event_id, ri): (OwnedRoomId, OwnedEventId, usize), + value: CanonicalJsonObject, +) -> Result<(OwnedEventId, Result)> { + services.server.check_running()?; + + let pdu_start_time = Instant::now(); + let completed: AtomicBool = Default::default(); + defer! {{ + if completed.load(Ordering::Acquire) { + return; + } + + if pdu_start_time.elapsed() >= Duration::from_secs(services.config.client_request_timeout) { + error!( + %origin, %txn_id, %room_id, %event_id, %ri, %ti, + elapsed = ?pdu_start_time.elapsed(), + "Incoming transaction processing timed out.", + ); + } else { + debug_warn!( + %origin, %txn_id, %room_id, %event_id, %ri, %ti, + elapsed = ?pdu_start_time.elapsed(), + "Incoming transaction processing interrupted.", + ); + } + }} + + let result = services + .event_handler + .handle_incoming_pdu(origin, room_id, &event_id, value, true) + .map_ok(|_| ()) + .await; + + completed.store(true, Ordering::Release); + debug!( + %event_id, ri, ti, + pdu_elapsed = ?pdu_start_time.elapsed(), + txn_elapsed = ?txn_start_time.elapsed(), + "Finished PDU", + ); + + Ok((event_id.clone(), result)) +} + #[tracing::instrument(name = "edus", level = "debug", skip_all)] async fn handle_edus( services: &Services, client: &IpAddr, origin: &ServerName, + txn_id: &TransactionId, edus: impl Stream + Send, ) -> Result { edus.enumerate() .for_each_concurrent(automatic_width(), |(i, edu)| { - handle_edu(services, client, origin, i, edu) + handle_edu(services, client, origin, txn_id, i, edu) }) .await; @@ -258,6 +319,7 @@ async fn handle_edu( services: &Services, client: &IpAddr, origin: &ServerName, + _txn_id: &TransactionId, i: usize, edu: Edu, ) {