Hoist room_version query to callers of get_auth_chain.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -57,11 +57,17 @@ pub(super) async fn get_auth_chain(&self, event_id: OwnedEventId) -> Result {
|
||||
let room_id = <&RoomId>::try_from(room_id_str)
|
||||
.map_err(|_| err!(Database("Invalid room id field in event in database")))?;
|
||||
|
||||
let room_version = self
|
||||
.services
|
||||
.state
|
||||
.get_room_version(room_id)
|
||||
.await?;
|
||||
|
||||
let start = Instant::now();
|
||||
let count = self
|
||||
.services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, once(event_id.as_ref()))
|
||||
.event_ids_iter(room_id, &room_version, once(event_id.as_ref()))
|
||||
.ready_filter_map(Result::ok)
|
||||
.count()
|
||||
.await;
|
||||
|
||||
@@ -46,22 +46,18 @@ pub(crate) async fn get_event_authorization_route(
|
||||
let room_id = <&RoomId>::try_from(room_id_str)
|
||||
.map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
|
||||
|
||||
let room_version = services
|
||||
.state
|
||||
.get_room_version(room_id)
|
||||
.await
|
||||
.ok();
|
||||
let room_version = services.state.get_room_version(room_id).await?;
|
||||
|
||||
let auth_chain = services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, once(body.event_id.borrow()))
|
||||
.event_ids_iter(room_id, &room_version, once(body.event_id.borrow()))
|
||||
.ready_filter_map(Result::ok)
|
||||
.broad_filter_map(async |id| {
|
||||
let pdu = services.timeline.get_pdu_json(&id).await.ok()?;
|
||||
|
||||
let pdu = services
|
||||
.federation
|
||||
.format_pdu_into(pdu, room_version.as_ref())
|
||||
.format_pdu_into(pdu, Some(&room_version))
|
||||
.await;
|
||||
|
||||
Some(pdu)
|
||||
|
||||
@@ -50,9 +50,9 @@ async fn create_join_event(
|
||||
|
||||
// We do not add the event_id field to the pdu here because of signature and
|
||||
// hashes checks
|
||||
let room_version_id = services.state.get_room_version(room_id).await?;
|
||||
let room_version = services.state.get_room_version(room_id).await?;
|
||||
|
||||
let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
|
||||
let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version) else {
|
||||
// Event could not be converted to canonical json
|
||||
return Err!(Request(BadJson("Could not convert event to canonical json.")));
|
||||
};
|
||||
@@ -136,9 +136,9 @@ async fn create_join_event(
|
||||
if let Some(authorising_user) = content.join_authorized_via_users_server {
|
||||
use ruma::RoomVersionId::*;
|
||||
|
||||
if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6 | V7) {
|
||||
if matches!(room_version, V1 | V2 | V3 | V4 | V5 | V6 | V7) {
|
||||
return Err!(Request(InvalidParam(
|
||||
"Room version {room_version_id} does not support restricted rooms but \
|
||||
"Room version {room_version} does not support restricted rooms but \
|
||||
join_authorised_via_users_server ({authorising_user}) was found in the event."
|
||||
)));
|
||||
}
|
||||
@@ -161,13 +161,8 @@ async fn create_join_event(
|
||||
)));
|
||||
}
|
||||
|
||||
if !super::user_can_perform_restricted_join(
|
||||
services,
|
||||
&state_key,
|
||||
room_id,
|
||||
&room_version_id,
|
||||
)
|
||||
.await?
|
||||
if !super::user_can_perform_restricted_join(services, &state_key, room_id, &room_version)
|
||||
.await?
|
||||
{
|
||||
return Err!(Request(UnableToAuthorizeJoin(
|
||||
"Joining user did not pass restricted room's rules."
|
||||
@@ -177,7 +172,7 @@ async fn create_join_event(
|
||||
|
||||
services
|
||||
.server_keys
|
||||
.hash_and_sign_event(&mut value, &room_version_id)
|
||||
.hash_and_sign_event(&mut value, &room_version)
|
||||
.map_err(|e| err!(Request(InvalidParam(warn!("Failed to sign send_join event: {e}")))))?;
|
||||
|
||||
let origin: OwnedServerName = serde_json::from_value(
|
||||
@@ -216,7 +211,7 @@ async fn create_join_event(
|
||||
// Join event for new server.
|
||||
let event = services
|
||||
.federation
|
||||
.format_pdu_into(value, Some(&room_version_id))
|
||||
.format_pdu_into(value, Some(&room_version))
|
||||
.map(Some)
|
||||
.map(Ok);
|
||||
|
||||
@@ -229,13 +224,13 @@ async fn create_join_event(
|
||||
let into_federation_format = |pdu| {
|
||||
services
|
||||
.federation
|
||||
.format_pdu_into(pdu, Some(&room_version_id))
|
||||
.format_pdu_into(pdu, Some(&room_version))
|
||||
.map(Ok)
|
||||
};
|
||||
|
||||
let auth_chain = services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, auth_heads)
|
||||
.event_ids_iter(room_id, &room_version, auth_heads)
|
||||
.broad_and_then(async |event_id| {
|
||||
services
|
||||
.timeline
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
use std::{borrow::Borrow, iter::once};
|
||||
|
||||
use axum::extract::State;
|
||||
use futures::{
|
||||
FutureExt, StreamExt, TryFutureExt, TryStreamExt,
|
||||
future::{join, try_join},
|
||||
};
|
||||
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
|
||||
use ruma::{OwnedEventId, api::federation::event::get_room_state};
|
||||
use tuwunel_core::{
|
||||
Result, at, err,
|
||||
utils::{
|
||||
future::TryExtExt,
|
||||
stream::{IterStream, TryBroadbandExt},
|
||||
},
|
||||
utils::stream::{IterStream, TryBroadbandExt},
|
||||
};
|
||||
|
||||
use super::AccessCheck;
|
||||
@@ -43,25 +37,23 @@ pub(crate) async fn get_room_state_route(
|
||||
.state_accessor
|
||||
.state_full_ids(shortstatehash)
|
||||
.map(at!(1))
|
||||
.collect::<Vec<OwnedEventId>>();
|
||||
.collect::<Vec<OwnedEventId>>()
|
||||
.map(Ok);
|
||||
|
||||
let room_version = services
|
||||
.state
|
||||
.get_room_version(&body.room_id)
|
||||
.ok();
|
||||
let room_version = services.state.get_room_version(&body.room_id);
|
||||
|
||||
let (room_version, state_ids) = join(room_version, state_ids).await;
|
||||
let (room_version, state_ids) = try_join(room_version, state_ids).await?;
|
||||
|
||||
let into_federation_format = |pdu| {
|
||||
services
|
||||
.federation
|
||||
.format_pdu_into(pdu, room_version.as_ref())
|
||||
.format_pdu_into(pdu, Some(&room_version))
|
||||
.map(Ok)
|
||||
};
|
||||
|
||||
let auth_chain = services
|
||||
.auth_chain
|
||||
.event_ids_iter(&body.room_id, once(body.event_id.borrow()))
|
||||
.event_ids_iter(&body.room_id, &room_version, once(body.event_id.borrow()))
|
||||
.broad_and_then(async |id| {
|
||||
services
|
||||
.timeline
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::{borrow::Borrow, iter::once};
|
||||
|
||||
use axum::extract::State;
|
||||
use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join};
|
||||
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
|
||||
use ruma::{OwnedEventId, api::federation::event::get_room_state_ids};
|
||||
use tuwunel_core::{Result, at, err};
|
||||
|
||||
@@ -28,12 +28,15 @@ pub(crate) async fn get_room_state_ids_route(
|
||||
let shortstatehash = services
|
||||
.state
|
||||
.pdu_shortstatehash(&body.event_id)
|
||||
.await
|
||||
.map_err(|_| err!(Request(NotFound("Pdu state not found."))))?;
|
||||
.map_err(|_| err!(Request(NotFound("Pdu state not found."))));
|
||||
|
||||
let room_version = services.state.get_room_version(&body.room_id);
|
||||
|
||||
let (shortstatehash, room_version) = try_join(shortstatehash, room_version).await?;
|
||||
|
||||
let auth_chain_ids = services
|
||||
.auth_chain
|
||||
.event_ids_iter(&body.room_id, once(body.event_id.borrow()))
|
||||
.event_ids_iter(&body.room_id, &room_version, once(body.event_id.borrow()))
|
||||
.try_collect();
|
||||
|
||||
let pdu_ids = services
|
||||
|
||||
@@ -11,10 +11,10 @@ use futures::{
|
||||
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
|
||||
stream::{FuturesUnordered, unfold},
|
||||
};
|
||||
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules};
|
||||
use ruma::{EventId, OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
|
||||
use tuwunel_core::{
|
||||
Err, Result, at, debug, debug_error, err, implement,
|
||||
matrix::{Event, PduEvent},
|
||||
matrix::{Event, PduEvent, room_version},
|
||||
pdu::AuthEvents,
|
||||
trace, utils,
|
||||
utils::{
|
||||
@@ -60,12 +60,13 @@ impl crate::Service for Service {
|
||||
pub fn event_ids_iter<'a, I>(
|
||||
&'a self,
|
||||
room_id: &'a RoomId,
|
||||
room_version: &'a RoomVersionId,
|
||||
starting_events: I,
|
||||
) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
|
||||
where
|
||||
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||
{
|
||||
self.get_auth_chain(room_id, starting_events)
|
||||
self.get_auth_chain(room_id, room_version, starting_events)
|
||||
.map_ok(|chain| {
|
||||
self.services
|
||||
.short
|
||||
@@ -85,6 +86,7 @@ where
|
||||
pub async fn get_auth_chain<'a, I>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
room_version: &RoomVersionId,
|
||||
starting_events: I,
|
||||
) -> Result<Vec<ShortEventId>>
|
||||
where
|
||||
@@ -94,12 +96,7 @@ where
|
||||
const BUCKET: Bucket<'_> = BTreeSet::new();
|
||||
|
||||
let started = Instant::now();
|
||||
let room_rules = self
|
||||
.services
|
||||
.state
|
||||
.get_room_version_rules(room_id)
|
||||
.await?;
|
||||
|
||||
let room_rules = room_version::rules(room_version)?;
|
||||
let starting_ids = self
|
||||
.services
|
||||
.short
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::rooms::state_compressor::CompressedState;
|
||||
pub async fn resolve_state(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
room_version_id: &RoomVersionId,
|
||||
room_version: &RoomVersionId,
|
||||
incoming_state: HashMap<u64, OwnedEventId>,
|
||||
) -> Result<Arc<CompressedState>> {
|
||||
trace!("Loading current room state ids");
|
||||
@@ -43,7 +43,7 @@ pub async fn resolve_state(
|
||||
.wide_and_then(|state| {
|
||||
self.services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, state.values().map(Borrow::borrow))
|
||||
.event_ids_iter(room_id, room_version, state.values().map(Borrow::borrow))
|
||||
.try_collect::<AuthSet<OwnedEventId>>()
|
||||
})
|
||||
.ready_filter_map(Result::ok);
|
||||
@@ -64,7 +64,7 @@ pub async fn resolve_state(
|
||||
|
||||
trace!("Resolving state");
|
||||
let state = self
|
||||
.state_resolution(room_version_id, fork_states, auth_chain_sets)
|
||||
.state_resolution(room_version, fork_states, auth_chain_sets)
|
||||
.await?;
|
||||
|
||||
trace!("State resolution done.");
|
||||
|
||||
@@ -104,7 +104,7 @@ pub(super) async fn state_at_incoming_resolved<Pdu>(
|
||||
&self,
|
||||
incoming_pdu: &Pdu,
|
||||
room_id: &RoomId,
|
||||
room_version_id: &RoomVersionId,
|
||||
room_version: &RoomVersionId,
|
||||
) -> Result<Option<HashMap<u64, OwnedEventId>>>
|
||||
where
|
||||
Pdu: Event,
|
||||
@@ -141,7 +141,7 @@ where
|
||||
.into_iter()
|
||||
.try_stream()
|
||||
.wide_and_then(|(sstatehash, prev_event)| {
|
||||
self.state_at_incoming_fork(room_id, sstatehash, prev_event)
|
||||
self.state_at_incoming_fork(room_id, room_version, sstatehash, prev_event)
|
||||
})
|
||||
.try_collect()
|
||||
.map_ok(Vec::into_iter)
|
||||
@@ -152,7 +152,7 @@ where
|
||||
|
||||
trace!("Resolving state");
|
||||
let Ok(new_state) = self
|
||||
.state_resolution(room_version_id, fork_states, auth_chain_sets)
|
||||
.state_resolution(room_version, fork_states, auth_chain_sets)
|
||||
.inspect_ok(|_| trace!("State resolution done."))
|
||||
.await
|
||||
else {
|
||||
@@ -189,6 +189,7 @@ where
|
||||
async fn state_at_incoming_fork<Pdu>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
room_version: &RoomVersionId,
|
||||
sstatehash: ShortStateHash,
|
||||
prev_event: Pdu,
|
||||
) -> Result<(StateMap<OwnedEventId>, AuthSet<OwnedEventId>)>
|
||||
@@ -228,7 +229,7 @@ where
|
||||
let auth_chain = self
|
||||
.services
|
||||
.auth_chain
|
||||
.event_ids_iter(room_id, starting_events)
|
||||
.event_ids_iter(room_id, room_version, starting_events)
|
||||
.try_collect();
|
||||
|
||||
let fork_state = leaf_state_after_event
|
||||
|
||||
Reference in New Issue
Block a user