Increment accumulators prior to release-action (fixes #253).

↳ userroomid_notificationcount and userroomid_highlightcount should be
incremented prior to touching useridcount_notification.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-13 06:01:20 +00:00
parent cf8b57b751
commit fd8ee422dd
2 changed files with 44 additions and 50 deletions

View File

@@ -1,8 +1,11 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use futures::{FutureExt, StreamExt, future::join}; use futures::{
FutureExt, StreamExt,
future::{OptionFuture, join},
};
use ruma::{ use ruma::{
OwnedUserId, RoomId, UserId, RoomId, UserId,
api::client::push::ProfileTag, api::client::push::ProfileTag,
events::{GlobalAccountDataEventType, TimelineEventType, push_rules::PushRulesEvent}, events::{GlobalAccountDataEventType, TimelineEventType, push_rules::PushRulesEvent},
push::{Action, Actions, Ruleset, Tweak}, push::{Action, Actions, Ruleset, Tweak},
@@ -68,9 +71,6 @@ pub(crate) async fn append_pdu(&self, pdu_id: RawPduId, pdu: &Pdu) -> Result {
let (mut push_target, power_levels) = join(push_target, power_levels).boxed().await; let (mut push_target, power_levels) = join(push_target, power_levels).boxed().await;
let mut notifies = Vec::with_capacity(push_target.len().saturating_add(1));
let mut highlights = Vec::with_capacity(push_target.len().saturating_add(1));
if *pdu.kind() == TimelineEventType::RoomMember { if *pdu.kind() == TimelineEventType::RoomMember {
if let Some(Ok(target_user_id)) = pdu.state_key().map(UserId::parse) { if let Some(Ok(target_user_id)) = pdu.state_key().map(UserId::parse) {
if self if self
@@ -85,6 +85,7 @@ pub(crate) async fn append_pdu(&self, pdu_id: RawPduId, pdu: &Pdu) -> Result {
} }
let serialized = pdu.to_format(); let serialized = pdu.to_format();
let _cork = self.db.db.cork();
for user in &push_target { for user in &push_target {
let rules_for_user = self let rules_for_user = self
.services .services
@@ -96,37 +97,29 @@ pub(crate) async fn append_pdu(&self, pdu_id: RawPduId, pdu: &Pdu) -> Result {
|ev: PushRulesEvent| ev.content.global, |ev: PushRulesEvent| ev.content.global,
); );
let mut highlight = false;
let mut notify = false;
let actions = self let actions = self
.services .services
.pusher .pusher
.get_actions(user, &rules_for_user, power_levels.as_ref(), &serialized, pdu.room_id()) .get_actions(user, &rules_for_user, power_levels.as_ref(), &serialized, pdu.room_id())
.await; .await;
for action in actions { let notify = actions
match action { .iter()
| Action::Notify => notify = true, .any(|action| matches!(action, Action::Notify));
| Action::SetTweak(Tweak::Highlight(true)) => {
highlight = true;
},
| _ => {},
}
// Break early if both conditions are true let highlight = actions
if notify && highlight { .iter()
break; .any(|action| matches!(action, Action::SetTweak(Tweak::Highlight(true))));
}
}
if notify { let increment_notify: OptionFuture<_> = notify
notifies.push(user.clone()); .then(|| self.increment_notificationcount(pdu.room_id(), user))
} .into();
if highlight { let increment_highlight: OptionFuture<_> = highlight
highlights.push(user.clone()); .then(|| self.increment_highlightcount(pdu.room_id(), user))
} .into();
join(increment_notify, increment_highlight).await;
if notify || highlight { if notify || highlight {
let id: PduId = pdu_id.into(); let id: PduId = pdu_id.into();
@@ -159,31 +152,27 @@ pub(crate) async fn append_pdu(&self, pdu_id: RawPduId, pdu: &Pdu) -> Result {
} }
} }
self.increment_notification_counts(pdu.room_id(), notifies, highlights)
.await;
Ok(()) Ok(())
} }
#[implement(super::Service)] #[implement(super::Service)]
async fn increment_notification_counts( async fn increment_notificationcount(&self, room_id: &RoomId, user_id: &UserId) {
&self, let db = &self.db.userroomid_notificationcount;
room_id: &RoomId, let key = (room_id.to_owned(), user_id.to_owned());
notifies: Vec<OwnedUserId>, let _lock = self.notification_increment_mutex.lock(&key).await;
highlights: Vec<OwnedUserId>,
) {
let _cork = self.db.db.cork();
for user in notifies { increment(db, (user_id, room_id)).await;
increment(&self.db.userroomid_notificationcount, (&user, room_id)).await; }
}
#[implement(super::Service)]
for user in highlights { async fn increment_highlightcount(&self, room_id: &RoomId, user_id: &UserId) {
increment(&self.db.userroomid_highlightcount, (&user, room_id)).await; let db = &self.db.userroomid_highlightcount;
} let key = (room_id.to_owned(), user_id.to_owned());
let _lock = self.highlight_increment_mutex.lock(&key).await;
increment(db, (user_id, room_id)).await;
} }
// TODO: this is an ABA problem
async fn increment(db: &Arc<Map>, key: (&UserId, &RoomId)) { async fn increment(db: &Arc<Map>, key: (&UserId, &RoomId)) {
let old: u64 = db.qry(&key).await.deserialized().unwrap_or(0); let old: u64 = db.qry(&key).await.deserialized().unwrap_or(0);
let new = old.saturating_add(1); let new = old.saturating_add(1);

View File

@@ -8,7 +8,7 @@ use std::sync::Arc;
use futures::{Stream, StreamExt, TryFutureExt, future::join}; use futures::{Stream, StreamExt, TryFutureExt, future::join};
use ipaddress::IPAddress; use ipaddress::IPAddress;
use ruma::{ use ruma::{
DeviceId, OwnedDeviceId, RoomId, UserId, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId,
api::client::push::{Pusher, PusherKind, set_pusher}, api::client::push::{Pusher, PusherKind, set_pusher},
events::{AnySyncTimelineEvent, room::power_levels::RoomPowerLevels}, events::{AnySyncTimelineEvent, room::power_levels::RoomPowerLevels},
push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, Ruleset}, push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, Ruleset},
@@ -18,6 +18,7 @@ use ruma::{
use tuwunel_core::{ use tuwunel_core::{
Err, Result, err, implement, Err, Result, err, implement,
utils::{ utils::{
MutexMap,
future::TryExtExt, future::TryExtExt,
stream::{BroadbandExt, ReadyExt, TryIgnore}, stream::{BroadbandExt, ReadyExt, TryIgnore},
}, },
@@ -27,24 +28,30 @@ use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Json, Map};
pub use self::append::Notified; pub use self::append::Notified;
pub struct Service { pub struct Service {
db: Data,
services: Arc<crate::services::OnceServices>, services: Arc<crate::services::OnceServices>,
notification_increment_mutex: MutexMap<(OwnedRoomId, OwnedUserId), ()>,
highlight_increment_mutex: MutexMap<(OwnedRoomId, OwnedUserId), ()>,
db: Data,
} }
struct Data { struct Data {
db: Arc<Database>,
senderkey_pusher: Arc<Map>, senderkey_pusher: Arc<Map>,
pushkey_deviceid: Arc<Map>, pushkey_deviceid: Arc<Map>,
useridcount_notification: Arc<Map>, useridcount_notification: Arc<Map>,
userroomid_highlightcount: Arc<Map>, userroomid_highlightcount: Arc<Map>,
userroomid_notificationcount: Arc<Map>, userroomid_notificationcount: Arc<Map>,
roomuserid_lastnotificationread: Arc<Map>, roomuserid_lastnotificationread: Arc<Map>,
db: Arc<Database>,
} }
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self { Ok(Arc::new(Self {
services: args.services.clone(),
notification_increment_mutex: MutexMap::new(),
highlight_increment_mutex: MutexMap::new(),
db: Data { db: Data {
db: args.db.clone(),
senderkey_pusher: args.db["senderkey_pusher"].clone(), senderkey_pusher: args.db["senderkey_pusher"].clone(),
pushkey_deviceid: args.db["pushkey_deviceid"].clone(), pushkey_deviceid: args.db["pushkey_deviceid"].clone(),
useridcount_notification: args.db["useridcount_notification"].clone(), useridcount_notification: args.db["useridcount_notification"].clone(),
@@ -52,9 +59,7 @@ impl crate::Service for Service {
userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(), userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(),
roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"] roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"]
.clone(), .clone(),
db: args.db.clone(),
}, },
services: args.services.clone(),
})) }))
} }