Move state_res from tuwunel_core to tuwunel_service.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-02-16 05:43:03 +00:00
parent 6a550baf5f
commit 9ede830ffe
73 changed files with 134 additions and 131 deletions

View File

@@ -0,0 +1,41 @@
use std::borrow::Borrow;
use futures::{FutureExt, Stream};
use ruma::EventId;
use tuwunel_core::utils::stream::{IterStream, ReadyExt};
use super::AuthSet;
/// Get the auth difference for the given auth chains.
///
/// Definition in the specification:
///
/// The auth difference is calculated by first calculating the full auth chain
/// for each state _Si_, that is the union of the auth chains for each event in
/// _Si_, and then taking every event that doesnt appear in every auth chain.
/// If _Ci_ is the full auth chain of _Si_, then the auth difference is _Ci_
/// ∩_Ci_.
///
/// ## Arguments
///
/// * `auth_chains` - The list of full recursive sets of `auth_events`. Inputs
/// must be sorted.
///
/// ## Returns
///
/// Outputs the event IDs that are not present in all the auth chains.
#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn auth_difference<'a, AuthSets, Id>(auth_sets: AuthSets) -> impl Stream<Item = Id>
where
AuthSets: Stream<Item = AuthSet<Id>>,
Id: Borrow<EventId> + Clone + Eq + Ord + Send + 'a,
{
auth_sets
.ready_fold_default(|ret: AuthSet<Id>, set| {
ret.symmetric_difference(&set)
.cloned()
.collect::<AuthSet<Id>>()
})
.map(|set: AuthSet<Id>| set.into_iter().stream())
.flatten_stream()
}

View File

@@ -0,0 +1,146 @@
use std::{
collections::HashSet as Set,
iter::once,
mem::take,
sync::{Arc, Mutex},
};
use futures::{Future, FutureExt, Stream, StreamExt};
use ruma::OwnedEventId;
use tuwunel_core::{
Result, debug,
matrix::{Event, pdu::AuthEvents},
smallvec::SmallVec,
utils::stream::{IterStream, automatic_width},
};
#[derive(Default)]
struct Global {
subgraph: Mutex<Set<OwnedEventId>>,
seen: Mutex<Set<OwnedEventId>>,
}
#[derive(Default, Debug)]
struct Local {
path: Path,
stack: Stack,
}
type Path = SmallVec<[OwnedEventId; PATH_INLINE]>;
type Stack = SmallVec<[Frame; STACK_INLINE]>;
type Frame = AuthEvents;
const PATH_INLINE: usize = 48;
const STACK_INLINE: usize = 48;
#[tracing::instrument(name = "subgraph_dfs", level = "debug", skip_all)]
pub(super) fn conflicted_subgraph_dfs<Fetch, Fut, Pdu>(
conflicted_event_ids: &Set<&OwnedEventId>,
fetch: &Fetch,
) -> impl Stream<Item = OwnedEventId> + Send
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
let state = Arc::new(Global::default());
let state_ = state.clone();
conflicted_event_ids
.iter()
.stream()
.enumerate()
.map(move |(i, event_id)| (state_.clone(), event_id, i))
.for_each_concurrent(automatic_width(), async |(state, event_id, i)| {
subgraph_descent(conflicted_event_ids, fetch, &state, event_id, i)
.await
.expect("only mutex errors expected");
})
.map(move |()| {
let seen = state.seen.lock().expect("locked");
let mut state = state.subgraph.lock().expect("locked");
debug!(
input_events = conflicted_event_ids.len(),
seen_events = seen.len(),
output_events = state.len(),
"conflicted subgraph state"
);
take(&mut *state).into_iter().stream()
})
.flatten_stream()
}
#[tracing::instrument(
name = "descent",
level = "trace",
skip_all,
fields(
event_ids = conflicted_event_ids.len(),
event_id = %conflicted_event_id,
%i,
)
)]
async fn subgraph_descent<Fetch, Fut, Pdu>(
conflicted_event_ids: &Set<&OwnedEventId>,
fetch: &Fetch,
state: &Arc<Global>,
conflicted_event_id: &OwnedEventId,
i: usize,
) -> Result
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
let Global { subgraph, seen } = &**state;
let mut local = Local {
path: once(conflicted_event_id.clone()).collect(),
stack: once(once(conflicted_event_id.clone()).collect()).collect(),
};
while let Some(event_id) = pop(&mut local) {
if subgraph.lock()?.contains(&event_id) {
if local.path.len() > 1 {
subgraph
.lock()?
.extend(local.path.iter().cloned());
}
local.path.pop();
continue;
}
if !seen.lock()?.insert(event_id.clone()) {
continue;
}
if local.path.len() > 1 && conflicted_event_ids.contains(&event_id) {
subgraph
.lock()?
.extend(local.path.iter().cloned());
}
if let Ok(event) = fetch(event_id).await {
local
.stack
.push(event.auth_events_into().into_iter().collect());
}
}
Ok(())
}
fn pop(local: &mut Local) -> Option<OwnedEventId> {
let Local { path, stack } = local;
while stack.last().is_some_and(Frame::is_empty) {
stack.pop();
path.pop();
}
stack
.last_mut()
.and_then(Frame::pop)
.inspect(|event_id| path.push(event_id.clone()))
}

View File

