Hoist room_version query to callers of get_auth_chain.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -57,11 +57,17 @@ pub(super) async fn get_auth_chain(&self, event_id: OwnedEventId) -> Result {
|
|||||||
let room_id = <&RoomId>::try_from(room_id_str)
|
let room_id = <&RoomId>::try_from(room_id_str)
|
||||||
.map_err(|_| err!(Database("Invalid room id field in event in database")))?;
|
.map_err(|_| err!(Database("Invalid room id field in event in database")))?;
|
||||||
|
|
||||||
|
let room_version = self
|
||||||
|
.services
|
||||||
|
.state
|
||||||
|
.get_room_version(room_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let count = self
|
let count = self
|
||||||
.services
|
.services
|
||||||
.auth_chain
|
.auth_chain
|
||||||
.event_ids_iter(room_id, once(event_id.as_ref()))
|
.event_ids_iter(room_id, &room_version, once(event_id.as_ref()))
|
||||||
.ready_filter_map(Result::ok)
|
.ready_filter_map(Result::ok)
|
||||||
.count()
|
.count()
|
||||||
.await;
|
.await;
|
||||||
|
|||||||
@@ -46,22 +46,18 @@ pub(crate) async fn get_event_authorization_route(
|
|||||||
let room_id = <&RoomId>::try_from(room_id_str)
|
let room_id = <&RoomId>::try_from(room_id_str)
|
||||||
.map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
|
.map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
|
||||||
|
|
||||||
let room_version = services
|
let room_version = services.state.get_room_version(room_id).await?;
|
||||||
.state
|
|
||||||
.get_room_version(room_id)
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
let auth_chain = services
|
let auth_chain = services
|
||||||
.auth_chain
|
.auth_chain
|
||||||
.event_ids_iter(room_id, once(body.event_id.borrow()))
|
.event_ids_iter(room_id, &room_version, once(body.event_id.borrow()))
|
||||||
.ready_filter_map(Result::ok)
|
.ready_filter_map(Result::ok)
|
||||||
.broad_filter_map(async |id| {
|
.broad_filter_map(async |id| {
|
||||||
let pdu = services.timeline.get_pdu_json(&id).await.ok()?;
|
let pdu = services.timeline.get_pdu_json(&id).await.ok()?;
|
||||||
|
|
||||||
let pdu = services
|
let pdu = services
|
||||||
.federation
|
.federation
|
||||||
.format_pdu_into(pdu, room_version.as_ref())
|
.format_pdu_into(pdu, Some(&room_version))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
Some(pdu)
|
Some(pdu)
|
||||||
|
|||||||
@@ -50,9 +50,9 @@ async fn create_join_event(
|
|||||||
|
|
||||||
// We do not add the event_id field to the pdu here because of signature and
|
// We do not add the event_id field to the pdu here because of signature and
|
||||||
// hashes checks
|
// hashes checks
|
||||||
let room_version_id = services.state.get_room_version(room_id).await?;
|
let room_version = services.state.get_room_version(room_id).await?;
|
||||||
|
|
||||||
let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
|
let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version) else {
|
||||||
// Event could not be converted to canonical json
|
// Event could not be converted to canonical json
|
||||||
return Err!(Request(BadJson("Could not convert event to canonical json.")));
|
return Err!(Request(BadJson("Could not convert event to canonical json.")));
|
||||||
};
|
};
|
||||||
@@ -136,9 +136,9 @@ async fn create_join_event(
|
|||||||
if let Some(authorising_user) = content.join_authorized_via_users_server {
|
if let Some(authorising_user) = content.join_authorized_via_users_server {
|
||||||
use ruma::RoomVersionId::*;
|
use ruma::RoomVersionId::*;
|
||||||
|
|
||||||
if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6 | V7) {
|
if matches!(room_version, V1 | V2 | V3 | V4 | V5 | V6 | V7) {
|
||||||
return Err!(Request(InvalidParam(
|
return Err!(Request(InvalidParam(
|
||||||
"Room version {room_version_id} does not support restricted rooms but \
|
"Room version {room_version} does not support restricted rooms but \
|
||||||
join_authorised_via_users_server ({authorising_user}) was found in the event."
|
join_authorised_via_users_server ({authorising_user}) was found in the event."
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
@@ -161,13 +161,8 @@ async fn create_join_event(
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if !super::user_can_perform_restricted_join(
|
if !super::user_can_perform_restricted_join(services, &state_key, room_id, &room_version)
|
||||||
services,
|
.await?
|
||||||
&state_key,
|
|
||||||
room_id,
|
|
||||||
&room_version_id,
|
|
||||||
)
|
|
||||||
.await?
|
|
||||||
{
|
{
|
||||||
return Err!(Request(UnableToAuthorizeJoin(
|
return Err!(Request(UnableToAuthorizeJoin(
|
||||||
"Joining user did not pass restricted room's rules."
|
"Joining user did not pass restricted room's rules."
|
||||||
@@ -177,7 +172,7 @@ async fn create_join_event(
|
|||||||
|
|
||||||
services
|
services
|
||||||
.server_keys
|
.server_keys
|
||||||
.hash_and_sign_event(&mut value, &room_version_id)
|
.hash_and_sign_event(&mut value, &room_version)
|
||||||
.map_err(|e| err!(Request(InvalidParam(warn!("Failed to sign send_join event: {e}")))))?;
|
.map_err(|e| err!(Request(InvalidParam(warn!("Failed to sign send_join event: {e}")))))?;
|
||||||
|
|
||||||
let origin: OwnedServerName = serde_json::from_value(
|
let origin: OwnedServerName = serde_json::from_value(
|
||||||
@@ -216,7 +211,7 @@ async fn create_join_event(
|
|||||||
// Join event for new server.
|
// Join event for new server.
|
||||||
let event = services
|
let event = services
|
||||||
.federation
|
.federation
|
||||||
.format_pdu_into(value, Some(&room_version_id))
|
.format_pdu_into(value, Some(&room_version))
|
||||||
.map(Some)
|
.map(Some)
|
||||||
.map(Ok);
|
.map(Ok);
|
||||||
|
|
||||||
@@ -229,13 +224,13 @@ async fn create_join_event(
|
|||||||
let into_federation_format = |pdu| {
|
let into_federation_format = |pdu| {
|
||||||
services
|
services
|
||||||
.federation
|
.federation
|
||||||
.format_pdu_into(pdu, Some(&room_version_id))
|
.format_pdu_into(pdu, Some(&room_version))
|
||||||
.map(Ok)
|
.map(Ok)
|
||||||
};
|
};
|
||||||
|
|
||||||
let auth_chain = services
|
let auth_chain = services
|
||||||
.auth_chain
|
.auth_chain
|
||||||
.event_ids_iter(room_id, auth_heads)
|
.event_ids_iter(room_id, &room_version, auth_heads)
|
||||||
.broad_and_then(async |event_id| {
|
.broad_and_then(async |event_id| {
|
||||||
services
|
services
|
||||||
.timeline
|
.timeline
|
||||||
|
|||||||
@@ -1,17 +1,11 @@
|
|||||||
use std::{borrow::Borrow, iter::once};
|
use std::{borrow::Borrow, iter::once};
|
||||||
|
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
use futures::{
|
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
|
||||||
FutureExt, StreamExt, TryFutureExt, TryStreamExt,
|
|
||||||
future::{join, try_join},
|
|
||||||
};
|
|
||||||
use ruma::{OwnedEventId, api::federation::event::get_room_state};
|
use ruma::{OwnedEventId, api::federation::event::get_room_state};
|
||||||
use tuwunel_core::{
|
use tuwunel_core::{
|
||||||
Result, at, err,
|
Result, at, err,
|
||||||
utils::{
|
utils::stream::{IterStream, TryBroadbandExt},
|
||||||
future::TryExtExt,
|
|
||||||
stream::{IterStream, TryBroadbandExt},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::AccessCheck;
|
use super::AccessCheck;
|
||||||
@@ -43,25 +37,23 @@ pub(crate) async fn get_room_state_route(
|
|||||||
.state_accessor
|
.state_accessor
|
||||||
.state_full_ids(shortstatehash)
|
.state_full_ids(shortstatehash)
|
||||||
.map(at!(1))
|
.map(at!(1))
|
||||||
.collect::<Vec<OwnedEventId>>();
|
.collect::<Vec<OwnedEventId>>()
|
||||||
|
.map(Ok);
|
||||||
|
|
||||||
let room_version = services
|
let room_version = services.state.get_room_version(&body.room_id);
|
||||||
.state
|
|
||||||
.get_room_version(&body.room_id)
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
let (room_version, state_ids) = join(room_version, state_ids).await;
|
let (room_version, state_ids) = try_join(room_version, state_ids).await?;
|
||||||
|
|
||||||
let into_federation_format = |pdu| {
|
let into_federation_format = |pdu| {
|
||||||
services
|
services
|
||||||
.federation
|
.federation
|
||||||
.format_pdu_into(pdu, room_version.as_ref())
|
.format_pdu_into(pdu, Some(&room_version))
|
||||||
.map(Ok)
|
.map(Ok)
|
||||||
};
|
};
|
||||||
|
|
||||||
let auth_chain = services
|
let auth_chain = services
|
||||||
.auth_chain
|
.auth_chain
|
||||||
.event_ids_iter(&body.room_id, once(body.event_id.borrow()))
|
.event_ids_iter(&body.room_id, &room_version, once(body.event_id.borrow()))
|
||||||
.broad_and_then(async |id| {
|
.broad_and_then(async |id| {
|
||||||
services
|
services
|
||||||
.timeline
|
.timeline
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::{borrow::Borrow, iter::once};
|
use std::{borrow::Borrow, iter::once};
|
||||||
|
|
||||||
use axum::extract::State;
|
use axum::extract::State;
|
||||||
use futures::{FutureExt, StreamExt, TryStreamExt, future::try_join};
|
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join};
|
||||||
use ruma::{OwnedEventId, api::federation::event::get_room_state_ids};
|
use ruma::{OwnedEventId, api::federation::event::get_room_state_ids};
|
||||||
use tuwunel_core::{Result, at, err};
|
use tuwunel_core::{Result, at, err};
|
||||||
|
|
||||||
@@ -28,12 +28,15 @@ pub(crate) async fn get_room_state_ids_route(
|
|||||||
let shortstatehash = services
|
let shortstatehash = services
|
||||||
.state
|
.state
|
||||||
.pdu_shortstatehash(&body.event_id)
|
.pdu_shortstatehash(&body.event_id)
|
||||||
.await
|
.map_err(|_| err!(Request(NotFound("Pdu state not found."))));
|
||||||
.map_err(|_| err!(Request(NotFound("Pdu state not found."))))?;
|
|
||||||
|
let room_version = services.state.get_room_version(&body.room_id);
|
||||||
|
|
||||||
|
let (shortstatehash, room_version) = try_join(shortstatehash, room_version).await?;
|
||||||
|
|
||||||
let auth_chain_ids = services
|
let auth_chain_ids = services
|
||||||
.auth_chain
|
.auth_chain
|
||||||
.event_ids_iter(&body.room_id, once(body.event_id.borrow()))
|
.event_ids_iter(&body.room_id, &room_version, once(body.event_id.borrow()))
|
||||||
.try_collect();
|
.try_collect();
|
||||||
|
|
||||||
let pdu_ids = services
|
let pdu_ids = services
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ use futures::{
|
|||||||
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
|
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut,
|
||||||
stream::{FuturesUnordered, unfold},
|
stream::{FuturesUnordered, unfold},
|
||||||
};
|
};
|
||||||
use ruma::{EventId, OwnedEventId, RoomId, room_version_rules::RoomVersionRules};
|
use ruma::{EventId, OwnedEventId, RoomId, RoomVersionId, room_version_rules::RoomVersionRules};
|
||||||
use tuwunel_core::{
|
use tuwunel_core::{
|
||||||
Err, Result, at, debug, debug_error, err, implement,
|
Err, Result, at, debug, debug_error, err, implement,
|
||||||
matrix::{Event, PduEvent},
|
matrix::{Event, PduEvent, room_version},
|
||||||
pdu::AuthEvents,
|
pdu::AuthEvents,
|
||||||
trace, utils,
|
trace, utils,
|
||||||
utils::{
|
utils::{
|
||||||
@@ -60,12 +60,13 @@ impl crate::Service for Service {
|
|||||||
pub fn event_ids_iter<'a, I>(
|
pub fn event_ids_iter<'a, I>(
|
||||||
&'a self,
|
&'a self,
|
||||||
room_id: &'a RoomId,
|
room_id: &'a RoomId,
|
||||||
|
room_version: &'a RoomVersionId,
|
||||||
starting_events: I,
|
starting_events: I,
|
||||||
) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
|
) -> impl Stream<Item = Result<OwnedEventId>> + Send + 'a
|
||||||
where
|
where
|
||||||
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
|
||||||
{
|
{
|
||||||
self.get_auth_chain(room_id, starting_events)
|
self.get_auth_chain(room_id, room_version, starting_events)
|
||||||
.map_ok(|chain| {
|
.map_ok(|chain| {
|
||||||
self.services
|
self.services
|
||||||
.short
|
.short
|
||||||
@@ -85,6 +86,7 @@ where
|
|||||||
pub async fn get_auth_chain<'a, I>(
|
pub async fn get_auth_chain<'a, I>(
|
||||||
&'a self,
|
&'a self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
|
room_version: &RoomVersionId,
|
||||||
starting_events: I,
|
starting_events: I,
|
||||||
) -> Result<Vec<ShortEventId>>
|
) -> Result<Vec<ShortEventId>>
|
||||||
where
|
where
|
||||||
@@ -94,12 +96,7 @@ where
|
|||||||
const BUCKET: Bucket<'_> = BTreeSet::new();
|
const BUCKET: Bucket<'_> = BTreeSet::new();
|
||||||
|
|
||||||
let started = Instant::now();
|
let started = Instant::now();
|
||||||
let room_rules = self
|
let room_rules = room_version::rules(room_version)?;
|
||||||
.services
|
|
||||||
.state
|
|
||||||
.get_room_version_rules(room_id)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let starting_ids = self
|
let starting_ids = self
|
||||||
.services
|
.services
|
||||||
.short
|
.short
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ use crate::rooms::state_compressor::CompressedState;
|
|||||||
pub async fn resolve_state(
|
pub async fn resolve_state(
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
room_version_id: &RoomVersionId,
|
room_version: &RoomVersionId,
|
||||||
incoming_state: HashMap<u64, OwnedEventId>,
|
incoming_state: HashMap<u64, OwnedEventId>,
|
||||||
) -> Result<Arc<CompressedState>> {
|
) -> Result<Arc<CompressedState>> {
|
||||||
trace!("Loading current room state ids");
|
trace!("Loading current room state ids");
|
||||||
@@ -43,7 +43,7 @@ pub async fn resolve_state(
|
|||||||
.wide_and_then(|state| {
|
.wide_and_then(|state| {
|
||||||
self.services
|
self.services
|
||||||
.auth_chain
|
.auth_chain
|
||||||
.event_ids_iter(room_id, state.values().map(Borrow::borrow))
|
.event_ids_iter(room_id, room_version, state.values().map(Borrow::borrow))
|
||||||
.try_collect::<AuthSet<OwnedEventId>>()
|
.try_collect::<AuthSet<OwnedEventId>>()
|
||||||
})
|
})
|
||||||
.ready_filter_map(Result::ok);
|
.ready_filter_map(Result::ok);
|
||||||
@@ -64,7 +64,7 @@ pub async fn resolve_state(
|
|||||||
|
|
||||||
trace!("Resolving state");
|
trace!("Resolving state");
|
||||||
let state = self
|
let state = self
|
||||||
.state_resolution(room_version_id, fork_states, auth_chain_sets)
|
.state_resolution(room_version, fork_states, auth_chain_sets)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
trace!("State resolution done.");
|
trace!("State resolution done.");
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ pub(super) async fn state_at_incoming_resolved<Pdu>(
|
|||||||
&self,
|
&self,
|
||||||
incoming_pdu: &Pdu,
|
incoming_pdu: &Pdu,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
room_version_id: &RoomVersionId,
|
room_version: &RoomVersionId,
|
||||||
) -> Result<Option<HashMap<u64, OwnedEventId>>>
|
) -> Result<Option<HashMap<u64, OwnedEventId>>>
|
||||||
where
|
where
|
||||||
Pdu: Event,
|
Pdu: Event,
|
||||||
@@ -141,7 +141,7 @@ where
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.try_stream()
|
.try_stream()
|
||||||
.wide_and_then(|(sstatehash, prev_event)| {
|
.wide_and_then(|(sstatehash, prev_event)| {
|
||||||
self.state_at_incoming_fork(room_id, sstatehash, prev_event)
|
self.state_at_incoming_fork(room_id, room_version, sstatehash, prev_event)
|
||||||
})
|
})
|
||||||
.try_collect()
|
.try_collect()
|
||||||
.map_ok(Vec::into_iter)
|
.map_ok(Vec::into_iter)
|
||||||
@@ -152,7 +152,7 @@ where
|
|||||||
|
|
||||||
trace!("Resolving state");
|
trace!("Resolving state");
|
||||||
let Ok(new_state) = self
|
let Ok(new_state) = self
|
||||||
.state_resolution(room_version_id, fork_states, auth_chain_sets)
|
.state_resolution(room_version, fork_states, auth_chain_sets)
|
||||||
.inspect_ok(|_| trace!("State resolution done."))
|
.inspect_ok(|_| trace!("State resolution done."))
|
||||||
.await
|
.await
|
||||||
else {
|
else {
|
||||||
@@ -189,6 +189,7 @@ where
|
|||||||
async fn state_at_incoming_fork<Pdu>(
|
async fn state_at_incoming_fork<Pdu>(
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
|
room_version: &RoomVersionId,
|
||||||
sstatehash: ShortStateHash,
|
sstatehash: ShortStateHash,
|
||||||
prev_event: Pdu,
|
prev_event: Pdu,
|
||||||
) -> Result<(StateMap<OwnedEventId>, AuthSet<OwnedEventId>)>
|
) -> Result<(StateMap<OwnedEventId>, AuthSet<OwnedEventId>)>
|
||||||
@@ -228,7 +229,7 @@ where
|
|||||||
let auth_chain = self
|
let auth_chain = self
|
||||||
.services
|
.services
|
||||||
.auth_chain
|
.auth_chain
|
||||||
.event_ids_iter(room_id, starting_events)
|
.event_ids_iter(room_id, room_version, starting_events)
|
||||||
.try_collect();
|
.try_collect();
|
||||||
|
|
||||||
let fork_state = leaf_state_after_event
|
let fork_state = leaf_state_after_event
|
||||||
|
|||||||
Reference in New Issue
Block a user