Adjust federation send handler sans applying topological sort.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-03-05 11:28:23 +00:00
parent 0ecdb86aca
commit 513c1184fe

View File

@@ -32,12 +32,13 @@ use tuwunel_core::{
debug_warn, defer, err, error,
itertools::Itertools,
result::LogErr,
smallvec::SmallVec,
trace,
utils::{
IterStream, ReadyExt,
debug::str_truncated,
future::TryExtExt,
millis_since_unix_epoch,
stream::{BroadbandExt, TryBroadbandExt, automatic_width},
stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, automatic_width},
},
warn,
};
@@ -49,6 +50,9 @@ use tuwunel_service::{
use crate::Ruma;
type ResolvedMap = BTreeMap<OwnedEventId, Result>;
type RoomsPdus = SmallVec<[RoomPdus; 1]>;
type RoomPdus = (OwnedRoomId, TxnPdus);
type TxnPdus = SmallVec<[(usize, Pdu); 1]>;
type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
/// # `PUT /_matrix/federation/v1/send/{txnId}`
@@ -99,22 +103,27 @@ pub(crate) async fn send_transaction_message_route(
.pdus
.iter()
.stream()
.broad_then(|pdu| services.event_handler.parse_incoming_pdu(pdu))
.inspect_err(|e| debug_warn!("Could not parse PDU: {e}"))
.ready_filter_map(Result::ok);
.enumerate()
.broad_filter_map(|(i, pdu)| {
services
.event_handler
.parse_incoming_pdu(pdu)
.inspect_err(move |e| debug_warn!("Could not parse PDU[{i}]: {e}"))
.map_ok(move |pdu| (i, pdu))
.ok()
});
let edus = body
.edus
.iter()
.map(|edu| edu.json().get())
.map(serde_json::from_str)
.filter_map(Result::ok)
.stream();
trace!(
elapsed = ?txn_start_time.elapsed(),
"Parsed txn",
);
.stream()
.enumerate()
.ready_filter_map(|(i, edu)| {
serde_json::from_str(edu.json().get())
.inspect_err(|e| debug_warn!("Could not parse EDU[{i}]: {e}"))
.map(|edu| (i, edu))
.ok()
});
let results = handle(
&services,
@@ -156,8 +165,8 @@ async fn handle(
origin: &ServerName,
txn_id: &TransactionId,
started: Instant,
pdus: impl Stream<Item = Pdu> + Send,
edus: impl Stream<Item = Edu> + Send,
pdus: impl Stream<Item = (usize, Pdu)> + Send,
edus: impl Stream<Item = (usize, Edu)> + Send,
) -> Result<ResolvedMap> {
let results = handle_pdus(services, client, origin, txn_id, started, pdus).await?;
@@ -172,25 +181,24 @@ async fn handle_pdus(
origin: &ServerName,
txn_id: &TransactionId,
started: Instant,
pdus: impl Stream<Item = Pdu> + Send,
pdus: impl Stream<Item = (usize, Pdu)> + Send,
) -> Result<ResolvedMap> {
// group pdus by room
let pdus = pdus
.enumerate()
.collect()
.map(|mut pdus: Vec<_>| {
pdus.sort_by(|(_, (room_a, ..)), (_, (room_b, ..))| room_a.cmp(room_b));
pdus.collect()
.map(Ok)
.map_ok(|pdus: TxnPdus| {
pdus.into_iter()
.sorted_by(|(_, (room_a, ..)), (_, (room_b, ..))| room_a.cmp(room_b))
.into_grouping_map_by(|(_, (room_id, ..))| room_id.clone())
.collect()
.into_iter()
.try_stream()
})
.await;
// we can evaluate rooms concurrently
let results: ResolvedMap = pdus
.into_iter()
.try_stream()
.broad_and_then(async |(room_id, pdus): (_, Vec<_>)| {
.try_flatten_stream()
.try_collect::<RoomsPdus>()
.map_ok(IntoIterator::into_iter)
.map_ok(IterStream::try_stream)
.try_flatten_stream()
.broad_and_then(async |(room_id, pdus)| {
handle_room(services, client, origin, txn_id, started, room_id, pdus.into_iter())
.map_ok(ResolvedMap::into_iter)
.map_ok(IterStream::try_stream)
@@ -198,9 +206,7 @@ async fn handle_pdus(
})
.try_flatten()
.try_collect()
.await?;
Ok(results)
.await
}
#[tracing::instrument(
@@ -218,24 +224,24 @@ async fn handle_room(
ref room_id: OwnedRoomId,
pdus: impl Iterator<Item = (usize, Pdu)> + Send,
) -> Result<ResolvedMap> {
let _room_lock = services
services
.event_handler
.mutex_federation
.lock(room_id)
.await;
pdus.enumerate()
.try_stream()
.and_then(|(ri, (ti, (room_id, event_id, value)))| {
handle_pdu(
services,
(origin, txn_id, txn_start_time, ti),
(room_id, event_id, ri),
value,
)
.then(async |_lock| {
pdus.enumerate()
.try_stream()
.and_then(async |pdu| {
services.server.check_running().map(|()| pdu) // interruption point
})
.and_then(|(ri, (ti, (room_id, event_id, value)))| {
let meta = (origin, txn_id, txn_start_time, ti);
let pdu = (ri, (room_id, event_id, value));
handle_pdu(services, meta, pdu).map(Ok)
})
.try_collect()
.await
})
.try_collect()
.boxed()
.await
}
@@ -248,11 +254,8 @@ async fn handle_room(
async fn handle_pdu(
services: &Services,
(origin, txn_id, txn_start_time, ti): (&ServerName, &TransactionId, Instant, usize),
(ref room_id, event_id, ri): (OwnedRoomId, OwnedEventId, usize),
value: CanonicalJsonObject,
) -> Result<(OwnedEventId, Result)> {
services.server.check_running()?;
(ri, (ref room_id, event_id, value)): (usize, Pdu),
) -> (OwnedEventId, Result) {
let pdu_start_time = Instant::now();
let completed: AtomicBool = Default::default();
defer! {{
@@ -279,6 +282,7 @@ async fn handle_pdu(
.event_handler
.handle_incoming_pdu(origin, room_id, &event_id, value, true)
.map_ok(|_| ())
.boxed()
.await;
completed.store(true, Ordering::Release);
@@ -289,7 +293,7 @@ async fn handle_pdu(
"Finished PDU",
);
Ok((event_id.clone(), result))
(event_id.clone(), result)
}
#[tracing::instrument(name = "edus", level = "debug", skip_all)]
@@ -298,13 +302,12 @@ async fn handle_edus(
client: &IpAddr,
origin: &ServerName,
txn_id: &TransactionId,
edus: impl Stream<Item = Edu> + Send,
edus: impl Stream<Item = (usize, Edu)> + Send,
) -> Result {
edus.enumerate()
.for_each_concurrent(automatic_width(), |(i, edu)| {
handle_edu(services, client, origin, txn_id, i, edu)
})
.await;
edus.for_each_concurrent(automatic_width(), |(i, edu)| {
handle_edu(services, client, origin, txn_id, i, edu)
})
.await;
Ok(())
}