diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index ddd61d49..7db6f963 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -234,39 +234,19 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu( // 14. Check if the event passes auth based on the "current state" of the room, // if not soft fail it - if soft_fail { - debug!("Soft failing event"); - let extremities = extremities.iter().map(Borrow::borrow); - - self.services - .timeline - .append_incoming_pdu( - &incoming_pdu, - val, - extremities, - state_ids_compressed, - soft_fail, - &state_lock, - ) - .await?; - - // Soft fail, we keep the event as an outlier but don't add it to the timeline - self.services - .pdu_metadata - .mark_event_soft_failed(incoming_pdu.event_id()); - - warn!("Event was soft failed: {:?}", incoming_pdu.event_id()); - return Err!(Request(InvalidParam("Event has been soft failed"))); - } - + // // Now that the event has passed all auth it is added into the timeline. // We use the `state_at_event` instead of `state_after` so we accurately // represent the state for this event. trace!("Appending pdu to timeline"); + + // Incoming event will be referenced in prev_events unless soft-failed. + let incoming_extremity = once(incoming_pdu.event_id()).filter(|_| !soft_fail); + let extremities = extremities .iter() .map(Borrow::borrow) - .chain(once(incoming_pdu.event_id())); + .chain(incoming_extremity); let pdu_id = self .services @@ -281,7 +261,21 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu( ) .await?; - // Event has passed all auth/stateres checks + if soft_fail { + self.services + .pdu_metadata + .mark_event_soft_failed(incoming_pdu.event_id()); + + drop(state_lock); + warn!( + elapsed = ?timer.elapsed(), + "Event was soft failed: {:?}", + incoming_pdu.event_id() + ); + + return Err!(Request(InvalidParam("Event has been soft failed"))); + } + drop(state_lock); debug_info!( elapsed = ?timer.elapsed(),