diff --git a/src/service/pusher/append.rs b/src/service/pusher/append.rs index 64530d76..12162189 100644 --- a/src/service/pusher/append.rs +++ b/src/service/pusher/append.rs @@ -1,8 +1,11 @@ use std::{collections::HashSet, sync::Arc}; -use futures::{FutureExt, StreamExt, future::join}; +use futures::{ + FutureExt, StreamExt, + future::{OptionFuture, join}, +}; use ruma::{ - OwnedUserId, RoomId, UserId, + RoomId, UserId, api::client::push::ProfileTag, events::{GlobalAccountDataEventType, TimelineEventType, push_rules::PushRulesEvent}, push::{Action, Actions, Ruleset, Tweak}, @@ -68,9 +71,6 @@ pub(crate) async fn append_pdu(&self, pdu_id: RawPduId, pdu: &Pdu) -> Result { let (mut push_target, power_levels) = join(push_target, power_levels).boxed().await; - let mut notifies = Vec::with_capacity(push_target.len().saturating_add(1)); - let mut highlights = Vec::with_capacity(push_target.len().saturating_add(1)); - if *pdu.kind() == TimelineEventType::RoomMember { if let Some(Ok(target_user_id)) = pdu.state_key().map(UserId::parse) { if self @@ -85,6 +85,7 @@ pub(crate) async fn append_pdu(&self, pdu_id: RawPduId, pdu: &Pdu) -> Result { } let serialized = pdu.to_format(); + let _cork = self.db.db.cork(); for user in &push_target { let rules_for_user = self .services @@ -96,37 +97,29 @@ pub(crate) async fn append_pdu(&self, pdu_id: RawPduId, pdu: &Pdu) -> Result { |ev: PushRulesEvent| ev.content.global, ); - let mut highlight = false; - let mut notify = false; - let actions = self .services .pusher .get_actions(user, &rules_for_user, power_levels.as_ref(), &serialized, pdu.room_id()) .await; - for action in actions { - match action { - | Action::Notify => notify = true, - | Action::SetTweak(Tweak::Highlight(true)) => { - highlight = true; - }, - | _ => {}, - } + let notify = actions + .iter() + .any(|action| matches!(action, Action::Notify)); - // Break early if both conditions are true - if notify && highlight { - break; - } - } + let highlight = actions + .iter() + .any(|action| matches!(action, Action::SetTweak(Tweak::Highlight(true)))); - if notify { - notifies.push(user.clone()); - } + let increment_notify: OptionFuture<_> = notify + .then(|| self.increment_notificationcount(pdu.room_id(), user)) + .into(); - if highlight { - highlights.push(user.clone()); - } + let increment_highlight: OptionFuture<_> = highlight + .then(|| self.increment_highlightcount(pdu.room_id(), user)) + .into(); + + join(increment_notify, increment_highlight).await; if notify || highlight { let id: PduId = pdu_id.into(); @@ -159,31 +152,27 @@ pub(crate) async fn append_pdu(&self, pdu_id: RawPduId, pdu: &Pdu) -> Result { } } - self.increment_notification_counts(pdu.room_id(), notifies, highlights) - .await; - Ok(()) } #[implement(super::Service)] -async fn increment_notification_counts( - &self, - room_id: &RoomId, - notifies: Vec, - highlights: Vec, -) { - let _cork = self.db.db.cork(); +async fn increment_notificationcount(&self, room_id: &RoomId, user_id: &UserId) { + let db = &self.db.userroomid_notificationcount; + let key = (room_id.to_owned(), user_id.to_owned()); + let _lock = self.notification_increment_mutex.lock(&key).await; - for user in notifies { - increment(&self.db.userroomid_notificationcount, (&user, room_id)).await; - } - - for user in highlights { - increment(&self.db.userroomid_highlightcount, (&user, room_id)).await; - } + increment(db, (user_id, room_id)).await; +} + +#[implement(super::Service)] +async fn increment_highlightcount(&self, room_id: &RoomId, user_id: &UserId) { + let db = &self.db.userroomid_highlightcount; + let key = (room_id.to_owned(), user_id.to_owned()); + let _lock = self.highlight_increment_mutex.lock(&key).await; + + increment(db, (user_id, room_id)).await; } -// TODO: this is an ABA problem async fn increment(db: &Arc, key: (&UserId, &RoomId)) { let old: u64 = db.qry(&key).await.deserialized().unwrap_or(0); let new = old.saturating_add(1); diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index d83cb7ea..9d4a4f45 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use futures::{Stream, StreamExt, TryFutureExt, future::join}; use ipaddress::IPAddress; use ruma::{ - DeviceId, OwnedDeviceId, RoomId, UserId, + DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId, api::client::push::{Pusher, PusherKind, set_pusher}, events::{AnySyncTimelineEvent, room::power_levels::RoomPowerLevels}, push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, Ruleset}, @@ -18,6 +18,7 @@ use ruma::{ use tuwunel_core::{ Err, Result, err, implement, utils::{ + MutexMap, future::TryExtExt, stream::{BroadbandExt, ReadyExt, TryIgnore}, }, @@ -27,24 +28,30 @@ use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Json, Map}; pub use self::append::Notified; pub struct Service { - db: Data, services: Arc, + notification_increment_mutex: MutexMap<(OwnedRoomId, OwnedUserId), ()>, + highlight_increment_mutex: MutexMap<(OwnedRoomId, OwnedUserId), ()>, + db: Data, } struct Data { + db: Arc, senderkey_pusher: Arc, pushkey_deviceid: Arc, useridcount_notification: Arc, userroomid_highlightcount: Arc, userroomid_notificationcount: Arc, roomuserid_lastnotificationread: Arc, - db: Arc, } impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + services: args.services.clone(), + notification_increment_mutex: MutexMap::new(), + highlight_increment_mutex: MutexMap::new(), db: Data { + db: args.db.clone(), senderkey_pusher: args.db["senderkey_pusher"].clone(), pushkey_deviceid: args.db["pushkey_deviceid"].clone(), useridcount_notification: args.db["useridcount_notification"].clone(), @@ -52,9 +59,7 @@ impl crate::Service for Service { userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(), roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"] .clone(), - db: args.db.clone(), }, - services: args.services.clone(), })) }