Add PduCount value to userroomid/roomuserid_joined; move PduCount to argument for update_membership.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-10-01 23:01:57 +00:00
parent eda45e445c
commit f1c2548807
8 changed files with 55 additions and 26 deletions

View File

@@ -22,6 +22,7 @@ use ruma::{
use tuwunel_core::{
Err, Result, debug, debug_info, debug_warn, err, extract_variant, info,
matrix::{
PduCount,
event::{Event, gen_event_id},
pdu::{PduBuilder, PduEvent},
},
@@ -276,6 +277,7 @@ async fn knock_room_helper_local(
.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
info!("Updating membership locally to knock state with provided stripped state events");
let count = services.globals.next_count();
services
.state_cache
.update_membership(
@@ -294,6 +296,7 @@ async fn knock_room_helper_local(
),
None,
false,
PduCount::Normal(*count),
)
.await?;
@@ -490,6 +493,7 @@ async fn knock_room_helper_remote(
.await?;
info!("Updating membership locally to knock state with provided stripped state events");
let count = services.globals.next_count();
services
.state_cache
.update_membership(
@@ -508,6 +512,7 @@ async fn knock_room_helper_remote(
),
None,
false,
PduCount::Normal(*count),
)
.await?;

View File

@@ -12,7 +12,7 @@ use ruma::{
};
use tuwunel_core::{
Err, Error, Result, err, extract_variant,
matrix::{Event, PduEvent, event::gen_event_id},
matrix::{Event, PduCount, PduEvent, event::gen_event_id},
utils,
utils::hash::sha256,
warn,
@@ -143,6 +143,7 @@ pub(crate) async fn create_invite_route(
.server_in_room(services.globals.server_name(), &body.room_id)
.await
{
let count = services.globals.next_count();
services
.state_cache
.update_membership(
@@ -153,8 +154,10 @@ pub(crate) async fn create_invite_route(
Some(invite_state),
body.via.clone(),
true,
PduCount::Normal(*count),
)
.await?;
drop(count);
for appservice in services.appservice.read().await.values() {
if appservice.is_user_match(&invited_user) {

View File

@@ -11,6 +11,7 @@ use ruma::{
};
use tuwunel_core::{
Err, Result, debug_info, debug_warn, err, implement,
matrix::PduCount,
pdu::PduBuilder,
utils::{self, FutureBoolExt, future::ReadyEqExt},
warn,
@@ -51,6 +52,7 @@ pub async fn leave(
if is_banned.or(is_disabled).await {
// the room is banned/disabled, the room must be rejected locally since we
// cant/dont want to federate with this server
let count = self.services.globals.next_count();
self.services
.state_cache
.update_membership(
@@ -61,6 +63,7 @@ pub async fn leave(
None,
None,
true,
PduCount::Normal(*count),
)
.await?;
@@ -104,6 +107,7 @@ pub async fn leave(
.ok();
// We always drop the invite, we can't rely on other servers
let count = self.services.globals.next_count();
self.services
.state_cache
.update_membership(
@@ -114,6 +118,7 @@ pub async fn leave(
last_state,
None,
true,
PduCount::Normal(*count),
)
.await?;
} else {
@@ -131,6 +136,7 @@ pub async fn leave(
"Trying to leave a room you are not a member of, marking room as left locally."
);
let count = self.services.globals.next_count();
return self
.services
.state_cache
@@ -142,6 +148,7 @@ pub async fn leave(
None,
None,
true,
PduCount::Normal(*count),
)
.await;
};

View File

@@ -11,6 +11,7 @@ use ruma::{
};
use tuwunel_core::{
Err, Result, debug, debug_info, debug_warn, error, info,
matrix::PduCount,
result::NotFound,
utils::{
IterStream, ReadyExt,
@@ -450,16 +451,18 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services)
for user_id in &joined_members {
debug_info!("User is joined, marking as joined");
let count = services.globals.next_count();
services
.state_cache
.mark_as_joined(user_id, room_id);
.mark_as_joined(user_id, room_id, PduCount::Normal(*count));
}
for user_id in &non_joined_members {
debug_info!("User is left or banned, marking as left");
let count = services.globals.next_count();
services
.state_cache
.mark_as_left(user_id, room_id);
.mark_as_left(user_id, room_id, PduCount::Normal(*count));
}
}

View File

@@ -12,7 +12,7 @@ use tuwunel_core::{
Event, PduEvent, Result, err,
error::inspect_debug_log,
implement,
matrix::{RoomVersionRules, StateKey, TypeStateKey, room_version},
matrix::{PduCount, RoomVersionRules, StateKey, TypeStateKey, room_version},
result::{AndThenRef, FlatOk},
state_res::{StateMap, auth_types_for_event},
trace,
@@ -126,6 +126,7 @@ pub async fn force_state(
return Ok(());
};
let count = self.services.globals.next_count();
self.services
.state_cache
.update_membership(
@@ -136,6 +137,7 @@ pub async fn force_state(
None,
None,
false,
PduCount::Normal(*count),
)
.await
},

View File

@@ -371,6 +371,17 @@ pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result
.deserialized()
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "trace")]
pub async fn get_joined_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db
.roomuserid_joined
.qry(&key)
.await
.deserialized()
}
/// Returns an iterator over all rooms this user joined.
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]

View File

@@ -14,7 +14,7 @@ use ruma::{
},
serde::Raw,
};
use tuwunel_core::{Result, implement, is_not_empty, utils::ReadyExt, warn};
use tuwunel_core::{Result, implement, is_not_empty, matrix::PduCount, utils::ReadyExt, warn};
use tuwunel_database::{Json, serialize_key};
/// Update current membership data.
@@ -26,6 +26,7 @@ use tuwunel_database::{Json, serialize_key};
%room_id,
%user_id,
%sender,
%count,
?membership_event,
),
)]
@@ -39,6 +40,7 @@ pub async fn update_membership(
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
update_joined_count: bool,
count: PduCount,
) -> Result {
let membership = membership_event.membership;
@@ -58,11 +60,6 @@ pub async fn update_membership(
match &membership {
| MembershipState::Join => {
// Increment the counter for a unique value and hold the guard even though the
// value is not used. This will allow proper sync to clients similar to the
// other membership state changes.
let _next_count = self.services.globals.next_count();
// Check if the user never joined this room
if !self.once_joined(user_id, room_id).await {
// Add the user ID to the join list then
@@ -128,7 +125,7 @@ pub async fn update_membership(
}
}
self.mark_as_joined(user_id, room_id);
self.mark_as_joined(user_id, room_id, count);
},
| MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver
@@ -141,11 +138,11 @@ pub async fn update_membership(
return Ok(());
}
self.mark_as_invited(user_id, room_id, last_state, invite_via)
self.mark_as_invited(user_id, room_id, count, last_state, invite_via)
.await;
},
| MembershipState::Leave | MembershipState::Ban => {
self.mark_as_left(user_id, room_id);
self.mark_as_left(user_id, room_id, count);
if self.services.globals.user_is_local(user_id)
&& (self.services.config.forget_forced_upon_leave
@@ -241,15 +238,19 @@ pub async fn update_joined_count(&self, room_id: &RoomId) {
/// `update_membership` instead
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub(crate) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) {
pub(crate) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
self.db.userroomid_joined.insert(&userroom_id, []);
self.db.roomuserid_joined.insert(&roomuser_id, []);
self.db
.userroomid_joined
.raw_aput::<8, _, _>(&userroom_id, count.into_unsigned());
self.db
.roomuserid_joined
.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
self.db
.userroomid_invitestate
@@ -276,9 +277,7 @@ pub(crate) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) {
/// `update_membership` instead
#[implement(super::Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub(crate) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
let count = self.services.globals.next_count();
pub(crate) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId, count: PduCount) {
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
@@ -293,7 +292,7 @@ pub(crate) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
.raw_put(&userroom_id, Json(leftstate));
self.db
.roomuserid_leftcount
.raw_aput::<8, _, _>(&roomuser_id, *count);
.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
@@ -324,10 +323,9 @@ pub(crate) fn _mark_as_knocked(
&self,
user_id: &UserId,
room_id: &RoomId,
count: PduCount,
knocked_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
) {
let count = self.services.globals.next_count();
let userroom_id = (user_id, room_id);
let userroom_id = serialize_key(userroom_id).expect("failed to serialize userroom_id");
@@ -339,7 +337,7 @@ pub(crate) fn _mark_as_knocked(
.raw_put(&userroom_id, Json(knocked_state.unwrap_or_default()));
self.db
.roomuserid_knockedcount
.raw_aput::<8, _, _>(&roomuser_id, *count);
.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
@@ -381,11 +379,10 @@ pub(crate) async fn mark_as_invited(
&self,
user_id: &UserId,
room_id: &RoomId,
count: PduCount,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) {
let count = self.services.globals.next_count();
let roomuser_id = (room_id, user_id);
let roomuser_id = serialize_key(roomuser_id).expect("failed to serialize roomuser_id");
@@ -397,7 +394,7 @@ pub(crate) async fn mark_as_invited(
.raw_put(&userroom_id, Json(last_state.unwrap_or_default()));
self.db
.roomuserid_invitecount
.raw_aput::<8, _, _>(&roomuser_id, *count);
.raw_aput::<8, _, _>(&roomuser_id, count.into_unsigned());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);

View File

@@ -363,6 +363,7 @@ where
stripped_state,
None,
true,
count,
)
.await?;
}