@@ -0,0 +1,209 @@
use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{
EventId, OwnedEventId,
events::{StateEventType, TimelineEventType},
room_version_rules::RoomVersionRules,
};
use tuwunel_core::{
Error, Result, debug_warn, err, error,
matrix::{Event, EventTypeExt, StateKey},
smallvec::SmallVec,
trace,
utils::stream::{IterStream, ReadyExt, TryReadyExt, TryWidebandExt},
};
use super::{
super::{auth_types_for_event, check_state_dependent_auth_rules},
StateMap,
};
/// Perform the iterative auth checks to the given list of events.
///
/// Definition in the specification:
///
/// The iterative auth checks algorithm takes as input an initial room state and
/// a sorted list of state events, and constructs a new room state by iterating
/// through the event list and applying the state event to the room state if the
/// state event is allowed by the authorization rules. If the state event is not
/// allowed by the authorization rules, then the event is ignored. If a
/// (event_type, state_key) key that is required for checking the authorization
/// rules is not present in the state, then the appropriate state event from the
/// events auth_events is used if the auth event is not rejected.
///
/// ## Arguments
///
/// * `rules` - The authorization rules for the current room version.
/// * `events` - The sorted state events to apply to the `partial_state`.
/// * `state` - The current state that was partially resolved for the room.
/// * `fetch_event` - Function to fetch an event in the room given its event ID.
///
/// ## Returns
///
/// Returns the partially resolved state, or an `Err(_)` if one of the state
/// events in the room has an unexpected format.
#[tracing::instrument(
name = "iterative_auth",
level = "debug",
skip_all,
fields(
states = ?state.len(),
)
)]
pub(super) async fn iterative_auth_check<'b, SortedPowerEvents, Fetch, Fut, Pdu>(
rules: &RoomVersionRules,
events: SortedPowerEvents,
state: StateMap<OwnedEventId>,
fetch: &Fetch,
) -> Result<StateMap<OwnedEventId>>
where
SortedPowerEvents: Stream<Item = &'b EventId> + Send,
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
events
.map(Ok)
.wide_and_then(async |event_id| {
let event = fetch(event_id.to_owned()).await?;
let state_key: StateKey = event
.state_key()
.ok_or_else(|| err!(Request(InvalidParam("Missing state_key"))))?
.into();
Ok((event_id, state_key, event))
})
.try_fold(state, |state, (event_id, state_key, event)| {
auth_check(rules, state, event_id, state_key, event, fetch)
})
.await
}
#[tracing::instrument(
name = "check",
level = "trace",
skip_all,
fields(
%event_id,
?state_key,
)
)]
async fn auth_check<Fetch, Fut, Pdu>(
rules: &RoomVersionRules,
mut state: StateMap<OwnedEventId>,
event_id: &EventId,
state_key: StateKey,
event: Pdu,
fetch: &Fetch,
) -> Result<StateMap<OwnedEventId>>
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
let Ok(auth_types) = auth_types_for_event(
event.event_type(),
event.sender(),
Some(&state_key),
event.content(),
&rules.authorization,
true,
)
.inspect_err(|e| error!("failed to get auth types for event: {e}")) else {
return Ok(state);
};
let auth_types_events = auth_types
.stream()
.ready_filter_map(|key| {
state
.get(&key)
.map(move |auth_event_id| (auth_event_id, key))
})
.filter_map(async |(id, key)| {
fetch(id.clone())
.inspect_err(|e| debug_warn!(%id, "missing auth event: {e}"))
.inspect_err(|e| debug_assert!(!cfg!(test), "missing auth {id:?}: {e:?}"))
.map_ok(move |auth_event| (key, auth_event))
.await
.ok()
})
.ready_filter_map(|(key, auth_event)| {
auth_event
.rejected()
.eq(&false)
.then_some((key, auth_event))
})
.map(Ok);
// If the `m.room.create` event is not in the auth events, we need to add it,
// because it's always part of the state and required in the auth rules.
let also_need_create_event = *event.event_type() != TimelineEventType::RoomCreate
&& rules
.authorization
.room_create_event_id_as_room_id;
let also_create_id: Option<OwnedEventId> = also_need_create_event
.then(|| event.room_id().as_event_id().ok())
.flatten();
let auth_events = event
.auth_events()
.chain(also_create_id.as_deref().into_iter())
.stream()
.filter_map(async |id| {
fetch(id.to_owned())
.inspect_err(|e| debug_warn!(%id, "missing auth event: {e}"))
.inspect_err(|e| debug_assert!(!cfg!(test), "missing auth {id:?}: {e:?}"))
.await
.ok()
})
.map(Result::<Pdu, Error>::Ok)
.ready_try_filter_map(|auth_event| {
let state_key = auth_event
.state_key()
.ok_or_else(|| err!(Request(InvalidParam("Missing state_key"))))?;
let key_val = auth_event
.rejected()
.eq(&false)
.then_some((auth_event.event_type().with_state_key(state_key), auth_event));
Ok(key_val)
});
let auth_events = auth_events
.chain(auth_types_events)
.try_collect()
.map_ok(|mut vec: SmallVec<[_; 4]>| {
vec.sort_by(|a, b| a.0.cmp(&b.0));
vec.reverse();
vec.dedup_by(|a, b| a.0.eq(&b.0));
vec
})
.await?;
let fetch_state = async |ty: StateEventType, key: StateKey| -> Result<Pdu> {
trace!(?ty, ?key, auth_events = auth_events.len(), "fetch state");
auth_events
.binary_search_by(|a| ty.cmp(&a.0.0).then(key.cmp(&a.0.1)))
.map(|i| auth_events[i].1.clone())
.map_err(|_| err!(Request(NotFound("Missing auth_event {ty:?},{key:?}"))))
};
// Add authentic event to the partially resolved state.
if check_state_dependent_auth_rules(rules, &event, &fetch_state)
.await
.inspect_err(|e| {
debug_warn!(
event_type = ?event.event_type(), ?state_key, %event_id,
"event failed auth check: {e}"
);
})
.is_ok()
{
let key = event.event_type().with_state_key(state_key);
state.insert(key, event_id.to_owned());
}
Ok(state)
}

