Hoist room_version query to callers of get_auth_chain.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-21 16:53:56 +00:00
parent afcb2315ee
commit 48aa6035f6
8 changed files with 49 additions and 59 deletions

View File

@@ -57,11 +57,17 @@ pub(super) async fn get_auth_chain(&self, event_id: OwnedEventId) -> Result {
let room_id = <&RoomId>::try_from(room_id_str) let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| err!(Database("Invalid room id field in event in database")))?; .map_err(|_| err!(Database("Invalid room id field in event in database")))?;
let room_version = self
.services
.state
.get_room_version(room_id)
.await?;
let start = Instant::now(); let start = Instant::now();
let count = self let count = self
.services .services
.auth_chain .auth_chain
.event_ids_iter(room_id, once(event_id.as_ref())) .event_ids_iter(room_id, &room_version, once(event_id.as_ref()))
.ready_filter_map(Result::ok) .ready_filter_map(Result::ok)
.count() .count()
.await; .await;

View File

@@ -46,22 +46,18 @@ pub(crate) async fn get_event_authorization_route(
let room_id = <&RoomId>::try_from(room_id_str) let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; .map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
let room_version = services let room_version = services.state.get_room_version(room_id).await?;
.state
.get_room_version(room_id)
.await
.ok();
let auth_chain = services let auth_chain = services
.auth_chain .auth_chain
.event_ids_iter(room_id, once(body.event_id.borrow())) .event_ids_iter(room_id, &room_version, once(body.event_id.borrow()))
.ready_filter_map(Result::ok) .ready_filter_map(Result::ok)
.broad_filter_map(async |id| { .broad_filter_map(async |id| {
let pdu = services.timeline.get_pdu_json(&id).await.ok()?; let pdu = services.timeline.get_pdu_json(&id).await.ok()?;
let pdu = services let pdu = services
.federation .federation
.format_pdu_into(pdu, room_version.as_ref()) .format_pdu_into(pdu, Some(&room_version))
.await; .await;
Some(pdu) Some(pdu)

View File

@@ -50,9 +50,9 @@ async fn create_join_event(
// We do not add the event_id field to the pdu here because of signature and // We do not add the event_id field to the pdu here because of signature and
// hashes checks // hashes checks
let room_version_id = services.state.get_room_version(room_id).await?; let room_version = services.state.get_room_version(room_id).await?;
let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version) else {
// Event could not be converted to canonical json // Event could not be converted to canonical json
return Err!(Request(BadJson("Could not convert event to canonical json."))); return Err!(Request(BadJson("Could not convert event to canonical json.")));
}; };
@@ -136,9 +136,9 @@ async fn create_join_event(
if let Some(authorising_user) = content.join_authorized_via_users_server { if let Some(authorising_user) = content.join_authorized_via_users_server {
use ruma::RoomVersionId::*; use ruma::RoomVersionId::*;
if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6 | V7) { if matches!(room_version, V1 | V2 | V3 | V4 | V5 | V6 | V7) {
return Err!(Request(InvalidParam( return Err!(Request(InvalidParam(
"Room version {room_version_id} does not support restricted rooms but \ "Room version {room_version} does not support restricted rooms but \
join_authorised_via_users_server ({authorising_user}) was found in the event." join_authorised_via_users_server ({authorising_user}) was found in the event."
))); )));
} }
@@ -161,13 +161,8 @@ async fn create_join_event(
))); )));
} }
if !super::user_can_perform_restricted_join( if !super::user_can_perform_restricted_join(services, &state_key, room_id, &room_version)
services, .await?
&state_key,
room_id,
&room_version_id,
)
.await?
{ {
return Err!(Request(UnableToAuthorizeJoin( return Err!(Request(UnableToAuthorizeJoin(
"Joining user did not pass restricted room's rules." "Joining user did not pass restricted room's rules."
@@ -177,7 +172,7 @@ async fn create_join_event(
services services
.server_keys .server_keys
.hash_and_sign_event(&mut value, &room_version_id) .hash_and_sign_event(&mut value, &room_version)
.map_err(|e| err!(Request(InvalidParam(warn!("Failed to sign send_join event: {e}")))))?; .map_err(|e| err!(Request(InvalidParam(warn!("Failed to sign send_join event: {e}")))))?;
let origin: OwnedServerName = serde_json::from_value( let origin: OwnedServerName = serde_json::from_value(
@@ -216,7 +211,7 @@ async fn create_join_event(
// Join event for new server. // Join event for new server.
let event = services let event = services
.federation .federation
.format_pdu_into(value, Some(&room_version_id)) .format_pdu_into(value, Some(&room_version))
.map(Some) .map(Some)
.map(Ok); .map(Ok);
@@ -229,13 +224,13 @@ async fn create_join_event(
let into_federation_format = |pdu| { let into_federation_format = |pdu| {
services services
.federation .federation
.format_pdu_into(pdu, Some(&room_version_id)) .format_pdu_into(pdu, Some(&room_version))
.map(Ok) .map(Ok)
}; };
let auth_chain = services let auth_chain = services
.auth_chain .auth_chain
.event_ids_iter(room_id, auth_heads) .event_ids_iter(room_id, &room_version, auth_heads)
.broad_and_then(async |event_id| { .broad_and_then(async |event_id| {
services services
.timeline .timeline

View File

@@ -1,17 +1,11 @@
use std::{borrow::Borrow, iter::once}; use std::{borrow::Borrow, iter::once};
use axum::extract::State; use axum::extract::State;
use futures::{ use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
FutureExt, StreamExt, TryFutureExt, TryStreamExt,
future::{join, try_join},
};
use ruma::{OwnedEventId, api::federation::event::get_room_state}; use ruma::{OwnedEventId, api::federation::event::get_room_state};
use tuwunel_core::{ use tuwunel_core::{
Result, at, err, Result, at, err,
utils::{ utils::stream::{IterStream, TryBroadbandExt},
future::TryExtExt,
stream::{IterStream, TryBroadbandExt},
},
}; };
use super::AccessCheck; use super::AccessCheck;
@@ -43,25 +37,23 @@ pub(crate) async fn get_room_state_route(
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.map(at!(1)) .map(at!(1))
.collect::<Vec<OwnedEventId>>(); .collect::<Vec<OwnedEventId>>()
.map(Ok);
let room_version = services let room_version = services.state.get_room_version(&body.room_id);
.state
.get_room_version(&body.room_id)
.ok();
let (room_version, state_ids) = join(room_version, state_ids).await; let (room_version, state_ids) = try_join(room_version, state_ids).await?;
let into_federation_format = |pdu| { let into_federation_format = |pdu| {
services services
.federation .federation
.format_pdu_into(pdu, room_version.as_ref()) .format_pdu_into(pdu, Some(&room_version))
.map(Ok) .map(Ok)
}; };
let auth_chain = services let auth_chain = services
.auth_chain .auth_chain
.event_ids_iter(&body.room_id, once(body.event_id.borrow())) .event_ids_iter(&body.room_id, &room_version, once(body.event_id.borrow()))
.broad_and_then(async |id| { .broad_and_then(async |id| {
services services
.timeline .timeline

View File

@@ -1,7 +1,7 @@
use std::{borrow::Borrow, iter::once}; use std::{borrow::Borrow, iter::once};
use axum::extract::State; use axum::extract::State;
use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join}; use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
use ruma::{OwnedEventId, api::federation::event::get_room_state_ids}; use ruma::{OwnedEventId, api::federation::event::get_room_state_ids};
use tuwunel_core::{Result, at, err}; use tuwunel_core::{Result, at, err};
@@ -28,12 +28,15 @@ pub(crate) async fn get_room_state_ids_route(
let shortstatehash = services let shortstatehash = services
.state .state
.pdu_shortstatehash(&body.event_id) .pdu_shortstatehash(&body.event_id)
.await .map_err(|_| err!(Request(NotFound("Pdu state not found."))));
.map_err(|_| err!(Request(NotFound("Pdu state not found."))))?;
let room_version = services.state.get_room_version(&body.room_id);
let (shortstatehash, room_version) = try_join(shortstatehash, room_version).await?;
let auth_chain_ids = services let auth_chain_ids = services
.auth_chain .auth_chain
.event_ids_iter(&body.room_id, once(body.event_id.borrow())) .event_ids_iter(&body.room_id, &room_version, once(body.event_id.borrow()))
.try_collect(); .try_collect();
let pdu_ids = services let pdu_ids = services

View File

@@ -11,10 +11,10 @@ use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
stream::{FuturesUnordered, unfold}, stream::{FuturesUnordered, unfold},
}; };
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules}; use ruma::{EventId, OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
use tuwunel_core::{ use tuwunel_core::{
Err, Result, at, debug, debug_error, err, implement, Err, Result, at, debug, debug_error, err, implement,
matrix::{Event, PduEvent}, matrix::{Event, PduEvent, room_version},
pdu::AuthEvents, pdu::AuthEvents,
trace, utils, trace, utils,
utils::{ utils::{
@@ -60,12 +60,13 @@ impl crate::Service for Service {
pub fn event_ids_iter<'a, I>( pub fn event_ids_iter<'a, I>(
&'a self, &'a self,
room_id: &'a RoomId, room_id: &'a RoomId,
room_version: &'a RoomVersionId,
starting_events: I, starting_events: I,
) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a ) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
where where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a, I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{ {
self.get_auth_chain(room_id, starting_events) self.get_auth_chain(room_id, room_version, starting_events)
.map_ok(|chain| { .map_ok(|chain| {
self.services self.services
.short .short
@@ -85,6 +86,7 @@ where
pub async fn get_auth_chain<'a, I>( pub async fn get_auth_chain<'a, I>(
&'a self, &'a self,
room_id: &RoomId, room_id: &RoomId,
room_version: &RoomVersionId,
starting_events: I, starting_events: I,
) -> Result<Vec<ShortEventId>> ) -> Result<Vec<ShortEventId>>
where where
@@ -94,12 +96,7 @@ where
const BUCKET: Bucket<'_> = BTreeSet::new(); const BUCKET: Bucket<'_> = BTreeSet::new();
let started = Instant::now(); let started = Instant::now();
let room_rules = self let room_rules = room_version::rules(room_version)?;
.services
.state
.get_room_version_rules(room_id)
.await?;
let starting_ids = self let starting_ids = self
.services .services
.short .short

View File

@@ -17,7 +17,7 @@ use crate::rooms::state_compressor::CompressedState;
pub async fn resolve_state( pub async fn resolve_state(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
room_version_id: &RoomVersionId, room_version: &RoomVersionId,
incoming_state: HashMap<u64, OwnedEventId>, incoming_state: HashMap<u64, OwnedEventId>,
) -> Result<Arc<CompressedState>> { ) -> Result<Arc<CompressedState>> {
trace!("Loading current room state ids"); trace!("Loading current room state ids");
@@ -43,7 +43,7 @@ pub async fn resolve_state(
.wide_and_then(|state| { .wide_and_then(|state| {
self.services self.services
.auth_chain .auth_chain
.event_ids_iter(room_id, state.values().map(Borrow::borrow)) .event_ids_iter(room_id, room_version, state.values().map(Borrow::borrow))
.try_collect::<AuthSet<OwnedEventId>>() .try_collect::<AuthSet<OwnedEventId>>()
}) })
.ready_filter_map(Result::ok); .ready_filter_map(Result::ok);
@@ -64,7 +64,7 @@ pub async fn resolve_state(
trace!("Resolving state"); trace!("Resolving state");
let state = self let state = self
.state_resolution(room_version_id, fork_states, auth_chain_sets) .state_resolution(room_version, fork_states, auth_chain_sets)
.await?; .await?;
trace!("State resolution done."); trace!("State resolution done.");

View File

@@ -104,7 +104,7 @@ pub(super) async fn state_at_incoming_resolved<Pdu>(
&self, &self,
incoming_pdu: &Pdu, incoming_pdu: &Pdu,
room_id: &RoomId, room_id: &RoomId,
room_version_id: &RoomVersionId, room_version: &RoomVersionId,
) -> Result<Option<HashMap<u64, OwnedEventId>>> ) -> Result<Option<HashMap<u64, OwnedEventId>>>
where where
Pdu: Event, Pdu: Event,
@@ -141,7 +141,7 @@ where
.into_iter() .into_iter()
.try_stream() .try_stream()
.wide_and_then(|(sstatehash, prev_event)| { .wide_and_then(|(sstatehash, prev_event)| {
self.state_at_incoming_fork(room_id, sstatehash, prev_event) self.state_at_incoming_fork(room_id, room_version, sstatehash, prev_event)
}) })
.try_collect() .try_collect()
.map_ok(Vec::into_iter) .map_ok(Vec::into_iter)
@@ -152,7 +152,7 @@ where
trace!("Resolving state"); trace!("Resolving state");
let Ok(new_state) = self let Ok(new_state) = self
.state_resolution(room_version_id, fork_states, auth_chain_sets) .state_resolution(room_version, fork_states, auth_chain_sets)
.inspect_ok(|_| trace!("State resolution done.")) .inspect_ok(|_| trace!("State resolution done."))
.await .await
else { else {
@@ -189,6 +189,7 @@ where
async fn state_at_incoming_fork<Pdu>( async fn state_at_incoming_fork<Pdu>(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
room_version: &RoomVersionId,
sstatehash: ShortStateHash, sstatehash: ShortStateHash,
prev_event: Pdu, prev_event: Pdu,
) -> Result<(StateMap<OwnedEventId>, AuthSet<OwnedEventId>)> ) -> Result<(StateMap<OwnedEventId>, AuthSet<OwnedEventId>)>
@@ -228,7 +229,7 @@ where
let auth_chain = self let auth_chain = self
.services .services
.auth_chain .auth_chain
.event_ids_iter(room_id, starting_events) .event_ids_iter(room_id, room_version, starting_events)
.try_collect(); .try_collect();
let fork_state = leaf_state_after_event let fork_state = leaf_state_after_event