From 48aa6035f66ebb5e9c43db644d3c0143146f4061 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 21 Jan 2026 16:53:56 +0000 Subject: [PATCH] Hoist room_version query to callers of get_auth_chain. Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 8 +++++- src/api/server/event_auth.rs | 10 +++----- src/api/server/send_join.rs | 25 ++++++++----------- src/api/server/state.rs | 24 ++++++------------ src/api/server/state_ids.rs | 11 +++++--- src/service/rooms/auth_chain/mod.rs | 15 +++++------ .../rooms/event_handler/resolve_state.rs | 6 ++--- .../rooms/event_handler/state_at_incoming.rs | 9 ++++--- 8 files changed, 49 insertions(+), 59 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index c1a11011..4976761b 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -57,11 +57,17 @@ pub(super) async fn get_auth_chain(&self, event_id: OwnedEventId) -> Result { let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| err!(Database("Invalid room id field in event in database")))?; + let room_version = self + .services + .state + .get_room_version(room_id) + .await?; + let start = Instant::now(); let count = self .services .auth_chain - .event_ids_iter(room_id, once(event_id.as_ref())) + .event_ids_iter(room_id, &room_version, once(event_id.as_ref())) .ready_filter_map(Result::ok) .count() .await; diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index fd6037a6..fa7f5c7e 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -46,22 +46,18 @@ pub(crate) async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; - let room_version = services - .state - .get_room_version(room_id) - .await - .ok(); + let room_version = services.state.get_room_version(room_id).await?; let auth_chain = services .auth_chain - .event_ids_iter(room_id, once(body.event_id.borrow())) + .event_ids_iter(room_id, &room_version, once(body.event_id.borrow())) .ready_filter_map(Result::ok) .broad_filter_map(async |id| { let pdu = services.timeline.get_pdu_json(&id).await.ok()?; let pdu = services .federation - .format_pdu_into(pdu, room_version.as_ref()) + .format_pdu_into(pdu, Some(&room_version)) .await; Some(pdu) diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index d3127974..517534c1 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -50,9 +50,9 @@ async fn create_join_event( // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services.state.get_room_version(room_id).await?; + let room_version = services.state.get_room_version(room_id).await?; - let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { + let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version) else { // Event could not be converted to canonical json return Err!(Request(BadJson("Could not convert event to canonical json."))); }; @@ -136,9 +136,9 @@ async fn create_join_event( if let Some(authorising_user) = content.join_authorized_via_users_server { use ruma::RoomVersionId::*; - if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6 | V7) { + if matches!(room_version, V1 | V2 | V3 | V4 | V5 | V6 | V7) { return Err!(Request(InvalidParam( - "Room version {room_version_id} does not support restricted rooms but \ + "Room version {room_version} does not support restricted rooms but \ join_authorised_via_users_server ({authorising_user}) was found in the event." ))); } @@ -161,13 +161,8 @@ async fn create_join_event( ))); } - if !super::user_can_perform_restricted_join( - services, - &state_key, - room_id, - &room_version_id, - ) - .await? + if !super::user_can_perform_restricted_join(services, &state_key, room_id, &room_version) + .await? { return Err!(Request(UnableToAuthorizeJoin( "Joining user did not pass restricted room's rules." @@ -177,7 +172,7 @@ async fn create_join_event( services .server_keys - .hash_and_sign_event(&mut value, &room_version_id) + .hash_and_sign_event(&mut value, &room_version) .map_err(|e| err!(Request(InvalidParam(warn!("Failed to sign send_join event: {e}")))))?; let origin: OwnedServerName = serde_json::from_value( @@ -216,7 +211,7 @@ async fn create_join_event( // Join event for new server. let event = services .federation - .format_pdu_into(value, Some(&room_version_id)) + .format_pdu_into(value, Some(&room_version)) .map(Some) .map(Ok); @@ -229,13 +224,13 @@ async fn create_join_event( let into_federation_format = |pdu| { services .federation - .format_pdu_into(pdu, Some(&room_version_id)) + .format_pdu_into(pdu, Some(&room_version)) .map(Ok) }; let auth_chain = services .auth_chain - .event_ids_iter(room_id, auth_heads) + .event_ids_iter(room_id, &room_version, auth_heads) .broad_and_then(async |event_id| { services .timeline diff --git a/src/api/server/state.rs b/src/api/server/state.rs index d7ba63bc..a88199ac 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,17 +1,11 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; -use futures::{ - FutureExt, StreamExt, TryFutureExt, TryStreamExt, - future::{join, try_join}, -}; +use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join}; use ruma::{OwnedEventId, api::federation::event::get_room_state}; use tuwunel_core::{ Result, at, err, - utils::{ - future::TryExtExt, - stream::{IterStream, TryBroadbandExt}, - }, + utils::stream::{IterStream, TryBroadbandExt}, }; use super::AccessCheck; @@ -43,25 +37,23 @@ pub(crate) async fn get_room_state_route( .state_accessor .state_full_ids(shortstatehash) .map(at!(1)) - .collect::>(); + .collect::>() + .map(Ok); - let room_version = services - .state - .get_room_version(&body.room_id) - .ok(); + let room_version = services.state.get_room_version(&body.room_id); - let (room_version, state_ids) = join(room_version, state_ids).await; + let (room_version, state_ids) = try_join(room_version, state_ids).await?; let into_federation_format = |pdu| { services .federation - .format_pdu_into(pdu, room_version.as_ref()) + .format_pdu_into(pdu, Some(&room_version)) .map(Ok) }; let auth_chain = services .auth_chain - .event_ids_iter(&body.room_id, once(body.event_id.borrow())) + .event_ids_iter(&body.room_id, &room_version, once(body.event_id.borrow())) .broad_and_then(async |id| { services .timeline diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 093ad02a..4003c27f 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,7 +1,7 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; -use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join}; +use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join}; use ruma::{OwnedEventId, api::federation::event::get_room_state_ids}; use tuwunel_core::{Result, at, err}; @@ -28,12 +28,15 @@ pub(crate) async fn get_room_state_ids_route( let shortstatehash = services .state .pdu_shortstatehash(&body.event_id) - .await - .map_err(|_| err!(Request(NotFound("Pdu state not found."))))?; + .map_err(|_| err!(Request(NotFound("Pdu state not found.")))); + + let room_version = services.state.get_room_version(&body.room_id); + + let (shortstatehash, room_version) = try_join(shortstatehash, room_version).await?; let auth_chain_ids = services .auth_chain - .event_ids_iter(&body.room_id, once(body.event_id.borrow())) + .event_ids_iter(&body.room_id, &room_version, once(body.event_id.borrow())) .try_collect(); let pdu_ids = services diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 7a3db79d..f8982d23 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -11,10 +11,10 @@ use futures::{ FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, stream::{FuturesUnordered, unfold}, }; -use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules}; +use ruma::{EventId, OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules}; use tuwunel_core::{ Err, Result, at, debug, debug_error, err, implement, - matrix::{Event, PduEvent}, + matrix::{Event, PduEvent, room_version}, pdu::AuthEvents, trace, utils, utils::{ @@ -60,12 +60,13 @@ impl crate::Service for Service { pub fn event_ids_iter<'a, I>( &'a self, room_id: &'a RoomId, + room_version: &'a RoomVersionId, starting_events: I, ) -> impl Stream> + Send + 'a where I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, { - self.get_auth_chain(room_id, starting_events) + self.get_auth_chain(room_id, room_version, starting_events) .map_ok(|chain| { self.services .short @@ -85,6 +86,7 @@ where pub async fn get_auth_chain<'a, I>( &'a self, room_id: &RoomId, + room_version: &RoomVersionId, starting_events: I, ) -> Result> where @@ -94,12 +96,7 @@ where const BUCKET: Bucket<'_> = BTreeSet::new(); let started = Instant::now(); - let room_rules = self - .services - .state - .get_room_version_rules(room_id) - .await?; - + let room_rules = room_version::rules(room_version)?; let starting_ids = self .services .short diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs index 19e28c6d..7cefcf80 100644 --- a/src/service/rooms/event_handler/resolve_state.rs +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -17,7 +17,7 @@ use crate::rooms::state_compressor::CompressedState; pub async fn resolve_state( &self, room_id: &RoomId, - room_version_id: &RoomVersionId, + room_version: &RoomVersionId, incoming_state: HashMap, ) -> Result> { trace!("Loading current room state ids"); @@ -43,7 +43,7 @@ pub async fn resolve_state( .wide_and_then(|state| { self.services .auth_chain - .event_ids_iter(room_id, state.values().map(Borrow::borrow)) + .event_ids_iter(room_id, room_version, state.values().map(Borrow::borrow)) .try_collect::>() }) .ready_filter_map(Result::ok); @@ -64,7 +64,7 @@ pub async fn resolve_state( trace!("Resolving state"); let state = self - .state_resolution(room_version_id, fork_states, auth_chain_sets) + .state_resolution(room_version, fork_states, auth_chain_sets) .await?; trace!("State resolution done."); diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index a94f0c18..96fef6ee 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -104,7 +104,7 @@ pub(super) async fn state_at_incoming_resolved( &self, incoming_pdu: &Pdu, room_id: &RoomId, - room_version_id: &RoomVersionId, + room_version: &RoomVersionId, ) -> Result>> where Pdu: Event, @@ -141,7 +141,7 @@ where .into_iter() .try_stream() .wide_and_then(|(sstatehash, prev_event)| { - self.state_at_incoming_fork(room_id, sstatehash, prev_event) + self.state_at_incoming_fork(room_id, room_version, sstatehash, prev_event) }) .try_collect() .map_ok(Vec::into_iter) @@ -152,7 +152,7 @@ where trace!("Resolving state"); let Ok(new_state) = self - .state_resolution(room_version_id, fork_states, auth_chain_sets) + .state_resolution(room_version, fork_states, auth_chain_sets) .inspect_ok(|_| trace!("State resolution done.")) .await else { @@ -189,6 +189,7 @@ where async fn state_at_incoming_fork( &self, room_id: &RoomId, + room_version: &RoomVersionId, sstatehash: ShortStateHash, prev_event: Pdu, ) -> Result<(StateMap, AuthSet)> @@ -228,7 +229,7 @@ where let auth_chain = self .services .auth_chain - .event_ids_iter(room_id, starting_events) + .event_ids_iter(room_id, room_version, starting_events) .try_collect(); let fork_state = leaf_state_after_event