View File

@@ -0,0 +1,192 @@
use futures::{
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, pin_mut, stream::try_unfold,
};
use ruma::{EventId, OwnedEventId, events::TimelineEventType};
use tuwunel_core::{
Error, Result, at, is_equal_to,
matrix::Event,
trace,
utils::stream::{BroadbandExt, IterStream, TryReadyExt},
};
/// Perform mainline ordering of the given events.
///
/// Definition in the spec:
/// Given mainline positions calculated from P, the mainline ordering based on P
/// of a set of events is the ordering, from smallest to largest, using the
/// following comparison relation on events: for events x and y, x < y if
///
/// 1. the mainline position of x is greater than the mainline position of y
/// (i.e. the auth chain of x is based on an earlier event in the mainline
/// than y); or
/// 2. the mainline positions of the events are the same, but xs
/// origin_server_ts is less than ys origin_server_ts; or
/// 3. the mainline positions of the events are the same and the events have the
/// same origin_server_ts, but xs event_id is less than ys event_id.
///
/// ## Arguments
///
/// * `events` - The list of event IDs to sort.
/// * `power_level` - The power level event in the current state.
/// * `fetch_event` - Function to fetch an event in the room given its event ID.
///
/// ## Returns
///
/// Returns the sorted list of event IDs, or an `Err(_)` if one the event in the
/// room has an unexpected format.
#[tracing::instrument(
level = "debug",
skip_all,
fields(
power_levels = power_level_event_id
.as_deref()
.map(EventId::as_str)
.unwrap_or_default(),
)
)]
pub(super) async fn mainline_sort<'a, RemainingEvents, Fetch, Fut, Pdu>(
power_level_event_id: Option<OwnedEventId>,
events: RemainingEvents,
fetch: &Fetch,
) -> Result<Vec<OwnedEventId>>
where
RemainingEvents: Stream<Item = &'a EventId> + Send,
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
// Populate the mainline of the power level.
let mainline: Vec<_> = try_unfold(power_level_event_id, async |power_level_event_id| {
let Some(power_level_event_id) = power_level_event_id else {
return Ok::<_, Error>(None);
};
let power_level_event = fetch(power_level_event_id).await?;
let this_event_id = power_level_event.event_id().to_owned();
let next_event_id = get_power_levels_auth_event(&power_level_event, fetch)
.map_ok(|event| {
event
.as_ref()
.map(Event::event_id)
.map(ToOwned::to_owned)
})
.await?;
trace!(?this_event_id, ?next_event_id, "mainline descent",);
Ok(Some((this_event_id, next_event_id)))
})
.try_collect()
.await?;
let mainline = mainline.iter().rev().map(AsRef::as_ref);
events
.map(ToOwned::to_owned)
.broad_filter_map(async |event_id| {
let event = fetch(event_id.clone()).await.ok()?;
let origin_server_ts = event.origin_server_ts();
let position = mainline_position(Some(event), &mainline, fetch)
.await
.ok()?;
Some((event_id, (position, origin_server_ts)))
})
.inspect(|(event_id, (position, origin_server_ts))| {
trace!(position, ?origin_server_ts, ?event_id, "mainline position");
})
.collect()
.map(|mut vec: Vec<_>| {
vec.sort_by(|a, b| {
let (a_pos, a_ots) = &a.1;
let (b_pos, b_ots) = &b.1;
a_pos
.cmp(b_pos)
.then(a_ots.cmp(b_ots))
.then(a.cmp(b))
});
vec.into_iter().map(at!(0)).collect()
})
.map(Ok)
.await
}
/// Get the mainline position of the given event from the given mainline map.
///
/// ## Arguments
///
/// * `event` - The event to compute the mainline position of.
/// * `mainline_map` - The mainline map of the m.room.power_levels event.
/// * `fetch` - Function to fetch an event in the room given its event ID.
///
/// ## Returns
///
/// Returns the mainline position of the event, or an `Err(_)` if one of the
/// events in the auth chain of the event was not found.
#[tracing::instrument(
name = "position",
level = "trace",
ret(level = "trace"),
skip_all,
fields(
mainline = mainline.clone().count(),
event = ?current_event.as_ref().map(Event::event_id).map(ToOwned::to_owned),
)
)]
async fn mainline_position<'a, Mainline, Fetch, Fut, Pdu>(
mut current_event: Option<Pdu>,
mainline: &Mainline,
fetch: &Fetch,
) -> Result<usize>
where
Mainline: Iterator<Item = &'a EventId> + Clone + Send + Sync,
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
while let Some(event) = current_event {
trace!(
event_id = ?event.event_id(),
"mainline position search",
);
// If the current event is in the mainline map, return its position.
if let Some(position) = mainline
.clone()
.position(is_equal_to!(event.event_id()))
{
return Ok(position);
}
// Look for the power levels event in the auth events.
current_event = get_power_levels_auth_event(&event, fetch).await?;
}
// Did not find a power level event so we default to zero.
Ok(0)
}
#[expect(clippy::redundant_closure)]
#[tracing::instrument(level = "trace", skip_all)]
async fn get_power_levels_auth_event<Fetch, Fut, Pdu>(
event: &Pdu,
fetch: &Fetch,
) -> Result<Option<Pdu>>
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
let power_level_event = event
.auth_events()
.try_stream()
.map_ok(ToOwned::to_owned)
.and_then(|auth_event_id| fetch(auth_event_id))
.ready_try_skip_while(|auth_event| {
Ok(!auth_event.is_type_and_state_key(&TimelineEventType::RoomPowerLevels, ""))
});
pin_mut!(power_level_event);
power_level_event.try_next().await
}

