Move state_res from tuwunel_core to tuwunel_service.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
41
src/service/rooms/state_res/resolve/auth_difference.rs
Normal file
41
src/service/rooms/state_res/resolve/auth_difference.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use futures::{FutureExt, Stream};
|
||||
use ruma::EventId;
|
||||
use tuwunel_core::utils::stream::{IterStream, ReadyExt};
|
||||
|
||||
use super::AuthSet;
|
||||
|
||||
/// Get the auth difference for the given auth chains.
|
||||
///
|
||||
/// Definition in the specification:
|
||||
///
|
||||
/// The auth difference is calculated by first calculating the full auth chain
|
||||
/// for each state _Si_, that is the union of the auth chains for each event in
|
||||
/// _Si_, and then taking every event that doesn’t appear in every auth chain.
|
||||
/// If _Ci_ is the full auth chain of _Si_, then the auth difference is ∪_Ci_ −
|
||||
/// ∩_Ci_.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `auth_chains` - The list of full recursive sets of `auth_events`. Inputs
|
||||
/// must be sorted.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// Outputs the event IDs that are not present in all the auth chains.
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
pub(super) fn auth_difference<'a, AuthSets, Id>(auth_sets: AuthSets) -> impl Stream<Item = Id>
|
||||
where
|
||||
AuthSets: Stream<Item = AuthSet<Id>>,
|
||||
Id: Borrow<EventId> + Clone + Eq + Ord + Send + 'a,
|
||||
{
|
||||
auth_sets
|
||||
.ready_fold_default(|ret: AuthSet<Id>, set| {
|
||||
ret.symmetric_difference(&set)
|
||||
.cloned()
|
||||
.collect::<AuthSet<Id>>()
|
||||
})
|
||||
.map(|set: AuthSet<Id>| set.into_iter().stream())
|
||||
.flatten_stream()
|
||||
}
|
||||
146
src/service/rooms/state_res/resolve/conflicted_subgraph.rs
Normal file
146
src/service/rooms/state_res/resolve/conflicted_subgraph.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use std::{
|
||||
collections::HashSet as Set,
|
||||
iter::once,
|
||||
mem::take,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use futures::{Future, FutureExt, Stream, StreamExt};
|
||||
use ruma::OwnedEventId;
|
||||
use tuwunel_core::{
|
||||
Result, debug,
|
||||
matrix::{Event, pdu::AuthEvents},
|
||||
smallvec::SmallVec,
|
||||
utils::stream::{IterStream, automatic_width},
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
struct Global {
|
||||
subgraph: Mutex<Set<OwnedEventId>>,
|
||||
seen: Mutex<Set<OwnedEventId>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct Local {
|
||||
path: Path,
|
||||
stack: Stack,
|
||||
}
|
||||
|
||||
type Path = SmallVec<[OwnedEventId; PATH_INLINE]>;
|
||||
type Stack = SmallVec<[Frame; STACK_INLINE]>;
|
||||
type Frame = AuthEvents;
|
||||
|
||||
const PATH_INLINE: usize = 48;
|
||||
const STACK_INLINE: usize = 48;
|
||||
|
||||
#[tracing::instrument(name = "subgraph_dfs", level = "debug", skip_all)]
|
||||
pub(super) fn conflicted_subgraph_dfs<Fetch, Fut, Pdu>(
|
||||
conflicted_event_ids: &Set<&OwnedEventId>,
|
||||
fetch: &Fetch,
|
||||
) -> impl Stream<Item = OwnedEventId> + Send
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
let state = Arc::new(Global::default());
|
||||
let state_ = state.clone();
|
||||
conflicted_event_ids
|
||||
.iter()
|
||||
.stream()
|
||||
.enumerate()
|
||||
.map(move |(i, event_id)| (state_.clone(), event_id, i))
|
||||
.for_each_concurrent(automatic_width(), async |(state, event_id, i)| {
|
||||
subgraph_descent(conflicted_event_ids, fetch, &state, event_id, i)
|
||||
.await
|
||||
.expect("only mutex errors expected");
|
||||
})
|
||||
.map(move |()| {
|
||||
let seen = state.seen.lock().expect("locked");
|
||||
let mut state = state.subgraph.lock().expect("locked");
|
||||
debug!(
|
||||
input_events = conflicted_event_ids.len(),
|
||||
seen_events = seen.len(),
|
||||
output_events = state.len(),
|
||||
"conflicted subgraph state"
|
||||
);
|
||||
|
||||
take(&mut *state).into_iter().stream()
|
||||
})
|
||||
.flatten_stream()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "descent",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
event_ids = conflicted_event_ids.len(),
|
||||
event_id = %conflicted_event_id,
|
||||
%i,
|
||||
)
|
||||
)]
|
||||
async fn subgraph_descent<Fetch, Fut, Pdu>(
|
||||
conflicted_event_ids: &Set<&OwnedEventId>,
|
||||
fetch: &Fetch,
|
||||
state: &Arc<Global>,
|
||||
conflicted_event_id: &OwnedEventId,
|
||||
i: usize,
|
||||
) -> Result
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
let Global { subgraph, seen } = &**state;
|
||||
|
||||
let mut local = Local {
|
||||
path: once(conflicted_event_id.clone()).collect(),
|
||||
stack: once(once(conflicted_event_id.clone()).collect()).collect(),
|
||||
};
|
||||
|
||||
while let Some(event_id) = pop(&mut local) {
|
||||
if subgraph.lock()?.contains(&event_id) {
|
||||
if local.path.len() > 1 {
|
||||
subgraph
|
||||
.lock()?
|
||||
.extend(local.path.iter().cloned());
|
||||
}
|
||||
|
||||
local.path.pop();
|
||||
continue;
|
||||
}
|
||||
|
||||
if !seen.lock()?.insert(event_id.clone()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if local.path.len() > 1 && conflicted_event_ids.contains(&event_id) {
|
||||
subgraph
|
||||
.lock()?
|
||||
.extend(local.path.iter().cloned());
|
||||
}
|
||||
|
||||
if let Ok(event) = fetch(event_id).await {
|
||||
local
|
||||
.stack
|
||||
.push(event.auth_events_into().into_iter().collect());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn pop(local: &mut Local) -> Option<OwnedEventId> {
|
||||
let Local { path, stack } = local;
|
||||
|
||||
while stack.last().is_some_and(Frame::is_empty) {
|
||||
stack.pop();
|
||||
path.pop();
|
||||
}
|
||||
|
||||
stack
|
||||
.last_mut()
|
||||
.and_then(Frame::pop)
|
||||
.inspect(|event_id| path.push(event_id.clone()))
|
||||
}
|
||||
209
src/service/rooms/state_res/resolve/iterative_auth_check.rs
Normal file
209
src/service/rooms/state_res/resolve/iterative_auth_check.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use ruma::{
|
||||
EventId, OwnedEventId,
|
||||
events::{StateEventType, TimelineEventType},
|
||||
room_version_rules::RoomVersionRules,
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Error, Result, debug_warn, err, error,
|
||||
matrix::{Event, EventTypeExt, StateKey},
|
||||
smallvec::SmallVec,
|
||||
trace,
|
||||
utils::stream::{IterStream, ReadyExt, TryReadyExt, TryWidebandExt},
|
||||
};
|
||||
|
||||
use super::{
|
||||
super::{auth_types_for_event, check_state_dependent_auth_rules},
|
||||
StateMap,
|
||||
};
|
||||
|
||||
/// Perform the iterative auth checks to the given list of events.
|
||||
///
|
||||
/// Definition in the specification:
|
||||
///
|
||||
/// The iterative auth checks algorithm takes as input an initial room state and
|
||||
/// a sorted list of state events, and constructs a new room state by iterating
|
||||
/// through the event list and applying the state event to the room state if the
|
||||
/// state event is allowed by the authorization rules. If the state event is not
|
||||
/// allowed by the authorization rules, then the event is ignored. If a
|
||||
/// (event_type, state_key) key that is required for checking the authorization
|
||||
/// rules is not present in the state, then the appropriate state event from the
|
||||
/// event’s auth_events is used if the auth event is not rejected.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `rules` - The authorization rules for the current room version.
|
||||
/// * `events` - The sorted state events to apply to the `partial_state`.
|
||||
/// * `state` - The current state that was partially resolved for the room.
|
||||
/// * `fetch_event` - Function to fetch an event in the room given its event ID.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// Returns the partially resolved state, or an `Err(_)` if one of the state
|
||||
/// events in the room has an unexpected format.
|
||||
#[tracing::instrument(
|
||||
name = "iterative_auth",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
states = ?state.len(),
|
||||
)
|
||||
)]
|
||||
pub(super) async fn iterative_auth_check<'b, SortedPowerEvents, Fetch, Fut, Pdu>(
|
||||
rules: &RoomVersionRules,
|
||||
events: SortedPowerEvents,
|
||||
state: StateMap<OwnedEventId>,
|
||||
fetch: &Fetch,
|
||||
) -> Result<StateMap<OwnedEventId>>
|
||||
where
|
||||
SortedPowerEvents: Stream<Item = &'b EventId> + Send,
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
events
|
||||
.map(Ok)
|
||||
.wide_and_then(async |event_id| {
|
||||
let event = fetch(event_id.to_owned()).await?;
|
||||
let state_key: StateKey = event
|
||||
.state_key()
|
||||
.ok_or_else(|| err!(Request(InvalidParam("Missing state_key"))))?
|
||||
.into();
|
||||
|
||||
Ok((event_id, state_key, event))
|
||||
})
|
||||
.try_fold(state, |state, (event_id, state_key, event)| {
|
||||
auth_check(rules, state, event_id, state_key, event, fetch)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "check",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
%event_id,
|
||||
?state_key,
|
||||
)
|
||||
)]
|
||||
async fn auth_check<Fetch, Fut, Pdu>(
|
||||
rules: &RoomVersionRules,
|
||||
mut state: StateMap<OwnedEventId>,
|
||||
event_id: &EventId,
|
||||
state_key: StateKey,
|
||||
event: Pdu,
|
||||
fetch: &Fetch,
|
||||
) -> Result<StateMap<OwnedEventId>>
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
let Ok(auth_types) = auth_types_for_event(
|
||||
event.event_type(),
|
||||
event.sender(),
|
||||
Some(&state_key),
|
||||
event.content(),
|
||||
&rules.authorization,
|
||||
true,
|
||||
)
|
||||
.inspect_err(|e| error!("failed to get auth types for event: {e}")) else {
|
||||
return Ok(state);
|
||||
};
|
||||
|
||||
let auth_types_events = auth_types
|
||||
.stream()
|
||||
.ready_filter_map(|key| {
|
||||
state
|
||||
.get(&key)
|
||||
.map(move |auth_event_id| (auth_event_id, key))
|
||||
})
|
||||
.filter_map(async |(id, key)| {
|
||||
fetch(id.clone())
|
||||
.inspect_err(|e| debug_warn!(%id, "missing auth event: {e}"))
|
||||
.inspect_err(|e| debug_assert!(!cfg!(test), "missing auth {id:?}: {e:?}"))
|
||||
.map_ok(move |auth_event| (key, auth_event))
|
||||
.await
|
||||
.ok()
|
||||
})
|
||||
.ready_filter_map(|(key, auth_event)| {
|
||||
auth_event
|
||||
.rejected()
|
||||
.eq(&false)
|
||||
.then_some((key, auth_event))
|
||||
})
|
||||
.map(Ok);
|
||||
|
||||
// If the `m.room.create` event is not in the auth events, we need to add it,
|
||||
// because it's always part of the state and required in the auth rules.
|
||||
let also_need_create_event = *event.event_type() != TimelineEventType::RoomCreate
|
||||
&& rules
|
||||
.authorization
|
||||
.room_create_event_id_as_room_id;
|
||||
|
||||
let also_create_id: Option<OwnedEventId> = also_need_create_event
|
||||
.then(|| event.room_id().as_event_id().ok())
|
||||
.flatten();
|
||||
|
||||
let auth_events = event
|
||||
.auth_events()
|
||||
.chain(also_create_id.as_deref().into_iter())
|
||||
.stream()
|
||||
.filter_map(async |id| {
|
||||
fetch(id.to_owned())
|
||||
.inspect_err(|e| debug_warn!(%id, "missing auth event: {e}"))
|
||||
.inspect_err(|e| debug_assert!(!cfg!(test), "missing auth {id:?}: {e:?}"))
|
||||
.await
|
||||
.ok()
|
||||
})
|
||||
.map(Result::<Pdu, Error>::Ok)
|
||||
.ready_try_filter_map(|auth_event| {
|
||||
let state_key = auth_event
|
||||
.state_key()
|
||||
.ok_or_else(|| err!(Request(InvalidParam("Missing state_key"))))?;
|
||||
|
||||
let key_val = auth_event
|
||||
.rejected()
|
||||
.eq(&false)
|
||||
.then_some((auth_event.event_type().with_state_key(state_key), auth_event));
|
||||
|
||||
Ok(key_val)
|
||||
});
|
||||
|
||||
let auth_events = auth_events
|
||||
.chain(auth_types_events)
|
||||
.try_collect()
|
||||
.map_ok(|mut vec: SmallVec<[_; 4]>| {
|
||||
vec.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
vec.reverse();
|
||||
vec.dedup_by(|a, b| a.0.eq(&b.0));
|
||||
vec
|
||||
})
|
||||
.await?;
|
||||
|
||||
let fetch_state = async |ty: StateEventType, key: StateKey| -> Result<Pdu> {
|
||||
trace!(?ty, ?key, auth_events = auth_events.len(), "fetch state");
|
||||
auth_events
|
||||
.binary_search_by(|a| ty.cmp(&a.0.0).then(key.cmp(&a.0.1)))
|
||||
.map(|i| auth_events[i].1.clone())
|
||||
.map_err(|_| err!(Request(NotFound("Missing auth_event {ty:?},{key:?}"))))
|
||||
};
|
||||
|
||||
// Add authentic event to the partially resolved state.
|
||||
if check_state_dependent_auth_rules(rules, &event, &fetch_state)
|
||||
.await
|
||||
.inspect_err(|e| {
|
||||
debug_warn!(
|
||||
event_type = ?event.event_type(), ?state_key, %event_id,
|
||||
"event failed auth check: {e}"
|
||||
);
|
||||
})
|
||||
.is_ok()
|
||||
{
|
||||
let key = event.event_type().with_state_key(state_key);
|
||||
state.insert(key, event_id.to_owned());
|
||||
}
|
||||
|
||||
Ok(state)
|
||||
}
|
||||
192
src/service/rooms/state_res/resolve/mainline_sort.rs
Normal file
192
src/service/rooms/state_res/resolve/mainline_sort.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
use futures::{
|
||||
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, stream::try_unfold,
|
||||
};
|
||||
use ruma::{EventId, OwnedEventId, events::TimelineEventType};
|
||||
use tuwunel_core::{
|
||||
Error, Result, at, is_equal_to,
|
||||
matrix::Event,
|
||||
trace,
|
||||
utils::stream::{BroadbandExt, IterStream, TryReadyExt},
|
||||
};
|
||||
|
||||
/// Perform mainline ordering of the given events.
|
||||
///
|
||||
/// Definition in the spec:
|
||||
/// Given mainline positions calculated from P, the mainline ordering based on P
|
||||
/// of a set of events is the ordering, from smallest to largest, using the
|
||||
/// following comparison relation on events: for events x and y, x < y if
|
||||
///
|
||||
/// 1. the mainline position of x is greater than the mainline position of y
|
||||
/// (i.e. the auth chain of x is based on an earlier event in the mainline
|
||||
/// than y); or
|
||||
/// 2. the mainline positions of the events are the same, but x’s
|
||||
/// origin_server_ts is less than y’s origin_server_ts; or
|
||||
/// 3. the mainline positions of the events are the same and the events have the
|
||||
/// same origin_server_ts, but x’s event_id is less than y’s event_id.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `events` - The list of event IDs to sort.
|
||||
/// * `power_level` - The power level event in the current state.
|
||||
/// * `fetch_event` - Function to fetch an event in the room given its event ID.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// Returns the sorted list of event IDs, or an `Err(_)` if one the event in the
|
||||
/// room has an unexpected format.
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
power_levels = power_level_event_id
|
||||
.as_deref()
|
||||
.map(EventId::as_str)
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
)]
|
||||
pub(super) async fn mainline_sort<'a, RemainingEvents, Fetch, Fut, Pdu>(
|
||||
power_level_event_id: Option<OwnedEventId>,
|
||||
events: RemainingEvents,
|
||||
fetch: &Fetch,
|
||||
) -> Result<Vec<OwnedEventId>>
|
||||
where
|
||||
RemainingEvents: Stream<Item = &'a EventId> + Send,
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
// Populate the mainline of the power level.
|
||||
let mainline: Vec<_> = try_unfold(power_level_event_id, async |power_level_event_id| {
|
||||
let Some(power_level_event_id) = power_level_event_id else {
|
||||
return Ok::<_, Error>(None);
|
||||
};
|
||||
|
||||
let power_level_event = fetch(power_level_event_id).await?;
|
||||
let this_event_id = power_level_event.event_id().to_owned();
|
||||
let next_event_id = get_power_levels_auth_event(&power_level_event, fetch)
|
||||
.map_ok(|event| {
|
||||
event
|
||||
.as_ref()
|
||||
.map(Event::event_id)
|
||||
.map(ToOwned::to_owned)
|
||||
})
|
||||
.await?;
|
||||
|
||||
trace!(?this_event_id, ?next_event_id, "mainline descent",);
|
||||
|
||||
Ok(Some((this_event_id, next_event_id)))
|
||||
})
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
let mainline = mainline.iter().rev().map(AsRef::as_ref);
|
||||
|
||||
events
|
||||
.map(ToOwned::to_owned)
|
||||
.broad_filter_map(async |event_id| {
|
||||
let event = fetch(event_id.clone()).await.ok()?;
|
||||
let origin_server_ts = event.origin_server_ts();
|
||||
let position = mainline_position(Some(event), &mainline, fetch)
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
Some((event_id, (position, origin_server_ts)))
|
||||
})
|
||||
.inspect(|(event_id, (position, origin_server_ts))| {
|
||||
trace!(position, ?origin_server_ts, ?event_id, "mainline position");
|
||||
})
|
||||
.collect()
|
||||
.map(|mut vec: Vec<_>| {
|
||||
vec.sort_by(|a, b| {
|
||||
let (a_pos, a_ots) = &a.1;
|
||||
let (b_pos, b_ots) = &b.1;
|
||||
a_pos
|
||||
.cmp(b_pos)
|
||||
.then(a_ots.cmp(b_ots))
|
||||
.then(a.cmp(b))
|
||||
});
|
||||
|
||||
vec.into_iter().map(at!(0)).collect()
|
||||
})
|
||||
.map(Ok)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Get the mainline position of the given event from the given mainline map.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `event` - The event to compute the mainline position of.
|
||||
/// * `mainline_map` - The mainline map of the m.room.power_levels event.
|
||||
/// * `fetch` - Function to fetch an event in the room given its event ID.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// Returns the mainline position of the event, or an `Err(_)` if one of the
|
||||
/// events in the auth chain of the event was not found.
|
||||
#[tracing::instrument(
|
||||
name = "position",
|
||||
level = "trace",
|
||||
ret(level = "trace"),
|
||||
skip_all,
|
||||
fields(
|
||||
mainline = mainline.clone().count(),
|
||||
event = ?current_event.as_ref().map(Event::event_id).map(ToOwned::to_owned),
|
||||
)
|
||||
)]
|
||||
async fn mainline_position<'a, Mainline, Fetch, Fut, Pdu>(
|
||||
mut current_event: Option<Pdu>,
|
||||
mainline: &Mainline,
|
||||
fetch: &Fetch,
|
||||
) -> Result<usize>
|
||||
where
|
||||
Mainline: Iterator<Item = &'a EventId> + Clone + Send + Sync,
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
while let Some(event) = current_event {
|
||||
trace!(
|
||||
event_id = ?event.event_id(),
|
||||
"mainline position search",
|
||||
);
|
||||
|
||||
// If the current event is in the mainline map, return its position.
|
||||
if let Some(position) = mainline
|
||||
.clone()
|
||||
.position(is_equal_to!(event.event_id()))
|
||||
{
|
||||
return Ok(position);
|
||||
}
|
||||
|
||||
// Look for the power levels event in the auth events.
|
||||
current_event = get_power_levels_auth_event(&event, fetch).await?;
|
||||
}
|
||||
|
||||
// Did not find a power level event so we default to zero.
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
#[expect(clippy::redundant_closure)]
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
async fn get_power_levels_auth_event<Fetch, Fut, Pdu>(
|
||||
event: &Pdu,
|
||||
fetch: &Fetch,
|
||||
) -> Result<Option<Pdu>>
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
let power_level_event = event
|
||||
.auth_events()
|
||||
.try_stream()
|
||||
.map_ok(ToOwned::to_owned)
|
||||
.and_then(|auth_event_id| fetch(auth_event_id))
|
||||
.ready_try_skip_while(|auth_event| {
|
||||
Ok(!auth_event.is_type_and_state_key(&TimelineEventType::RoomPowerLevels, ""))
|
||||
});
|
||||
|
||||
pin_mut!(power_level_event);
|
||||
power_level_event.try_next().await
|
||||
}
|
||||
273
src/service/rooms/state_res/resolve/power_sort.rs
Normal file
273
src/service/rooms/state_res/resolve/power_sort.rs
Normal file
@@ -0,0 +1,273 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
collections::{HashMap, HashSet},
|
||||
};
|
||||
|
||||
use futures::{StreamExt, TryFutureExt, TryStreamExt};
|
||||
use ruma::{
|
||||
EventId, OwnedEventId,
|
||||
events::{TimelineEventType, room::power_levels::UserPowerLevel},
|
||||
room_version_rules::RoomVersionRules,
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Result, err,
|
||||
matrix::Event,
|
||||
utils::stream::{BroadbandExt, IterStream, TryBroadbandExt},
|
||||
};
|
||||
|
||||
use super::super::{
|
||||
events::{
|
||||
RoomCreateEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField, is_power_event,
|
||||
power_levels::RoomPowerLevelsEventOptionExt,
|
||||
},
|
||||
topological_sort,
|
||||
topological_sort::ReferencedIds,
|
||||
};
|
||||
|
||||
/// Enlarge the given list of conflicted power events by adding the events in
|
||||
/// their auth chain that are in the full conflicted set, and sort it using
|
||||
/// reverse topological power ordering.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `conflicted_power_events` - The list of power events in the full
|
||||
/// conflicted set.
|
||||
///
|
||||
/// * `full_conflicted_set` - The full conflicted set.
|
||||
///
|
||||
/// * `rules` - The authorization rules for the current room version.
|
||||
///
|
||||
/// * `fetch` - Function to fetch an event in the room given its event ID.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// Returns the ordered list of event IDs from earliest to latest.
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
conflicted = full_conflicted_set.len(),
|
||||
)
|
||||
)]
|
||||
pub(super) async fn power_sort<Fetch, Fut, Pdu>(
|
||||
rules: &RoomVersionRules,
|
||||
full_conflicted_set: &HashSet<OwnedEventId>,
|
||||
fetch: &Fetch,
|
||||
) -> Result<Vec<OwnedEventId>>
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
// A representation of the DAG, a map of event ID to its list of auth events
|
||||
// that are in the full conflicted set. Fill the graph.
|
||||
let graph = full_conflicted_set
|
||||
.iter()
|
||||
.stream()
|
||||
.broad_filter_map(async |id| is_power_event_id(id, fetch).await.then_some(id))
|
||||
.fold(HashMap::new(), |graph, event_id| {
|
||||
add_event_auth_chain(graph, full_conflicted_set, event_id, fetch)
|
||||
})
|
||||
.await;
|
||||
|
||||
// The map of event ID to the power level of the sender of the event.
|
||||
// Get the power level of the sender of each event in the graph.
|
||||
let event_to_power_level: HashMap<_, _> = graph
|
||||
.keys()
|
||||
.try_stream()
|
||||
.map_ok(AsRef::as_ref)
|
||||
.broad_and_then(|event_id| {
|
||||
power_level_for_sender(event_id, rules, fetch)
|
||||
.map_ok(move |sender_power| (event_id, sender_power))
|
||||
.map_err(|e| err!(Request(NotFound("Missing PL for sender: {e}"))))
|
||||
})
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
let query = async |event_id: OwnedEventId| {
|
||||
let power_level = *event_to_power_level
|
||||
.get(&event_id.borrow())
|
||||
.ok_or_else(|| err!(Request(NotFound("Missing PL event: {event_id}"))))?;
|
||||
|
||||
let event = fetch(event_id).await?;
|
||||
Ok((power_level, event.origin_server_ts()))
|
||||
};
|
||||
|
||||
topological_sort(&graph, &query).await
|
||||
}
|
||||
|
||||
/// Add the event with the given event ID and all the events in its auth chain
|
||||
/// that are in the full conflicted set to the graph.
|
||||
#[tracing::instrument(
|
||||
name = "auth_chain",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
?event_id,
|
||||
graph = graph.len(),
|
||||
)
|
||||
)]
|
||||
async fn add_event_auth_chain<Fetch, Fut, Pdu>(
|
||||
mut graph: HashMap<OwnedEventId, ReferencedIds>,
|
||||
full_conflicted_set: &HashSet<OwnedEventId>,
|
||||
event_id: &EventId,
|
||||
fetch: &Fetch,
|
||||
) -> HashMap<OwnedEventId, ReferencedIds>
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
let mut state = vec![event_id.to_owned()];
|
||||
|
||||
// Iterate through the auth chain of the event.
|
||||
while let Some(event_id) = state.pop() {
|
||||
// Iterate through the auth events of this event.
|
||||
let event = fetch(event_id.clone()).await.ok();
|
||||
|
||||
// Add the current event to the graph.
|
||||
graph.entry(event_id).or_default();
|
||||
|
||||
event
|
||||
.as_ref()
|
||||
.map(Event::auth_events)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(ToOwned::to_owned)
|
||||
.filter(|auth_event_id| full_conflicted_set.contains(auth_event_id))
|
||||
.for_each(|auth_event_id| {
|
||||
// If the auth event ID is not in the graph, check its auth events later.
|
||||
if !graph.contains_key(&auth_event_id) {
|
||||
state.push(auth_event_id.clone());
|
||||
}
|
||||
|
||||
let event_id = event
|
||||
.as_ref()
|
||||
.expect("event is Some if there are auth_events")
|
||||
.event_id();
|
||||
|
||||
// Add the auth event ID to the list of incoming edges.
|
||||
let references = graph
|
||||
.get_mut(event_id)
|
||||
.expect("event_id must be added to graph");
|
||||
|
||||
if !references.contains(&auth_event_id) {
|
||||
references.push(auth_event_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
graph
|
||||
}
|
||||
|
||||
/// Find the power level for the sender of the event of the given event ID or
|
||||
/// return a default value of zero.
|
||||
///
|
||||
/// We find the most recent `m.room.power_levels` by walking backwards in the
|
||||
/// auth chain of the event.
|
||||
///
|
||||
/// Do NOT use this anywhere but topological sort.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `event_id` - The event ID of the event to get the power level of the
|
||||
/// sender of.
|
||||
///
|
||||
/// * `rules` - The authorization rules for the current room version.
|
||||
///
|
||||
/// * `fetch` - Function to fetch an event in the room given its event ID.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// Returns the power level of the sender of the event or an `Err(_)` if one of
|
||||
/// the auth events if malformed.
|
||||
#[tracing::instrument(
|
||||
name = "sender_power",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
?event_id,
|
||||
)
|
||||
)]
|
||||
async fn power_level_for_sender<Fetch, Fut, Pdu>(
|
||||
event_id: &EventId,
|
||||
rules: &RoomVersionRules,
|
||||
fetch: &Fetch,
|
||||
) -> Result<UserPowerLevel>
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
let mut room_create_event = None;
|
||||
let mut room_power_levels_event = None;
|
||||
let event = fetch(event_id.to_owned()).await;
|
||||
if let Ok(event) = &event
|
||||
&& rules
|
||||
.authorization
|
||||
.room_create_event_id_as_room_id
|
||||
{
|
||||
let create_id = event.room_id().as_event_id()?;
|
||||
let fetched = fetch(create_id).await?;
|
||||
room_create_event = Some(RoomCreateEvent::new(fetched));
|
||||
}
|
||||
|
||||
for auth_event_id in event
|
||||
.as_ref()
|
||||
.map(Event::auth_events)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
if let Ok(auth_event) = fetch(auth_event_id.to_owned()).await {
|
||||
if auth_event.is_type_and_state_key(&TimelineEventType::RoomPowerLevels, "") {
|
||||
room_power_levels_event = Some(RoomPowerLevelsEvent::new(auth_event));
|
||||
} else if !rules
|
||||
.authorization
|
||||
.room_create_event_id_as_room_id
|
||||
&& auth_event.is_type_and_state_key(&TimelineEventType::RoomCreate, "")
|
||||
{
|
||||
room_create_event = Some(RoomCreateEvent::new(auth_event));
|
||||
}
|
||||
|
||||
if room_power_levels_event.is_some() && room_create_event.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let auth_rules = &rules.authorization;
|
||||
let creators = room_create_event
|
||||
.as_ref()
|
||||
.and_then(|event| event.creators(auth_rules).ok());
|
||||
|
||||
if let Some((event, creators)) = event.ok().zip(creators) {
|
||||
room_power_levels_event.user_power_level(event.sender(), creators, auth_rules)
|
||||
} else {
|
||||
room_power_levels_event
|
||||
.get_as_int_or_default(RoomPowerLevelsIntField::UsersDefault, auth_rules)
|
||||
.map(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the given event ID belongs to a power event.
|
||||
///
|
||||
/// See the docs of `is_power_event()` for the definition of a power event.
|
||||
#[tracing::instrument(
|
||||
name = "is_power_event",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
?event_id,
|
||||
)
|
||||
)]
|
||||
async fn is_power_event_id<Fetch, Fut, Pdu>(event_id: &EventId, fetch: &Fetch) -> bool
|
||||
where
|
||||
Fetch: Fn(OwnedEventId) -> Fut + Sync,
|
||||
Fut: Future<Output = Result<Pdu>> + Send,
|
||||
Pdu: Event,
|
||||
{
|
||||
match fetch(event_id.to_owned()).await {
|
||||
| Ok(state) => is_power_event(&state),
|
||||
| _ => false,
|
||||
}
|
||||
}
|
||||
82
src/service/rooms/state_res/resolve/split_conflicted.rs
Normal file
82
src/service/rooms/state_res/resolve/split_conflicted.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use std::{collections::HashMap, hash::Hash, iter::IntoIterator};
|
||||
|
||||
use futures::{Stream, StreamExt};
|
||||
use tuwunel_core::validated;
|
||||
|
||||
use super::{ConflictMap, StateMap};
|
||||
|
||||
/// Split the unconflicted state map and the conflicted state set.
|
||||
///
|
||||
/// Definition in the specification:
|
||||
///
|
||||
/// If a given key _K_ is present in every _Si_ with the same value _V_ in each
|
||||
/// state map, then the pair (_K_, _V_) belongs to the unconflicted state map.
|
||||
/// Otherwise, _V_ belongs to the conflicted state set.
|
||||
///
|
||||
/// It means that, for a given (event type, state key) tuple, if all state maps
|
||||
/// have the same event ID, it lands in the unconflicted state map, otherwise
|
||||
/// the event IDs land in the conflicted state set.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `state_maps` - The incoming states to resolve. Each `StateMap` represents
|
||||
/// a possible fork in the state of a room.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// Returns an `(unconflicted_state, conflicted_states)` tuple.
|
||||
#[tracing::instrument(name = "split", level = "debug", skip_all)]
|
||||
pub(super) async fn split_conflicted_state<'a, Maps, Id>(
|
||||
state_maps: Maps,
|
||||
) -> (StateMap<Id>, ConflictMap<Id>)
|
||||
where
|
||||
Maps: Stream<Item = StateMap<Id>>,
|
||||
Id: Clone + Eq + Hash + Ord + Send + Sync + 'a,
|
||||
{
|
||||
let state_maps: Vec<_> = state_maps.collect().await;
|
||||
|
||||
let state_ids_est = state_maps.iter().flatten().count();
|
||||
|
||||
let state_set_count = state_maps
|
||||
.iter()
|
||||
.fold(0_usize, |acc, _| validated!(acc + 1));
|
||||
|
||||
let mut occurrences = HashMap::<_, HashMap<_, usize>>::with_capacity(state_ids_est);
|
||||
|
||||
for (k, v) in state_maps
|
||||
.into_iter()
|
||||
.flat_map(IntoIterator::into_iter)
|
||||
{
|
||||
let acc = occurrences
|
||||
.entry(k.clone())
|
||||
.or_default()
|
||||
.entry(v.clone())
|
||||
.or_default();
|
||||
|
||||
*acc = acc.saturating_add(1);
|
||||
}
|
||||
|
||||
let mut unconflicted_state_map = StateMap::new();
|
||||
let mut conflicted_state_set = ConflictMap::new();
|
||||
|
||||
for (k, v) in occurrences {
|
||||
for (id, occurrence_count) in v {
|
||||
if occurrence_count == state_set_count {
|
||||
unconflicted_state_map.insert((k.0.clone(), k.1.clone()), id.clone());
|
||||
} else {
|
||||
let conflicts = conflicted_state_set
|
||||
.entry((k.0.clone(), k.1.clone()))
|
||||
.or_default();
|
||||
|
||||
debug_assert!(
|
||||
!conflicts.contains(&id),
|
||||
"Unexpected duplicate conflicted state event"
|
||||
);
|
||||
|
||||
conflicts.push(id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(unconflicted_state_map, conflicted_state_set)
|
||||
}
|
||||
843
src/service/rooms/state_res/resolve/tests.rs
Normal file
843
src/service/rooms/state_res/resolve/tests.rs
Normal file
@@ -0,0 +1,843 @@
|
||||
use std::{collections::HashMap, iter::once};
|
||||
|
||||
use futures::StreamExt;
|
||||
use maplit::hashmap;
|
||||
use rand::seq::SliceRandom;
|
||||
use ruma::{
|
||||
MilliSecondsSinceUnixEpoch, OwnedEventId,
|
||||
events::{
|
||||
StateEventType, TimelineEventType,
|
||||
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||
},
|
||||
int,
|
||||
room_version_rules::RoomVersionRules,
|
||||
uint,
|
||||
};
|
||||
use serde_json::{json, value::to_raw_value as to_raw_json_value};
|
||||
use tuwunel_core::{
|
||||
debug,
|
||||
matrix::{Event, EventTypeExt, PduEvent},
|
||||
utils::stream::IterStream,
|
||||
};
|
||||
|
||||
use super::{
|
||||
StateMap,
|
||||
test_utils::{
|
||||
INITIAL_EVENTS, TestStore, alice, bob, charlie, do_check, ella, event_id,
|
||||
member_content_ban, member_content_join, not_found, room_id, to_init_pdu_event,
|
||||
to_pdu_event, zara,
|
||||
},
|
||||
};
|
||||
|
||||
async fn test_event_sort() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let rules = RoomVersionRules::V6;
|
||||
let events = INITIAL_EVENTS();
|
||||
|
||||
let auth_chain = Default::default();
|
||||
|
||||
let sorted_power_events = super::power_sort(&rules, &auth_chain, &async |id| {
|
||||
events.get(&id).cloned().ok_or_else(not_found)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let sorted_power_events = sorted_power_events
|
||||
.iter()
|
||||
.stream()
|
||||
.map(AsRef::as_ref);
|
||||
|
||||
let resolved_power =
|
||||
super::iterative_auth_check(&rules, sorted_power_events, StateMap::new(), &async |id| {
|
||||
events.get(&id).cloned().ok_or_else(not_found)
|
||||
})
|
||||
.await
|
||||
.expect("iterative auth check failed on resolved events");
|
||||
|
||||
// don't remove any events so we know it sorts them all correctly
|
||||
let mut events_to_sort = events.keys().cloned().collect::<Vec<_>>();
|
||||
|
||||
events_to_sort.shuffle(&mut rand::thread_rng());
|
||||
|
||||
let power_level = resolved_power
|
||||
.get(&(StateEventType::RoomPowerLevels, "".into()))
|
||||
.cloned();
|
||||
|
||||
let events_to_sort = events_to_sort.iter().stream().map(AsRef::as_ref);
|
||||
|
||||
let sorted_event_ids = super::mainline_sort(power_level, events_to_sort, &async |id| {
|
||||
events.get(&id).cloned().ok_or_else(not_found)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
vec![
|
||||
"$CREATE:foo",
|
||||
"$IMA:foo",
|
||||
"$IPOWER:foo",
|
||||
"$IJR:foo",
|
||||
"$IMB:foo",
|
||||
"$IMC:foo",
|
||||
"$START:foo",
|
||||
"$END:foo"
|
||||
],
|
||||
sorted_event_ids
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sort() {
|
||||
for _ in 0..20 {
|
||||
// since we shuffle the eventIds before we sort them introducing randomness
|
||||
// seems like we should test this a few times
|
||||
test_event_sort().await;
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ban_vs_power_level() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let events = &[
|
||||
to_init_pdu_event(
|
||||
"PA",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"MA",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(alice().to_string().as_str()),
|
||||
member_content_join(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"MB",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(bob().to_string().as_str()),
|
||||
member_content_ban(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PB",
|
||||
bob(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
),
|
||||
];
|
||||
|
||||
let edges = vec![vec!["END", "MB", "MA", "PA", "START"], vec!["END", "PA", "PB"]]
|
||||
.into_iter()
|
||||
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_state_ids = vec!["PA", "MA", "MB"]
|
||||
.into_iter()
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn topic_basic() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let events = vec![
|
||||
to_init_pdu_event(
|
||||
"T1",
|
||||
alice(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PA1",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"T2",
|
||||
alice(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PA2",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PB",
|
||||
bob(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"T3",
|
||||
bob(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
];
|
||||
|
||||
let edges =
|
||||
vec![vec!["END", "PA2", "T2", "PA1", "T1", "START"], vec!["END", "T3", "PB", "PA1"]]
|
||||
.into_iter()
|
||||
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_state_ids = vec!["PA2", "T2"]
|
||||
.into_iter()
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
do_check(&events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn topic_reset() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let events = &[
|
||||
to_init_pdu_event(
|
||||
"T1",
|
||||
alice(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PA",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"T2",
|
||||
bob(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"MB",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(bob().to_string().as_str()),
|
||||
member_content_ban(),
|
||||
),
|
||||
];
|
||||
|
||||
let edges = vec![vec!["END", "MB", "T2", "PA", "T1", "START"], vec!["END", "T1"]]
|
||||
.into_iter()
|
||||
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_state_ids = vec!["T1", "MB", "PA"]
|
||||
.into_iter()
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn join_rule_evasion() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let events = &[
|
||||
to_init_pdu_event(
|
||||
"JR",
|
||||
alice(),
|
||||
TimelineEventType::RoomJoinRules,
|
||||
Some(""),
|
||||
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Private)).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"ME",
|
||||
ella(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(ella().to_string().as_str()),
|
||||
member_content_join(),
|
||||
),
|
||||
];
|
||||
|
||||
let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]]
|
||||
.into_iter()
|
||||
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_state_ids = vec![event_id("JR")];
|
||||
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn offtopic_power_level() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let events = &[
|
||||
to_init_pdu_event(
|
||||
"PA",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PB",
|
||||
bob(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 50 } }))
|
||||
.unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PC",
|
||||
charlie(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 0 } }))
|
||||
.unwrap(),
|
||||
),
|
||||
];
|
||||
|
||||
let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]]
|
||||
.into_iter()
|
||||
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_state_ids = vec!["PC"]
|
||||
.into_iter()
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
do_check(events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn topic_setting() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let events = vec![
|
||||
to_init_pdu_event(
|
||||
"T1",
|
||||
alice(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PA1",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"T2",
|
||||
alice(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PA2",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"PB",
|
||||
bob(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"T3",
|
||||
bob(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"MZ1",
|
||||
zara(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
to_init_pdu_event(
|
||||
"T4",
|
||||
alice(),
|
||||
TimelineEventType::RoomTopic,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({})).unwrap(),
|
||||
),
|
||||
];
|
||||
|
||||
let edges = vec![vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], vec![
|
||||
"END", "MZ1", "T3", "PB", "PA1",
|
||||
]]
|
||||
.into_iter()
|
||||
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_state_ids = vec!["T4", "PA2"]
|
||||
.into_iter()
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
do_check(&events, edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_event_map_none() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let mut store = TestStore(hashmap! {});
|
||||
|
||||
// build up the DAG
|
||||
let (state_at_bob, state_at_charlie, expected) = store.set_up();
|
||||
|
||||
let ev_map = store.0.clone();
|
||||
let state_sets = [state_at_bob, state_at_charlie];
|
||||
let auth_chains = state_sets
|
||||
.iter()
|
||||
.map(|map| {
|
||||
store
|
||||
.auth_event_ids(room_id(), map.values().cloned().collect())
|
||||
.unwrap()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let rules = RoomVersionRules::V1;
|
||||
let resolved = match super::resolve(
|
||||
&rules,
|
||||
state_sets.into_iter().stream(),
|
||||
auth_chains.into_iter().stream(),
|
||||
&async |id| ev_map.get(&id).cloned().ok_or_else(not_found),
|
||||
&async |id| ev_map.contains_key(&id),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
{
|
||||
| Ok(state) => state,
|
||||
| Err(e) => panic!("{e}"),
|
||||
};
|
||||
|
||||
assert_eq!(expected, resolved);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[expect(
|
||||
clippy::iter_on_single_items,
|
||||
clippy::iter_on_empty_collections
|
||||
)]
|
||||
async fn test_reverse_topological_power_sort() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
|
||||
let graph = hashmap! {
|
||||
event_id("l") => [event_id("o")].into_iter().collect(),
|
||||
event_id("m") => [event_id("n"), event_id("o")].into_iter().collect(),
|
||||
event_id("n") => [event_id("o")].into_iter().collect(),
|
||||
event_id("o") => [].into_iter().collect(), // "o" has zero outgoing edges but 4 incoming edges
|
||||
event_id("p") => [event_id("o")].into_iter().collect(),
|
||||
};
|
||||
|
||||
let res = super::super::topological_sort(&graph, &async |_id| {
|
||||
Ok((int!(0).into(), MilliSecondsSinceUnixEpoch(uint!(0))))
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
vec!["o", "l", "n", "m", "p"],
|
||||
res.iter()
|
||||
.map(ToString::to_string)
|
||||
.map(|s| s.replace('$', "").replace(":foo", ""))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ban_with_auth_chains() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
let ban = BAN_STATE_SET();
|
||||
|
||||
let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]]
|
||||
.into_iter()
|
||||
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_state_ids = vec!["PA", "MB"]
|
||||
.into_iter()
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ban_with_auth_chains2() {
|
||||
_ = tracing::subscriber::set_default(
|
||||
tracing_subscriber::fmt()
|
||||
.with_test_writer()
|
||||
.finish(),
|
||||
);
|
||||
let init = INITIAL_EVENTS();
|
||||
let ban = BAN_STATE_SET();
|
||||
|
||||
let mut inner = init.clone();
|
||||
inner.extend(ban);
|
||||
let store = TestStore(inner.clone());
|
||||
|
||||
let state_set_a = [
|
||||
&inner[&event_id("CREATE")],
|
||||
&inner[&event_id("IJR")],
|
||||
&inner[&event_id("IMA")],
|
||||
&inner[&event_id("IMB")],
|
||||
&inner[&event_id("IMC")],
|
||||
&inner[&event_id("MB")],
|
||||
&inner[&event_id("PA")],
|
||||
]
|
||||
.iter()
|
||||
.map(|ev| {
|
||||
(
|
||||
ev.event_type()
|
||||
.with_state_key(ev.state_key().unwrap()),
|
||||
ev.event_id.clone(),
|
||||
)
|
||||
})
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let state_set_b = [
|
||||
&inner[&event_id("CREATE")],
|
||||
&inner[&event_id("IJR")],
|
||||
&inner[&event_id("IMA")],
|
||||
&inner[&event_id("IMB")],
|
||||
&inner[&event_id("IMC")],
|
||||
&inner[&event_id("IME")],
|
||||
&inner[&event_id("PA")],
|
||||
]
|
||||
.iter()
|
||||
.map(|ev| {
|
||||
(
|
||||
ev.event_type()
|
||||
.with_state_key(ev.state_key().unwrap()),
|
||||
ev.event_id.clone(),
|
||||
)
|
||||
})
|
||||
.collect::<StateMap<_>>();
|
||||
|
||||
let ev_map = &store.0;
|
||||
let state_sets = [state_set_a, state_set_b];
|
||||
let auth_chains = state_sets
|
||||
.iter()
|
||||
.map(|map| {
|
||||
store
|
||||
.auth_event_ids(room_id(), map.values().cloned().collect())
|
||||
.unwrap()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let resolved = match super::resolve(
|
||||
&RoomVersionRules::V6,
|
||||
state_sets.into_iter().stream(),
|
||||
auth_chains.into_iter().stream(),
|
||||
&async |id| ev_map.get(&id).cloned().ok_or_else(not_found),
|
||||
&async |id| ev_map.contains_key(&id),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
{
|
||||
| Ok(state) => state,
|
||||
| Err(e) => panic!("{e}"),
|
||||
};
|
||||
|
||||
debug!(
|
||||
resolved = ?resolved
|
||||
.iter()
|
||||
.map(|((ty, key), id)| format!("(({ty}{key:?}), {id})"))
|
||||
.collect::<Vec<_>>(),
|
||||
"resolved state",
|
||||
);
|
||||
|
||||
let expected = [
|
||||
"$CREATE:foo",
|
||||
"$IJR:foo",
|
||||
"$PA:foo",
|
||||
"$IMA:foo",
|
||||
"$IMB:foo",
|
||||
"$IMC:foo",
|
||||
"$MB:foo",
|
||||
];
|
||||
|
||||
for id in expected.iter().map(|i| event_id(i)) {
|
||||
// make sure our resolved events are equal to the expected list
|
||||
assert!(resolved.values().any(|eid| eid == &id) || init.contains_key(&id), "{id}");
|
||||
}
|
||||
assert_eq!(expected.len(), resolved.len());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn join_rule_with_auth_chain() {
|
||||
let join_rule = JOIN_RULE();
|
||||
|
||||
let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
|
||||
.into_iter()
|
||||
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected_state_ids = vec!["JR"]
|
||||
.into_iter()
|
||||
.map(event_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
|
||||
}
|
||||
|
||||
#[expect(non_snake_case)]
|
||||
fn BAN_STATE_SET() -> HashMap<OwnedEventId, PduEvent> {
|
||||
vec![
|
||||
to_pdu_event(
|
||||
"PA",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
&["CREATE", "IMA", "IPOWER"], // auth_events
|
||||
&["START"], // prev_events
|
||||
),
|
||||
to_pdu_event(
|
||||
"PB",
|
||||
alice(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
|
||||
&["CREATE", "IMA", "IPOWER"],
|
||||
&["END"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"MB",
|
||||
alice(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(ella().as_str()),
|
||||
member_content_ban(),
|
||||
&["CREATE", "IMA", "PB"],
|
||||
&["PA"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"IME",
|
||||
ella(),
|
||||
TimelineEventType::RoomMember,
|
||||
Some(ella().as_str()),
|
||||
member_content_join(),
|
||||
&["CREATE", "IJR", "PA"],
|
||||
&["MB"],
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|ev| (ev.event_id.clone(), ev))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[expect(non_snake_case)]
|
||||
fn JOIN_RULE() -> HashMap<OwnedEventId, PduEvent> {
|
||||
vec![
|
||||
to_pdu_event(
|
||||
"JR",
|
||||
alice(),
|
||||
TimelineEventType::RoomJoinRules,
|
||||
Some(""),
|
||||
to_raw_json_value(&json!({ "join_rule": "invite" })).unwrap(),
|
||||
&["CREATE", "IMA", "IPOWER"],
|
||||
&["START"],
|
||||
),
|
||||
to_pdu_event(
|
||||
"IMZ",
|
||||
zara(),
|
||||
TimelineEventType::RoomPowerLevels,
|
||||
Some(zara().as_str()),
|
||||
member_content_join(),
|
||||
&["CREATE", "JR", "IPOWER"],
|
||||
&["START"],
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|ev| (ev.event_id.clone(), ev))
|
||||
.collect()
|
||||
}
|
||||
|
||||
macro_rules! state_set {
|
||||
($($kind:expr => $key:expr => $id:expr),* $(,)?) => {{
|
||||
let mut x = StateMap::new();
|
||||
$(
|
||||
x.insert(($kind, $key.into()), $id);
|
||||
)*
|
||||
x
|
||||
}};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn split_conflicted_state_set_conflicted_unique_state_keys() {
|
||||
let (unconflicted, conflicted) = super::split_conflicted_state(
|
||||
[
|
||||
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
|
||||
state_set![StateEventType::RoomMember => "@b:hs1" => 1],
|
||||
state_set![StateEventType::RoomMember => "@c:hs1" => 2],
|
||||
]
|
||||
.into_iter()
|
||||
.stream(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (unconflicted, conflicted): (StateMap<_>, StateMap<_>) =
|
||||
(unconflicted.into_iter().collect(), conflicted.into_iter().collect());
|
||||
|
||||
assert_eq!(unconflicted, StateMap::new());
|
||||
assert_eq!(conflicted, state_set![
|
||||
StateEventType::RoomMember => "@a:hs1" => once(0).collect(),
|
||||
StateEventType::RoomMember => "@b:hs1" => once(1).collect(),
|
||||
StateEventType::RoomMember => "@c:hs1" => once(2).collect(),
|
||||
],);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn split_conflicted_state_set_conflicted_same_state_key() {
|
||||
let (unconflicted, conflicted) = super::split_conflicted_state(
|
||||
[
|
||||
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
|
||||
state_set![StateEventType::RoomMember => "@a:hs1" => 1],
|
||||
state_set![StateEventType::RoomMember => "@a:hs1" => 2],
|
||||
]
|
||||
.into_iter()
|
||||
.stream(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (unconflicted, mut conflicted): (StateMap<_>, StateMap<_>) =
|
||||
(unconflicted.into_iter().collect(), conflicted.into_iter().collect());
|
||||
|
||||
// HashMap iteration order is random, so sort this before asserting on it
|
||||
for v in conflicted.values_mut() {
|
||||
v.sort_unstable();
|
||||
}
|
||||
|
||||
assert_eq!(unconflicted, StateMap::new());
|
||||
assert_eq!(conflicted, state_set![
|
||||
StateEventType::RoomMember => "@a:hs1" => [0, 1, 2].into_iter().collect(),
|
||||
],);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn split_conflicted_state_set_unconflicted() {
|
||||
let (unconflicted, conflicted) = super::split_conflicted_state(
|
||||
[
|
||||
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
|
||||
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
|
||||
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
|
||||
]
|
||||
.into_iter()
|
||||
.stream(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (unconflicted, conflicted): (StateMap<_>, StateMap<_>) =
|
||||
(unconflicted.into_iter().collect(), conflicted.into_iter().collect());
|
||||
|
||||
assert_eq!(unconflicted, state_set![
|
||||
StateEventType::RoomMember => "@a:hs1" => 0,
|
||||
],);
|
||||
assert_eq!(conflicted, StateMap::new());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn split_conflicted_state_set_mixed() {
|
||||
let (unconflicted, conflicted) = super::split_conflicted_state(
|
||||
[
|
||||
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
|
||||
state_set![
|
||||
StateEventType::RoomMember => "@a:hs1" => 0,
|
||||
StateEventType::RoomMember => "@b:hs1" => 1,
|
||||
],
|
||||
state_set![
|
||||
StateEventType::RoomMember => "@a:hs1" => 0,
|
||||
StateEventType::RoomMember => "@c:hs1" => 2,
|
||||
],
|
||||
]
|
||||
.into_iter()
|
||||
.stream(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (unconflicted, conflicted): (StateMap<_>, StateMap<_>) =
|
||||
(unconflicted.into_iter().collect(), conflicted.into_iter().collect());
|
||||
|
||||
assert_eq!(unconflicted, state_set![
|
||||
StateEventType::RoomMember => "@a:hs1" => 0,
|
||||
],);
|
||||
assert_eq!(conflicted, state_set![
|
||||
StateEventType::RoomMember => "@b:hs1" => once(1).collect(),
|
||||
StateEventType::RoomMember => "@c:hs1" => once(2).collect(),
|
||||
],);
|
||||
}
|
||||
Reference in New Issue
Block a user