Encapsulate incoming pdu formatting and checks within constructor.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-03-11 13:42:17 +00:00
parent a554280559
commit e9864bc4e7
13 changed files with 110 additions and 79 deletions

View File

@@ -553,9 +553,11 @@ pub(super) async fn force_set_room_state_from_server(
}; };
let pdu = if value["type"] == "m.room.create" { let pdu = if value["type"] == "m.room.create" {
PduEvent::from_rid_val(&room_id, &event_id, value.clone()).map_err(invalid_pdu_err)? PduEvent::from_object_and_roomid_and_eventid(&room_id, &event_id, value.clone())
.map_err(invalid_pdu_err)?
} else { } else {
PduEvent::from_id_val(&event_id, value.clone()).map_err(invalid_pdu_err)? PduEvent::from_object_and_eventid(&event_id, value.clone())
.map_err(invalid_pdu_err)?
}; };
if !value.contains_key("room_id") { if !value.contains_key("room_id") {

View File

@@ -1,6 +1,6 @@
mod builder; mod builder;
mod count; mod count;
pub mod format; mod format;
mod hashes; mod hashes;
mod id; mod id;
mod raw_id; mod raw_id;
@@ -13,6 +13,7 @@ use std::cmp::Ordering;
use ruma::{ use ruma::{
CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, UInt, UserId, events::TimelineEventType, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, UInt, UserId, events::TimelineEventType,
room_version_rules::RoomVersionRules,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;
@@ -22,7 +23,10 @@ pub use self::{
Count as PduCount, Id as PduId, Pdu as PduEvent, RawId as RawPduId, Count as PduCount, Id as PduId, Pdu as PduEvent, RawId as RawPduId,
builder::{Builder, Builder as PduBuilder}, builder::{Builder, Builder as PduBuilder},
count::Count, count::Count,
format::check::check_pdu_format, format::{
check::{check_room_id, check_rules},
from_incoming_federation, into_outgoing_federation,
},
hashes::EventHashes as EventHash, hashes::EventHashes as EventHash,
id::Id, id::Id,
raw_id::*, raw_id::*,
@@ -99,28 +103,61 @@ pub const MAX_PREV_EVENTS: usize = 20;
pub const MAX_AUTH_EVENTS: usize = 10; pub const MAX_AUTH_EVENTS: usize = 10;
impl Pdu { impl Pdu {
pub fn from_rid_val( pub fn from_object_and_roomid_and_eventid(
room_id: &RoomId, room_id: &RoomId,
event_id: &EventId, event_id: &EventId,
mut json: CanonicalJsonObject, mut json: CanonicalJsonObject,
) -> Result<Self> { ) -> Result<Self> {
let room_id = CanonicalJsonValue::String(room_id.into()); let room_id = CanonicalJsonValue::String(room_id.into());
json.insert("room_id".into(), room_id); json.insert("room_id".into(), room_id);
Self::from_object_and_eventid(event_id, json)
Self::from_id_val(event_id, json)
} }
pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self> { pub fn from_object_and_eventid(
event_id: &EventId,
mut json: CanonicalJsonObject,
) -> Result<Self> {
let event_id = CanonicalJsonValue::String(event_id.into()); let event_id = CanonicalJsonValue::String(event_id.into());
json.insert("event_id".into(), event_id); json.insert("event_id".into(), event_id);
Self::from_object(json)
Self::from_val(&json)
} }
pub fn from_val(json: &CanonicalJsonObject) -> Result<Self> { pub fn from_object_federation(
serde_json::to_value(json) room_id: &RoomId,
.and_then(serde_json::from_value) event_id: &EventId,
.map_err(Into::into) json: CanonicalJsonObject,
rules: &RoomVersionRules,
) -> Result<(Self, CanonicalJsonObject)> {
let json = from_incoming_federation(room_id, event_id, json, rules);
let pdu = Self::from_object_checked(json.clone(), rules)?;
check_room_id(&pdu, room_id)?;
Ok((pdu, json))
}
pub fn from_object_checked(
json: CanonicalJsonObject,
rules: &RoomVersionRules,
) -> Result<Self> {
check_rules(&json, &rules.event_format)?;
Self::from_object(json)
}
pub fn from_object(json: CanonicalJsonObject) -> Result<Self> {
let json = CanonicalJsonValue::Object(json);
Self::from_value(json)
}
pub fn from_raw_value(json: &RawJsonValue) -> Result<Self> {
let json: CanonicalJsonValue = json.into();
Self::from_value(json)
}
pub fn from_value(json: CanonicalJsonValue) -> Result<Self> {
serde_json::from_value(json.into()).map_err(Into::into)
}
pub fn from_raw_json(json: &RawJsonValue) -> Result<Self> {
Self::deserialize(json).map_err(Into::into)
} }
} }

View File

@@ -5,11 +5,9 @@ use ruma::{
room_version_rules::{EventsReferenceFormatVersion, RoomVersionRules}, room_version_rules::{EventsReferenceFormatVersion, RoomVersionRules},
}; };
use crate::{ use crate::{extract_variant, is_equal_to, matrix::room_version};
Result, extract_variant, is_equal_to,
matrix::{PduEvent, room_version},
};
#[must_use]
pub fn into_outgoing_federation( pub fn into_outgoing_federation(
mut pdu_json: CanonicalJsonObject, mut pdu_json: CanonicalJsonObject,
room_version: &RoomVersionId, room_version: &RoomVersionId,
@@ -68,12 +66,13 @@ fn mutate_outgoing_reference_format(value: &mut CanonicalJsonValue) {
}); });
} }
#[must_use]
pub fn from_incoming_federation( pub fn from_incoming_federation(
room_id: &RoomId, room_id: &RoomId,
event_id: &EventId, event_id: &EventId,
pdu_json: &mut CanonicalJsonObject, mut pdu_json: CanonicalJsonObject,
room_rules: &RoomVersionRules, room_rules: &RoomVersionRules,
) -> Result<PduEvent> { ) -> CanonicalJsonObject {
if matches!(room_rules.events_reference_format, EventsReferenceFormatVersion::V1) { if matches!(room_rules.events_reference_format, EventsReferenceFormatVersion::V1) {
if let Some(value) = pdu_json.get_mut("auth_events") { if let Some(value) = pdu_json.get_mut("auth_events") {
mutate_incoming_reference_format(value); mutate_incoming_reference_format(value);
@@ -95,9 +94,7 @@ pub fn from_incoming_federation(
pdu_json.insert("event_id".into(), CanonicalJsonValue::String(event_id.into())); pdu_json.insert("event_id".into(), CanonicalJsonValue::String(event_id.into()));
} }
check::check_pdu_format(pdu_json, &room_rules.event_format)?; pdu_json
PduEvent::from_val(pdu_json)
} }
fn mutate_incoming_reference_format(value: &mut CanonicalJsonValue) { fn mutate_incoming_reference_format(value: &mut CanonicalJsonValue) {

View File

@@ -4,10 +4,21 @@ use ruma::{
}; };
use serde_json::to_string as to_json_string; use serde_json::to_string as to_json_string;
use crate::{ use super::super::{MAX_AUTH_EVENTS, MAX_PDU_BYTES, MAX_PREV_EVENTS, Pdu};
Err, Result, err, use crate::{Err, Result, err};
matrix::pdu::{MAX_AUTH_EVENTS, MAX_PDU_BYTES, MAX_PREV_EVENTS},
}; pub fn check_room_id(pdu: &Pdu, room_id: &RoomId) -> Result {
if pdu.room_id != room_id {
return Err!(Request(InvalidParam(error!(
pdu_event_id = ?pdu.event_id,
pdu_room_id = ?pdu.room_id,
?room_id,
"Event in wrong room",
))));
}
Ok(())
}
/// Check that the given canonicalized PDU respects the event format of the room /// Check that the given canonicalized PDU respects the event format of the room
/// version and the [size limits] from the Matrix specification. /// version and the [size limits] from the Matrix specification.
@@ -31,7 +42,7 @@ use crate::{
/// ///
/// [size limits]: https://spec.matrix.org/latest/client-server-api/#size-limits /// [size limits]: https://spec.matrix.org/latest/client-server-api/#size-limits
/// [checks performed on receipt of a PDU]: https://spec.matrix.org/latest/server-server-api/#checks-performed-on-receipt-of-a-pdu /// [checks performed on receipt of a PDU]: https://spec.matrix.org/latest/server-server-api/#checks-performed-on-receipt-of-a-pdu
pub fn check_pdu_format(pdu: &CanonicalJsonObject, rules: &EventFormatRules) -> Result { pub fn check_rules(pdu: &CanonicalJsonObject, rules: &EventFormatRules) -> Result {
// Check the PDU size, it must occur on the full PDU with signatures. // Check the PDU size, it must occur on the full PDU with signatures.
let json = to_json_string(&pdu) let json = to_json_string(&pdu)
.map_err(|e| err!(Request(BadJson("Failed to serialize canonical JSON: {e}"))))?; .map_err(|e| err!(Request(BadJson("Failed to serialize canonical JSON: {e}"))))?;
@@ -200,7 +211,7 @@ mod tests {
}; };
use serde_json::{from_value as from_json_value, json}; use serde_json::{from_value as from_json_value, json};
use super::check_pdu_format; use super::check_rules as check_pdu_format;
/// Construct a PDU valid for the event format of room v1. /// Construct a PDU valid for the event format of room v1.
fn pdu_v1() -> CanonicalJsonObject { fn pdu_v1() -> CanonicalJsonObject {

View File

@@ -35,7 +35,7 @@ pub async fn format_pdu_into(
.as_ref() .as_ref()
.or(room_version) .or(room_version)
{ {
pdu_json = pdu::format::into_outgoing_federation(pdu_json, room_version); pdu_json = pdu::into_outgoing_federation(pdu_json, room_version);
} else { } else {
pdu_json.remove("event_id"); pdu_json.remove("event_id");
} }

View File

@@ -29,7 +29,7 @@ use serde_json::value::RawValue as RawJsonValue;
use tuwunel_core::{ use tuwunel_core::{
Err, Result, at, debug, debug_error, debug_info, debug_warn, err, error, implement, info, Err, Result, at, debug, debug_error, debug_info, debug_warn, err, error, implement, info,
matrix::{event::gen_event_id_canonical_json, room_version}, matrix::{event::gen_event_id_canonical_json, room_version},
pdu::{PduBuilder, check_pdu_format, format::from_incoming_federation}, pdu::{Pdu, PduBuilder, check_rules},
trace, trace,
utils::{self, BoolExt, IterStream, ReadyExt, future::TryExtExt, math::Expected, shuffle}, utils::{self, BoolExt, IterStream, ReadyExt, future::TryExtExt, math::Expected, shuffle},
warn, warn,
@@ -310,8 +310,8 @@ pub async fn join_remote(
%shortroomid, %shortroomid,
"Initialized room. Parsing join event..." "Initialized room. Parsing join event..."
); );
let parsed_join_pdu = let (parsed_join_pdu, join_event) =
from_incoming_federation(room_id, &event_id, &mut join_event, &room_version_rules)?; Pdu::from_object_federation(room_id, &event_id, join_event, &room_version_rules)?;
let resp_state = &response.state; let resp_state = &response.state;
let resp_auth = &response.auth_chain; let resp_auth = &response.auth_chain;
@@ -337,12 +337,12 @@ pub async fn join_remote(
}) })
.inspect_err(|e| debug_error!("Invalid send_join state event: {e:?}")) .inspect_err(|e| debug_error!("Invalid send_join state event: {e:?}"))
.ready_filter_map(Result::ok) .ready_filter_map(Result::ok)
.ready_filter_map(|(event_id, mut value)| { .ready_filter_map(|(event_id, value)| {
from_incoming_federation(room_id, &event_id, &mut value, &room_version_rules) Pdu::from_object_federation(room_id, &event_id, value, &room_version_rules)
.inspect_err(|e| { .inspect_err(|e| {
debug_warn!("Invalid PDU in send_join response: {e:?}: {value:#?}"); debug_warn!("Invalid PDU {event_id:?} in send_join response: {e:?}");
}) })
.map(move |pdu| (event_id, pdu, value)) .map(move |(pdu, value)| (event_id, pdu, value))
.ok() .ok()
}) })
.fold(HashMap::new(), async |mut state, (event_id, pdu, value)| { .fold(HashMap::new(), async |mut state, (event_id, pdu, value)| {
@@ -757,7 +757,7 @@ async fn create_join_event(
.server_keys .server_keys
.gen_id_hash_and_sign_event(&mut event, room_version_id)?; .gen_id_hash_and_sign_event(&mut event, room_version_id)?;
check_pdu_format(&event, &room_version_rules.event_format)?; check_rules(&event, &room_version_rules.event_format)?;
Ok((event, event_id, join_authorized_via_users_server)) Ok((event, event_id, join_authorized_via_users_server))
} }

View File

@@ -308,7 +308,7 @@ async fn knock_room_helper_local(
info!("Parsing knock event"); info!("Parsing knock event");
let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone()) let parsed_knock_pdu = PduEvent::from_object_and_eventid(&event_id, knock_event.clone())
.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?; .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
info!("Updating membership locally to knock state with provided stripped state events"); info!("Updating membership locally to knock state with provided stripped state events");
@@ -480,7 +480,7 @@ async fn knock_room_helper_remote(
.await; .await;
info!("Parsing knock event"); info!("Parsing knock event");
let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone()) let parsed_knock_pdu = PduEvent::from_object_and_eventid(&event_id, knock_event.clone())
.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?; .map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
info!("Going through send_knock response knock state events"); info!("Going through send_knock response knock state events");

View File

@@ -16,7 +16,7 @@ use ruma::{
}; };
use tuwunel_core::{ use tuwunel_core::{
Err, Result, debug_info, debug_warn, err, implement, Err, Result, debug_info, debug_warn, err, implement,
matrix::{PduCount, pdu::check_pdu_format, room_version}, matrix::{PduCount, pdu::check_rules, room_version},
pdu::PduBuilder, pdu::PduBuilder,
utils::{ utils::{
self, FutureBoolExt, self, FutureBoolExt,
@@ -353,7 +353,7 @@ async fn remote_leave(
.server_keys .server_keys
.gen_id_hash_and_sign_event(&mut event, &room_version_id)?; .gen_id_hash_and_sign_event(&mut event, &room_version_id)?;
check_pdu_format(&event, &room_version_rules.event_format)?; check_rules(&event, &room_version_rules.event_format)?;
self.services self.services
.federation .federation

View File

@@ -7,11 +7,13 @@ use ruma::{
}; };
use tuwunel_core::{ use tuwunel_core::{
Result, debug_warn, err, implement, Result, debug_warn, err, implement,
matrix::{Event, PduEvent, pdu::MAX_PREV_EVENTS}, matrix::{
Event, PduEvent,
pdu::{MAX_PREV_EVENTS, check_room_id},
},
utils::stream::IterStream, utils::stream::IterStream,
}; };
use super::check_room_id;
use crate::rooms::state_res; use crate::rooms::state_res;
#[implement(super::Service)] #[implement(super::Service)]
@@ -69,7 +71,7 @@ where
continue; continue;
}; };
check_room_id(room_id, &pdu)?; check_room_id(&pdu, room_id)?;
let limit = self.services.server.config.max_fetch_prev_events; let limit = self.services.server.config.max_fetch_prev_events;
if amount > limit { if amount > limit {

View File

@@ -5,13 +5,11 @@ use ruma::{
use tuwunel_core::{ use tuwunel_core::{
Err, Result, debug, debug_info, err, implement, Err, Result, debug, debug_info, err, implement,
matrix::{Event, PduEvent, event::TypeExt, room_version}, matrix::{Event, PduEvent, event::TypeExt, room_version},
pdu::{check_pdu_format, format::from_incoming_federation},
ref_at, trace, ref_at, trace,
utils::{future::TryExtExt, stream::IterStream}, utils::{future::TryExtExt, stream::IterStream},
warn, warn,
}; };
use super::check_room_id;
use crate::rooms::state_res; use crate::rooms::state_res;
#[implement(super::Service)] #[implement(super::Service)]
@@ -41,7 +39,7 @@ pub(super) async fn handle_outlier_pdu(
// anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json // anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json
// 2. Check signatures, otherwise drop // 2. Check signatures, otherwise drop
// 3. check content hash, redact if doesn't match // 3. check content hash, redact if doesn't match
let mut pdu_json = match self let pdu_json = match self
.services .services
.server_keys .server_keys
.verify_event(&pdu_json, Some(room_version)) .verify_event(&pdu_json, Some(room_version))
@@ -57,7 +55,8 @@ pub(super) async fn handle_outlier_pdu(
))); )));
}; };
let Ok(obj) = ruma::canonical_json::redact(pdu_json, &rules.redaction, None) else { let Ok(pdu_json) = ruma::canonical_json::redact(pdu_json, &rules.redaction, None)
else {
return Err!(Request(InvalidParam("Redaction failed"))); return Err!(Request(InvalidParam("Redaction failed")));
}; };
@@ -68,7 +67,7 @@ pub(super) async fn handle_outlier_pdu(
))); )));
} }
obj pdu_json
}, },
| Err(e) => { | Err(e) => {
return Err!(Request(InvalidParam(debug_error!( return Err!(Request(InvalidParam(debug_error!(
@@ -77,15 +76,11 @@ pub(super) async fn handle_outlier_pdu(
}, },
}; };
let room_rules = room_version::rules(room_version)?;
check_pdu_format(&pdu_json, &room_rules.event_format)?;
// Now that we have checked the signature and hashes we can make mutations and // Now that we have checked the signature and hashes we can make mutations and
// convert to our PduEvent type. // convert to our PduEvent type.
let event = from_incoming_federation(room_id, event_id, &mut pdu_json, &room_rules)?; let room_rules = room_version::rules(room_version)?;
let (event, pdu_json) =
check_room_id(room_id, &event)?; PduEvent::from_object_federation(room_id, event_id, pdu_json, &room_rules)?;
if !auth_events_known { if !auth_events_known {
// 4. fetch any missing auth events doing all checks listed here starting at 1. // 4. fetch any missing auth events doing all checks listed here starting at 1.

View File

@@ -19,10 +19,10 @@ use std::{
}; };
use async_trait::async_trait; use async_trait::async_trait;
use ruma::{EventId, OwnedEventId, OwnedRoomId, RoomId}; use ruma::{EventId, OwnedEventId, OwnedRoomId};
use tuwunel_core::{ use tuwunel_core::{
Err, Result, implement, Result, implement,
matrix::{Event, PduEvent}, matrix::PduEvent,
utils::{MutexMap, bytes::pretty, continue_exponential_backoff}, utils::{MutexMap, bytes::pretty, continue_exponential_backoff},
}; };
@@ -146,16 +146,3 @@ async fn event_exists(&self, event_id: &EventId) -> bool {
async fn event_fetch(&self, event_id: &EventId) -> Result<PduEvent> { async fn event_fetch(&self, event_id: &EventId) -> Result<PduEvent> {
self.services.timeline.get_pdu(event_id).await self.services.timeline.get_pdu(event_id).await
} }
fn check_room_id<Pdu: Event>(room_id: &RoomId, pdu: &Pdu) -> Result {
if pdu.room_id() != room_id {
return Err!(Request(InvalidParam(error!(
pdu_event_id = ?pdu.event_id(),
pdu_room_id = ?pdu.room_id(),
?room_id,
"Found event from room in room",
))));
}
Ok(())
}

View File

@@ -7,7 +7,7 @@ use ruma::{
}; };
use tuwunel_core::{ use tuwunel_core::{
Err, Result, debug, debug_info, err, implement, is_equal_to, Err, Result, debug, debug_info, err, implement, is_equal_to,
matrix::{Event, EventTypeExt, PduEvent, StateKey, pdu::check_pdu_format, room_version}, matrix::{Event, EventTypeExt, PduEvent, StateKey, pdu::check_rules, room_version},
trace, trace,
utils::stream::{BroadbandExt, ReadyExt}, utils::stream::{BroadbandExt, ReadyExt},
warn, warn,
@@ -33,7 +33,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
origin: &ServerName, origin: &ServerName,
room_id: &RoomId, room_id: &RoomId,
incoming_pdu: PduEvent, incoming_pdu: PduEvent,
val: CanonicalJsonObject, pdu_json: CanonicalJsonObject,
room_version: &RoomVersionId, room_version: &RoomVersionId,
recursion_level: usize, recursion_level: usize,
create_event_id: &EventId, create_event_id: &EventId,
@@ -64,7 +64,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
let room_rules = room_version::rules(room_version)?; let room_rules = room_version::rules(room_version)?;
trace!(format = ?room_rules.event_format, "Checking format"); trace!(format = ?room_rules.event_format, "Checking format");
check_pdu_format(&val, &room_rules.event_format)?; check_rules(&pdu_json, &room_rules.event_format)?;
// 10. Fetch missing state and auth chain events by calling /state_ids at // 10. Fetch missing state and auth chain events by calling /state_ids at
// backwards extremities doing all the checks in this list starting at 1. // backwards extremities doing all the checks in this list starting at 1.
@@ -280,7 +280,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
.timeline .timeline
.append_incoming_pdu( .append_incoming_pdu(
&incoming_pdu, &incoming_pdu,
val, pdu_json,
extremities, extremities,
state_ids_compressed, state_ids_compressed,
soft_fail, soft_fail,

View File

@@ -13,7 +13,7 @@ use tuwunel_core::{
Error, Result, err, implement, Error, Result, err, implement,
matrix::{ matrix::{
event::{Event, StateKey, TypeExt}, event::{Event, StateKey, TypeExt},
pdu::{EventHash, PduBuilder, PduEvent, PrevEvents, check_pdu_format}, pdu::{EventHash, PduBuilder, PduEvent, PrevEvents, check_rules},
room_version, room_version,
}, },
utils::{ utils::{
@@ -195,7 +195,7 @@ pub async fn create_hash_and_sign_event(
pdu_json.insert("room_id".into(), CanonicalJsonValue::String(pdu.room_id.clone().into())); pdu_json.insert("room_id".into(), CanonicalJsonValue::String(pdu.room_id.clone().into()));
} }
check_pdu_format(&pdu_json, &version_rules.event_format)?; check_rules(&pdu_json, &version_rules.event_format)?;
// Generate short event id // Generate short event id
let _shorteventid = self let _shorteventid = self