View File

@@ -0,0 +1,273 @@
use std::{
borrow::Borrow,
collections::{HashMap, HashSet},
};
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use ruma::{
EventId, OwnedEventId,
events::{TimelineEventType, room::power_levels::UserPowerLevel},
room_version_rules::RoomVersionRules,
};
use tuwunel_core::{
Result, err,
matrix::Event,
utils::stream::{BroadbandExt, IterStream, TryBroadbandExt},
};
use super::super::{
events::{
RoomCreateEvent, RoomPowerLevelsEvent, RoomPowerLevelsIntField, is_power_event,
power_levels::RoomPowerLevelsEventOptionExt,
},
topological_sort,
topological_sort::ReferencedIds,
};
/// Enlarge the given list of conflicted power events by adding the events in
/// their auth chain that are in the full conflicted set, and sort it using
/// reverse topological power ordering.
///
/// ## Arguments
///
/// * `conflicted_power_events` - The list of power events in the full
/// conflicted set.
///
/// * `full_conflicted_set` - The full conflicted set.
///
/// * `rules` - The authorization rules for the current room version.
///
/// * `fetch` - Function to fetch an event in the room given its event ID.
///
/// ## Returns
///
/// Returns the ordered list of event IDs from earliest to latest.
#[tracing::instrument(
level = "debug",
skip_all,
fields(
conflicted = full_conflicted_set.len(),
)
)]
pub(super) async fn power_sort<Fetch, Fut, Pdu>(
rules: &RoomVersionRules,
full_conflicted_set: &HashSet<OwnedEventId>,
fetch: &Fetch,
) -> Result<Vec<OwnedEventId>>
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
// A representation of the DAG, a map of event ID to its list of auth events
// that are in the full conflicted set. Fill the graph.
let graph = full_conflicted_set
.iter()
.stream()
.broad_filter_map(async |id| is_power_event_id(id, fetch).await.then_some(id))
.fold(HashMap::new(), |graph, event_id| {
add_event_auth_chain(graph, full_conflicted_set, event_id, fetch)
})
.await;
// The map of event ID to the power level of the sender of the event.
// Get the power level of the sender of each event in the graph.
let event_to_power_level: HashMap<_, _> = graph
.keys()
.try_stream()
.map_ok(AsRef::as_ref)
.broad_and_then(|event_id| {
power_level_for_sender(event_id, rules, fetch)
.map_ok(move |sender_power| (event_id, sender_power))
.map_err(|e| err!(Request(NotFound("Missing PL for sender: {e}"))))
})
.try_collect()
.await?;
let query = async |event_id: OwnedEventId| {
let power_level = *event_to_power_level
.get(&event_id.borrow())
.ok_or_else(|| err!(Request(NotFound("Missing PL event: {event_id}"))))?;
let event = fetch(event_id).await?;
Ok((power_level, event.origin_server_ts()))
};
topological_sort(&graph, &query).await
}
/// Add the event with the given event ID and all the events in its auth chain
/// that are in the full conflicted set to the graph.
#[tracing::instrument(
name = "auth_chain",
level = "trace",
skip_all,
fields(
?event_id,
graph = graph.len(),
)
)]
async fn add_event_auth_chain<Fetch, Fut, Pdu>(
mut graph: HashMap<OwnedEventId, ReferencedIds>,
full_conflicted_set: &HashSet<OwnedEventId>,
event_id: &EventId,
fetch: &Fetch,
) -> HashMap<OwnedEventId, ReferencedIds>
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
let mut state = vec![event_id.to_owned()];
// Iterate through the auth chain of the event.
while let Some(event_id) = state.pop() {
// Iterate through the auth events of this event.
let event = fetch(event_id.clone()).await.ok();
// Add the current event to the graph.
graph.entry(event_id).or_default();
event
.as_ref()
.map(Event::auth_events)
.into_iter()
.flatten()
.map(ToOwned::to_owned)
.filter(|auth_event_id| full_conflicted_set.contains(auth_event_id))
.for_each(|auth_event_id| {
// If the auth event ID is not in the graph, check its auth events later.
if !graph.contains_key(&auth_event_id) {
state.push(auth_event_id.clone());
}
let event_id = event
.as_ref()
.expect("event is Some if there are auth_events")
.event_id();
// Add the auth event ID to the list of incoming edges.
let references = graph
.get_mut(event_id)
.expect("event_id must be added to graph");
if !references.contains(&auth_event_id) {
references.push(auth_event_id);
}
});
}
graph
}
/// Find the power level for the sender of the event of the given event ID or
/// return a default value of zero.
///
/// We find the most recent `m.room.power_levels` by walking backwards in the
/// auth chain of the event.
///
/// Do NOT use this anywhere but topological sort.
///
/// ## Arguments
///
/// * `event_id` - The event ID of the event to get the power level of the
/// sender of.
///
/// * `rules` - The authorization rules for the current room version.
///
/// * `fetch` - Function to fetch an event in the room given its event ID.
///
/// ## Returns
///
/// Returns the power level of the sender of the event or an `Err(_)` if one of
/// the auth events if malformed.
#[tracing::instrument(
name = "sender_power",
level = "trace",
skip_all,
fields(
?event_id,
)
)]
async fn power_level_for_sender<Fetch, Fut, Pdu>(
event_id: &EventId,
rules: &RoomVersionRules,
fetch: &Fetch,
) -> Result<UserPowerLevel>
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
let mut room_create_event = None;
let mut room_power_levels_event = None;
let event = fetch(event_id.to_owned()).await;
if let Ok(event) = &event
&& rules
.authorization
.room_create_event_id_as_room_id
{
let create_id = event.room_id().as_event_id()?;
let fetched = fetch(create_id).await?;
room_create_event = Some(RoomCreateEvent::new(fetched));
}
for auth_event_id in event
.as_ref()
.map(Event::auth_events)
.into_iter()
.flatten()
{
if let Ok(auth_event) = fetch(auth_event_id.to_owned()).await {
if auth_event.is_type_and_state_key(&TimelineEventType::RoomPowerLevels, "") {
room_power_levels_event = Some(RoomPowerLevelsEvent::new(auth_event));
} else if !rules
.authorization
.room_create_event_id_as_room_id
&& auth_event.is_type_and_state_key(&TimelineEventType::RoomCreate, "")
{
room_create_event = Some(RoomCreateEvent::new(auth_event));
}
if room_power_levels_event.is_some() && room_create_event.is_some() {
break;
}
}
}
let auth_rules = &rules.authorization;
let creators = room_create_event
.as_ref()
.and_then(|event| event.creators(auth_rules).ok());
if let Some((event, creators)) = event.ok().zip(creators) {
room_power_levels_event.user_power_level(event.sender(), creators, auth_rules)
} else {
room_power_levels_event
.get_as_int_or_default(RoomPowerLevelsIntField::UsersDefault, auth_rules)
.map(Into::into)
}
}
/// Whether the given event ID belongs to a power event.
///
/// See the docs of `is_power_event()` for the definition of a power event.
#[tracing::instrument(
name = "is_power_event",
level = "trace",
skip_all,
fields(
?event_id,
)
)]
async fn is_power_event_id<Fetch, Fut, Pdu>(event_id: &EventId, fetch: &Fetch) -> bool
where
Fetch: Fn(OwnedEventId) -> Fut + Sync,
Fut: Future<Output = Result<Pdu>> + Send,
Pdu: Event,
{
match fetch(event_id.to_owned()).await {
| Ok(state) => is_power_event(&state),
| _ => false,
}
}

