diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index 5ff50a3c..b9368fa3 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -7,7 +7,7 @@ use std::{ use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join}; use ruma::{OwnedEventId, RoomId, RoomVersionId}; use tuwunel_core::{ - Result, debug, err, implement, + Result, err, implement, matrix::{Event, StateMap}, trace, utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt}, @@ -26,7 +26,12 @@ pub(super) async fn state_at_incoming_degree_one( where Pdu: Event, { - let prev_event = incoming_pdu + debug_assert!( + incoming_pdu.prev_events().count() == 1, + "Incoming PDU must have one prev_event to make this call" + ); + + let prev_event_id = incoming_pdu .prev_events() .next() .expect("at least one prev_event"); @@ -34,35 +39,35 @@ where let Ok(prev_event_sstatehash) = self .services .state_accessor - .pdu_shortstatehash(prev_event) + .pdu_shortstatehash(prev_event_id) .await else { return Ok(None); }; - let mut state: HashMap<_, _> = self + let prev_event = self + .services + .timeline + .get_pdu(prev_event_id) + .map_err(|e| err!(Database("Could not find prev_event, but we know the state: {e:?}"))); + + let state = self .services .state_accessor .state_full_ids(prev_event_sstatehash) - .collect() - .await; + .collect::>() + .map(Ok); - debug!("Using cached state"); - let prev_pdu = self - .services - .timeline - .get_pdu(prev_event) - .await - .map_err(|e| err!(Database("Could not find prev event, but we know the state: {e:?}")))?; + let (prev_event, mut state) = try_join(prev_event, state).await?; - if let Some(state_key) = &prev_pdu.state_key { + if let Some(state_key) = prev_event.state_key { let shortstatekey = self .services .short - .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) + .get_or_create_shortstatekey(&prev_event.kind.into(), &state_key) .await; - state.insert(shortstatekey, prev_event.to_owned()); + state.insert(shortstatekey, prev_event.event_id); // Now it's the state after the pdu } @@ -82,21 +87,24 @@ pub(super) async fn state_at_incoming_resolved( where Pdu: Event, { + debug_assert!( + incoming_pdu.prev_events().count() > 1, + "Incoming PDU should have more than one prev_event for this codepath" + ); + trace!("Calculating extremity statehashes..."); let Ok(extremity_sstatehashes) = incoming_pdu .prev_events() .try_stream() - .broad_and_then(|prev_eventid| { - self.services - .timeline - .get_pdu(prev_eventid) - .map_ok(move |prev_event| (prev_eventid, prev_event)) - }) - .broad_and_then(|(prev_eventid, prev_event)| { - self.services + .broad_and_then(|prev_event_id| { + let prev_event = self.services.timeline.get_pdu(prev_event_id); + + let sstatehash = self + .services .state_accessor - .pdu_shortstatehash(prev_eventid) - .map_ok(move |sstatehash| (sstatehash, prev_event)) + .pdu_shortstatehash(prev_event_id); + + try_join(sstatehash, prev_event) }) .try_collect::>() .await