View File

@@ -0,0 +1,82 @@
use std::{collections::HashMap, hash::Hash, iter::IntoIterator};
use futures::{Stream, StreamExt};
use tuwunel_core::validated;
use super::{ConflictMap, StateMap};
/// Split the unconflicted state map and the conflicted state set.
///
/// Definition in the specification:
///
/// If a given key _K_ is present in every _Si_ with the same value _V_ in each
/// state map, then the pair (_K_, _V_) belongs to the unconflicted state map.
/// Otherwise, _V_ belongs to the conflicted state set.
///
/// It means that, for a given (event type, state key) tuple, if all state maps
/// have the same event ID, it lands in the unconflicted state map, otherwise
/// the event IDs land in the conflicted state set.
///
/// ## Arguments
///
/// * `state_maps` - The incoming states to resolve. Each `StateMap` represents
/// a possible fork in the state of a room.
///
/// ## Returns
///
/// Returns an `(unconflicted_state, conflicted_states)` tuple.
#[tracing::instrument(name = "split", level = "debug", skip_all)]
pub(super) async fn split_conflicted_state<'a, Maps, Id>(
state_maps: Maps,
) -> (StateMap<Id>, ConflictMap<Id>)
where
Maps: Stream<Item = StateMap<Id>>,
Id: Clone + Eq + Hash + Ord + Send + Sync + 'a,
{
let state_maps: Vec<_> = state_maps.collect().await;
let state_ids_est = state_maps.iter().flatten().count();
let state_set_count = state_maps
.iter()
.fold(0_usize, |acc, _| validated!(acc + 1));
let mut occurrences = HashMap::<_, HashMap<_, usize>>::with_capacity(state_ids_est);
for (k, v) in state_maps
.into_iter()
.flat_map(IntoIterator::into_iter)
{
let acc = occurrences
.entry(k.clone())
.or_default()
.entry(v.clone())
.or_default();
*acc = acc.saturating_add(1);
}
let mut unconflicted_state_map = StateMap::new();
let mut conflicted_state_set = ConflictMap::new();
for (k, v) in occurrences {
for (id, occurrence_count) in v {
if occurrence_count == state_set_count {
unconflicted_state_map.insert((k.0.clone(), k.1.clone()), id.clone());
} else {
let conflicts = conflicted_state_set
.entry((k.0.clone(), k.1.clone()))
.or_default();
debug_assert!(
!conflicts.contains(&id),
"Unexpected duplicate conflicted state event"
);
conflicts.push(id.clone());
}
}
}
(unconflicted_state_map, conflicted_state_set)
}

View File

@@ -0,0 +1,843 @@
use std::{collections::HashMap, iter::once};
use futures::StreamExt;
use maplit::hashmap;
use rand::seq::SliceRandom;
use ruma::{
MilliSecondsSinceUnixEpoch, OwnedEventId,
events::{
StateEventType, TimelineEventType,
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
},
int,
room_version_rules::RoomVersionRules,
uint,
};
use serde_json::{json, value::to_raw_value as to_raw_json_value};
use tuwunel_core::{
debug,
matrix::{Event, EventTypeExt, PduEvent},
utils::stream::IterStream,
};
use super::{
StateMap,
test_utils::{
INITIAL_EVENTS, TestStore, alice, bob, charlie, do_check, ella, event_id,
member_content_ban, member_content_join, not_found, room_id, to_init_pdu_event,
to_pdu_event, zara,
},
};
async fn test_event_sort() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let rules = RoomVersionRules::V6;
let events = INITIAL_EVENTS();
let auth_chain = Default::default();
let sorted_power_events = super::power_sort(&rules, &auth_chain, &async |id| {
events.get(&id).cloned().ok_or_else(not_found)
})
.await
.unwrap();
let sorted_power_events = sorted_power_events
.iter()
.stream()
.map(AsRef::as_ref);
let resolved_power =
super::iterative_auth_check(&rules, sorted_power_events, StateMap::new(), &async |id| {
events.get(&id).cloned().ok_or_else(not_found)
})
.await
.expect("iterative auth check failed on resolved events");
// don't remove any events so we know it sorts them all correctly
let mut events_to_sort = events.keys().cloned().collect::<Vec<_>>();
events_to_sort.shuffle(&mut rand::thread_rng());
let power_level = resolved_power
.get(&(StateEventType::RoomPowerLevels, "".into()))
.cloned();
let events_to_sort = events_to_sort.iter().stream().map(AsRef::as_ref);
let sorted_event_ids = super::mainline_sort(power_level, events_to_sort, &async |id| {
events.get(&id).cloned().ok_or_else(not_found)
})
.await
.unwrap();
assert_eq!(
vec![
"$CREATE:foo",
"$IMA:foo",
"$IPOWER:foo",
"$IJR:foo",
"$IMB:foo",
"$IMC:foo",
"$START:foo",
"$END:foo"
],
sorted_event_ids
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn test_sort() {
for _ in 0..20 {
// since we shuffle the eventIds before we sort them introducing randomness
// seems like we should test this a few times
test_event_sort().await;
}
}
#[tokio::test]
async fn ban_vs_power_level() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let events = &[
to_init_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"MA",
alice(),
TimelineEventType::RoomMember,
Some(alice().to_string().as_str()),
member_content_join(),
),
to_init_pdu_event(
"MB",
alice(),
TimelineEventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_ban(),
),
to_init_pdu_event(
"PB",
bob(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
];
let edges = vec![vec!["END", "MB", "MA", "PA", "START"], vec!["END", "PA", "PB"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["PA", "MA", "MB"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(events, edges, expected_state_ids).await;
}
#[tokio::test]
async fn topic_basic() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let events = vec![
to_init_pdu_event(
"T1",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA1",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T2",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA2",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
),
to_init_pdu_event(
"PB",
bob(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T3",
bob(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
];
let edges =
vec![vec!["END", "PA2", "T2", "PA1", "T1", "START"], vec!["END", "T3", "PB", "PA1"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["PA2", "T2"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(&events, edges, expected_state_ids).await;
}
#[tokio::test]
async fn topic_reset() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let events = &[
to_init_pdu_event(
"T1",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T2",
bob(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"MB",
alice(),
TimelineEventType::RoomMember,
Some(bob().to_string().as_str()),
member_content_ban(),
),
];
let edges = vec![vec!["END", "MB", "T2", "PA", "T1", "START"], vec!["END", "T1"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["T1", "MB", "PA"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(events, edges, expected_state_ids).await;
}
#[tokio::test]
async fn join_rule_evasion() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let events = &[
to_init_pdu_event(
"JR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Private)).unwrap(),
),
to_init_pdu_event(
"ME",
ella(),
TimelineEventType::RoomMember,
Some(ella().to_string().as_str()),
member_content_join(),
),
];
let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec![event_id("JR")];
do_check(events, edges, expected_state_ids).await;
}
#[tokio::test]
async fn offtopic_power_level() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let events = &[
to_init_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"PB",
bob(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 50 } }))
.unwrap(),
),
to_init_pdu_event(
"PC",
charlie(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 0 } }))
.unwrap(),
),
];
let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["PC"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(events, edges, expected_state_ids).await;
}
#[tokio::test]
async fn topic_setting() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let events = vec![
to_init_pdu_event(
"T1",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA1",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T2",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"PA2",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(),
),
to_init_pdu_event(
"PB",
bob(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
),
to_init_pdu_event(
"T3",
bob(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"MZ1",
zara(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
to_init_pdu_event(
"T4",
alice(),
TimelineEventType::RoomTopic,
Some(""),
to_raw_json_value(&json!({})).unwrap(),
),
];
let edges = vec![vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], vec![
"END", "MZ1", "T3", "PB", "PA1",
]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["T4", "PA2"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(&events, edges, expected_state_ids).await;
}
#[tokio::test]
async fn test_event_map_none() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let mut store = TestStore(hashmap! {});
// build up the DAG
let (state_at_bob, state_at_charlie, expected) = store.set_up();
let ev_map = store.0.clone();
let state_sets = [state_at_bob, state_at_charlie];
let auth_chains = state_sets
.iter()
.map(|map| {
store
.auth_event_ids(room_id(), map.values().cloned().collect())
.unwrap()
})
.collect::<Vec<_>>();
let rules = RoomVersionRules::V1;
let resolved = match super::resolve(
&rules,
state_sets.into_iter().stream(),
auth_chains.into_iter().stream(),
&async |id| ev_map.get(&id).cloned().ok_or_else(not_found),
&async |id| ev_map.contains_key(&id),
false,
)
.await
{
| Ok(state) => state,
| Err(e) => panic!("{e}"),
};
assert_eq!(expected, resolved);
}
#[tokio::test]
#[expect(
clippy::iter_on_single_items,
clippy::iter_on_empty_collections
)]
async fn test_reverse_topological_power_sort() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let graph = hashmap! {
event_id("l") => [event_id("o")].into_iter().collect(),
event_id("m") => [event_id("n"), event_id("o")].into_iter().collect(),
event_id("n") => [event_id("o")].into_iter().collect(),
event_id("o") => [].into_iter().collect(), // "o" has zero outgoing edges but 4 incoming edges
event_id("p") => [event_id("o")].into_iter().collect(),
};
let res = super::super::topological_sort(&graph, &async |_id| {
Ok((int!(0).into(), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.await
.unwrap();
assert_eq!(
vec!["o", "l", "n", "m", "p"],
res.iter()
.map(ToString::to_string)
.map(|s| s.replace('$', "").replace(":foo", ""))
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn ban_with_auth_chains() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let ban = BAN_STATE_SET();
let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["PA", "MB"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(&ban.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
}
#[tokio::test]
async fn ban_with_auth_chains2() {
_ = tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_test_writer()
.finish(),
);
let init = INITIAL_EVENTS();
let ban = BAN_STATE_SET();
let mut inner = init.clone();
inner.extend(ban);
let store = TestStore(inner.clone());
let state_set_a = [
&inner[&event_id("CREATE")],
&inner[&event_id("IJR")],
&inner[&event_id("IMA")],
&inner[&event_id("IMB")],
&inner[&event_id("IMC")],
&inner[&event_id("MB")],
&inner[&event_id("PA")],
]
.iter()
.map(|ev| {
(
ev.event_type()
.with_state_key(ev.state_key().unwrap()),
ev.event_id.clone(),
)
})
.collect::<StateMap<_>>();
let state_set_b = [
&inner[&event_id("CREATE")],
&inner[&event_id("IJR")],
&inner[&event_id("IMA")],
&inner[&event_id("IMB")],
&inner[&event_id("IMC")],
&inner[&event_id("IME")],
&inner[&event_id("PA")],
]
.iter()
.map(|ev| {
(
ev.event_type()
.with_state_key(ev.state_key().unwrap()),
ev.event_id.clone(),
)
})
.collect::<StateMap<_>>();
let ev_map = &store.0;
let state_sets = [state_set_a, state_set_b];
let auth_chains = state_sets
.iter()
.map(|map| {
store
.auth_event_ids(room_id(), map.values().cloned().collect())
.unwrap()
})
.collect::<Vec<_>>();
let resolved = match super::resolve(
&RoomVersionRules::V6,
state_sets.into_iter().stream(),
auth_chains.into_iter().stream(),
&async |id| ev_map.get(&id).cloned().ok_or_else(not_found),
&async |id| ev_map.contains_key(&id),
false,
)
.await
{
| Ok(state) => state,
| Err(e) => panic!("{e}"),
};
debug!(
resolved = ?resolved
.iter()
.map(|((ty, key), id)| format!("(({ty}{key:?}), {id})"))
.collect::<Vec<_>>(),
"resolved state",
);
let expected = [
"$CREATE:foo",
"$IJR:foo",
"$PA:foo",
"$IMA:foo",
"$IMB:foo",
"$IMC:foo",
"$MB:foo",
];
for id in expected.iter().map(|i| event_id(i)) {
// make sure our resolved events are equal to the expected list
assert!(resolved.values().any(|eid| eid == &id) || init.contains_key(&id), "{id}");
}
assert_eq!(expected.len(), resolved.len());
}
#[tokio::test]
async fn join_rule_with_auth_chain() {
let join_rule = JOIN_RULE();
let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]]
.into_iter()
.map(|list| list.into_iter().map(event_id).collect::<Vec<_>>())
.collect::<Vec<_>>();
let expected_state_ids = vec!["JR"]
.into_iter()
.map(event_id)
.collect::<Vec<_>>();
do_check(&join_rule.values().cloned().collect::<Vec<_>>(), edges, expected_state_ids).await;
}
#[expect(non_snake_case)]
fn BAN_STATE_SET() -> HashMap<OwnedEventId, PduEvent> {
vec![
to_pdu_event(
"PA",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
&["CREATE", "IMA", "IPOWER"], // auth_events
&["START"], // prev_events
),
to_pdu_event(
"PB",
alice(),
TimelineEventType::RoomPowerLevels,
Some(""),
to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["END"],
),
to_pdu_event(
"MB",
alice(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
member_content_ban(),
&["CREATE", "IMA", "PB"],
&["PA"],
),
to_pdu_event(
"IME",
ella(),
TimelineEventType::RoomMember,
Some(ella().as_str()),
member_content_join(),
&["CREATE", "IJR", "PA"],
&["MB"],
),
]
.into_iter()
.map(|ev| (ev.event_id.clone(), ev))
.collect()
}
#[expect(non_snake_case)]
fn JOIN_RULE() -> HashMap<OwnedEventId, PduEvent> {
vec![
to_pdu_event(
"JR",
alice(),
TimelineEventType::RoomJoinRules,
Some(""),
to_raw_json_value(&json!({ "join_rule": "invite" })).unwrap(),
&["CREATE", "IMA", "IPOWER"],
&["START"],
),
to_pdu_event(
"IMZ",
zara(),
TimelineEventType::RoomPowerLevels,
Some(zara().as_str()),
member_content_join(),
&["CREATE", "JR", "IPOWER"],
&["START"],
),
]
.into_iter()
.map(|ev| (ev.event_id.clone(), ev))
.collect()
}
macro_rules! state_set {
($($kind:expr => $key:expr => $id:expr),* $(,)?) => {{
let mut x = StateMap::new();
$(
x.insert(($kind, $key.into()), $id);
)*
x
}};
}
#[tokio::test]
async fn split_conflicted_state_set_conflicted_unique_state_keys() {
let (unconflicted, conflicted) = super::split_conflicted_state(
[
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![StateEventType::RoomMember => "@b:hs1" => 1],
state_set![StateEventType::RoomMember => "@c:hs1" => 2],
]
.into_iter()
.stream(),
)
.await;
let (unconflicted, conflicted): (StateMap<_>, StateMap<_>) =
(unconflicted.into_iter().collect(), conflicted.into_iter().collect());
assert_eq!(unconflicted, StateMap::new());
assert_eq!(conflicted, state_set![
StateEventType::RoomMember => "@a:hs1" => once(0).collect(),
StateEventType::RoomMember => "@b:hs1" => once(1).collect(),
StateEventType::RoomMember => "@c:hs1" => once(2).collect(),
],);
}
#[tokio::test]
async fn split_conflicted_state_set_conflicted_same_state_key() {
let (unconflicted, conflicted) = super::split_conflicted_state(
[
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![StateEventType::RoomMember => "@a:hs1" => 1],
state_set![StateEventType::RoomMember => "@a:hs1" => 2],
]
.into_iter()
.stream(),
)
.await;
let (unconflicted, mut conflicted): (StateMap<_>, StateMap<_>) =
(unconflicted.into_iter().collect(), conflicted.into_iter().collect());
// HashMap iteration order is random, so sort this before asserting on it
for v in conflicted.values_mut() {
v.sort_unstable();
}
assert_eq!(unconflicted, StateMap::new());
assert_eq!(conflicted, state_set![
StateEventType::RoomMember => "@a:hs1" => [0, 1, 2].into_iter().collect(),
],);
}
#[tokio::test]
async fn split_conflicted_state_set_unconflicted() {
let (unconflicted, conflicted) = super::split_conflicted_state(
[
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
]
.into_iter()
.stream(),
)
.await;
let (unconflicted, conflicted): (StateMap<_>, StateMap<_>) =
(unconflicted.into_iter().collect(), conflicted.into_iter().collect());
assert_eq!(unconflicted, state_set![
StateEventType::RoomMember => "@a:hs1" => 0,
],);
assert_eq!(conflicted, StateMap::new());
}
#[tokio::test]
async fn split_conflicted_state_set_mixed() {
let (unconflicted, conflicted) = super::split_conflicted_state(
[
state_set![StateEventType::RoomMember => "@a:hs1" => 0],
state_set![
StateEventType::RoomMember => "@a:hs1" => 0,
StateEventType::RoomMember => "@b:hs1" => 1,
],
state_set![
StateEventType::RoomMember => "@a:hs1" => 0,
StateEventType::RoomMember => "@c:hs1" => 2,
],
]
.into_iter()
.stream(),
)
.await;
let (unconflicted, conflicted): (StateMap<_>, StateMap<_>) =
(unconflicted.into_iter().collect(), conflicted.into_iter().collect());
assert_eq!(unconflicted, state_set![
StateEventType::RoomMember => "@a:hs1" => 0,
],);
assert_eq!(conflicted, state_set![
StateEventType::RoomMember => "@b:hs1" => once(1).collect(),
StateEventType::RoomMember => "@c:hs1" => once(2).collect(),
],);
}