diff --git a/src/api/client/read_marker.rs b/src/api/client/read_marker.rs index c0f4473d..0f7ccba7 100644 --- a/src/api/client/read_marker.rs +++ b/src/api/client/read_marker.rs @@ -28,7 +28,7 @@ pub(crate) async fn set_read_marker_route( if body.private_read_receipt.is_some() || body.read_receipt.is_some() { services - .user + .pusher .reset_notification_counts(sender_user, &body.room_id); } @@ -115,7 +115,7 @@ pub(crate) async fn create_receipt_route( create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate ) { services - .user + .pusher .reset_notification_counts(sender_user, &body.room_id); } diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index f5a54f3b..59c0cdba 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -809,7 +809,7 @@ async fn load_joined_room( .is_empty() .then(|| { services - .user + .pusher .last_notification_read(sender_user, room_id) }) .into(); @@ -877,7 +877,7 @@ async fn load_joined_room( let notification_count: OptionFuture<_> = send_notification_counts .then(|| { services - .user + .pusher .notification_count(sender_user, room_id) .map(TryInto::try_into) .unwrap_or(uint!(0)) @@ -887,7 +887,7 @@ async fn load_joined_room( let highlight_count: OptionFuture<_> = send_notification_counts .then(|| { services - .user + .pusher .highlight_count(sender_user, room_id) .map(TryInto::try_into) .unwrap_or(uint!(0)) diff --git a/src/api/client/sync/v5/rooms.rs b/src/api/client/sync/v5/rooms.rs index 875fa3e9..89f7502c 100644 --- a/src/api/client/sync/v5/rooms.rs +++ b/src/api/client/sync/v5/rooms.rs @@ -273,13 +273,13 @@ async fn handle_room( .map(Option::flatten); let highlight_count = services - .user + .pusher .highlight_count(sender_user, room_id) .map(TryInto::try_into) .map(Result::ok); let notification_count = services - .user + .pusher .notification_count(sender_user, room_id) .map(TryInto::try_into) .map(Result::ok); @@ -304,7 +304,7 @@ async fn handle_room( .map(|is_dm| is_dm.then_some(is_dm)); let last_read_count = services - .user + .pusher .last_notification_read(sender_user, room_id); let timeline = timeline_pdus diff --git a/src/api/client/sync/v5/selector.rs b/src/api/client/sync/v5/selector.rs index e8bf4148..917d872b 100644 --- a/src/api/client/sync/v5/selector.rs +++ b/src/api/client/sync/v5/selector.rs @@ -100,7 +100,7 @@ async fn matcher( let last_notification: OptionFuture<_> = matched .then(|| { services - .user + .pusher .last_notification_read(sender_user, &room_id) }) .into(); diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index a0946ba3..01849b63 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,36 +1,23 @@ mod append; +mod notification; +mod request; +mod send; -use std::{fmt::Debug, mem, sync::Arc}; +use std::sync::Arc; -use bytes::BytesMut; use futures::{Stream, StreamExt}; use ipaddress::IPAddress; use ruma::{ - DeviceId, OwnedDeviceId, RoomId, UInt, UserId, - api::{ - IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, SupportedVersions, - client::push::{Pusher, PusherKind, set_pusher}, - push_gateway::send_event_notification::{ - self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, - }, - }, - events::{AnySyncTimelineEvent, TimelineEventType, room::power_levels::RoomPowerLevels}, - push::{ - Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak, - }, + DeviceId, OwnedDeviceId, RoomId, UserId, + api::client::push::{Pusher, PusherKind, set_pusher}, + events::{AnySyncTimelineEvent, room::power_levels::RoomPowerLevels}, + push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, Ruleset}, serde::Raw, uint, }; use tuwunel_core::{ - Err, Result, debug_warn, err, - matrix::Event, - trace, - utils::{ - stream::{BroadbandExt, TryIgnore}, - string_from_bytes, - }, - warn, + Err, Result, err, implement, + utils::stream::{BroadbandExt, TryIgnore}, }; use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Json, Map}; @@ -44,6 +31,7 @@ struct Data { pushkey_deviceid: Arc, userroomid_highlightcount: Arc, userroomid_notificationcount: Arc, + roomuserid_lastnotificationread: Arc, db: Arc, } @@ -55,6 +43,8 @@ impl crate::Service for Service { pushkey_deviceid: args.db["pushkey_deviceid"].clone(), userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(), userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(), + roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"] + .clone(), db: args.db.clone(), }, services: args.services.clone(), @@ -64,347 +54,32 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - pub async fn set_pusher( - &self, - sender: &UserId, - sender_device: &DeviceId, - pusher: &set_pusher::v3::PusherAction, - ) -> Result { - match pusher { - | set_pusher::v3::PusherAction::Post(data) => { - let pushkey = data.pusher.ids.pushkey.as_str(); +#[implement(Service)] +pub async fn set_pusher( + &self, + sender: &UserId, + sender_device: &DeviceId, + pusher: &set_pusher::v3::PusherAction, +) -> Result { + match pusher { + | set_pusher::v3::PusherAction::Post(data) => { + let pushkey = data.pusher.ids.pushkey.as_str(); - if pushkey.len() > 512 { - return Err!(Request(InvalidParam( - "Push key length cannot be greater than 512 bytes." - ))); - } - - if data.pusher.ids.app_id.as_str().len() > 64 { - return Err!(Request(InvalidParam( - "App ID length cannot be greater than 64 bytes." - ))); - } - - // add some validation to the pusher URL - let pusher_kind = &data.pusher.kind; - if let PusherKind::Http(http) = pusher_kind { - let url = &http.url; - let url = url::Url::parse(&http.url).map_err(|e| { - err!(Request(InvalidParam( - warn!(%url, "HTTP pusher URL is not a valid URL: {e}") - ))) - })?; - - if ["http", "https"] - .iter() - .all(|&scheme| scheme != url.scheme().to_lowercase()) - { - return Err!(Request(InvalidParam( - warn!(%url, "HTTP pusher URL is not a valid HTTP/HTTPS URL") - ))); - } - - if let Ok(ip) = - IPAddress::parse(url.host_str().expect("URL previously validated")) - { - if !self.services.client.valid_cidr_range(&ip) { - return Err!(Request(InvalidParam( - warn!(%url, "HTTP pusher URL is a forbidden remote address") - ))); - } - } - } - - let pushkey = data.pusher.ids.pushkey.as_str(); - let key = (sender, pushkey); - self.db.senderkey_pusher.put(key, Json(pusher)); - self.db - .pushkey_deviceid - .insert(pushkey, sender_device); - }, - | set_pusher::v3::PusherAction::Delete(ids) => { - self.delete_pusher(sender, ids.pushkey.as_str()) - .await; - }, - } - - Ok(()) - } - - pub async fn delete_pusher(&self, sender: &UserId, pushkey: &str) { - let key = (sender, pushkey); - self.db.senderkey_pusher.del(key); - self.db.pushkey_deviceid.remove(pushkey); - - self.services - .sending - .cleanup_events(None, Some(sender), Some(pushkey)) - .await - .ok(); - } - - pub async fn get_device_pushkeys( - &self, - sender: &UserId, - device_id: &DeviceId, - ) -> Vec { - self.get_pushkeys(sender) - .map(ToOwned::to_owned) - .broad_filter_map(async |pushkey| { - self.get_pusher_device(&pushkey) - .await - .ok() - .filter(|pusher_device| pusher_device == device_id) - .is_some() - .then_some(pushkey) - }) - .collect() - .await - } - - pub async fn get_pusher_device(&self, pushkey: &str) -> Result { - self.db - .pushkey_deviceid - .get(pushkey) - .await - .deserialized() - } - - pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result { - let senderkey = (sender, pushkey); - self.db - .senderkey_pusher - .qry(&senderkey) - .await - .deserialized() - } - - pub async fn get_pushers(&self, sender: &UserId) -> Vec { - let prefix = (sender, Interfix); - self.db - .senderkey_pusher - .stream_prefix(&prefix) - .ignore_err() - .map(|(_, pusher): (Ignore, Pusher)| pusher) - .collect() - .await - } - - pub fn get_pushkeys<'a>( - &'a self, - sender: &'a UserId, - ) -> impl Stream + Send + 'a { - let prefix = (sender, Interfix); - self.db - .senderkey_pusher - .keys_prefix(&prefix) - .ignore_err() - .map(|(_, pushkey): (Ignore, &str)| pushkey) - } - - #[tracing::instrument(skip(self, dest, request))] - pub async fn send_request(&self, dest: &str, request: T) -> Result - where - T: OutgoingRequest + Debug + Send, - { - const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0]; - let supported = SupportedVersions { - versions: VERSIONS.into(), - features: Default::default(), - }; - - let dest = dest.replace(&self.services.config.notification_push_path, ""); - trace!("Push gateway destination: {dest}"); - - let http_request = request - .try_into_http_request::(&dest, SendAccessToken::IfRequired(""), &supported) - .map_err(|e| { - err!(BadServerResponse(warn!( - "Failed to find destination {dest} for push gateway: {e}" - ))) - })? - .map(BytesMut::freeze); - - let reqwest_request = reqwest::Request::try_from(http_request)?; - - if let Some(url_host) = reqwest_request.url().host_str() { - trace!("Checking request URL for IP"); - if let Ok(ip) = IPAddress::parse(url_host) { - if !self.services.client.valid_cidr_range(&ip) { - return Err!(BadServerResponse("Not allowed to send requests to this IP")); - } - } - } - - let response = self - .services - .client - .pusher - .execute(reqwest_request) - .await; - - match response { - | Ok(mut response) => { - // reqwest::Response -> http::Response conversion - - trace!("Checking response destination's IP"); - if let Some(remote_addr) = response.remote_addr() { - if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - if !self.services.client.valid_cidr_range(&ip) { - return Err!(BadServerResponse( - "Not allowed to send requests to this IP" - )); - } - } - } - - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await?; // TODO: handle timeout - - if !status.is_success() { - debug_warn!("Push gateway response body: {:?}", string_from_bytes(&body)); - return Err!(BadServerResponse(warn!( - "Push gateway {dest} returned unsuccessful HTTP response: {status}" - ))); - } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|e| { - err!(BadServerResponse(warn!( - "Push gateway {dest} returned invalid response: {e}" - ))) - }) - }, - | Err(e) => { - warn!("Could not send request to pusher {dest}: {e}"); - Err(e.into()) - }, - } - } - - #[tracing::instrument(skip(self, user, unread, pusher, ruleset, event))] - pub async fn send_push_notice( - &self, - user: &UserId, - unread: UInt, - pusher: &Pusher, - ruleset: Ruleset, - event: &E, - ) -> Result - where - E: Event, - { - let mut notify = None; - let mut tweaks = Vec::new(); - - let power_levels: RoomPowerLevels = self - .services - .state_accessor - .get_power_levels(event.room_id()) - .await?; - - let serialized = event.to_format(); - for action in self - .get_actions(user, &ruleset, &power_levels, &serialized, event.room_id()) - .await - { - let n = match action { - | Action::Notify => true, - | Action::SetTweak(tweak) => { - tweaks.push(tweak.clone()); - continue; - }, - | _ => false, - }; - - if notify.is_some() { - return Err!(Database( - r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"# - )); + if pushkey.len() > 512 { + return Err!(Request(InvalidParam( + "Push key length cannot be greater than 512 bytes." + ))); } - notify = Some(n); - } + if data.pusher.ids.app_id.as_str().len() > 64 { + return Err!(Request(InvalidParam( + "App ID length cannot be greater than 64 bytes." + ))); + } - if notify == Some(true) { - self.send_notice(unread, pusher, tweaks, event) - .await?; - } - // Else the event triggered no actions - - Ok(()) - } - - #[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")] - pub async fn get_actions<'a>( - &self, - user: &UserId, - ruleset: &'a Ruleset, - power_levels: &RoomPowerLevels, - pdu: &Raw, - room_id: &RoomId, - ) -> &'a [Action] { - let power_levels = PushConditionPowerLevelsCtx { - users: power_levels.users.clone(), - users_default: power_levels.users_default, - notifications: power_levels.notifications.clone(), - rules: power_levels.rules.clone(), - }; - - let room_joined_count = self - .services - .state_cache - .room_joined_count(room_id) - .await - .unwrap_or(1) - .try_into() - .unwrap_or_else(|_| uint!(0)); - - let user_display_name = self - .services - .users - .displayname(user) - .await - .unwrap_or_else(|_| user.localpart().to_owned()); - - let ctx = PushConditionRoomCtx { - room_id: room_id.to_owned(), - member_count: room_joined_count, - user_id: user.to_owned(), - user_display_name, - power_levels: Some(power_levels), - }; - - ruleset.get_actions(pdu, &ctx).await - } - - #[tracing::instrument(skip(self, unread, pusher, tweaks, event))] - async fn send_notice( - &self, - unread: UInt, - pusher: &Pusher, - tweaks: Vec, - event: &Pdu, - ) -> Result { - // TODO: email - match &pusher.kind { - | PusherKind::Http(http) => { + // add some validation to the pusher URL + let pusher_kind = &data.pusher.kind; + if let PusherKind::Http(http) = pusher_kind { let url = &http.url; let url = url::Url::parse(&http.url).map_err(|e| { err!(Request(InvalidParam( @@ -430,86 +105,134 @@ impl Service { ))); } } + } - // TODO (timo): can pusher/devices have conflicting formats - let event_id_only = http.format == Some(PushFormat::EventIdOnly); - - let mut device = - Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone()); - device.data.data.clone_from(&http.data); - device.data.format.clone_from(&http.format); - - // Tweaks are only added if the format is NOT event_id_only - if !event_id_only { - device.tweaks.clone_from(&tweaks); - } - - let d = vec![device]; - let mut notifi = Notification::new(d); - - notifi.event_id = Some(event.event_id().to_owned()); - notifi.room_id = Some(event.room_id().to_owned()); - if http - .data - .get("org.matrix.msc4076.disable_badge_count") - .is_none() && http.data.get("disable_badge_count").is_none() - { - notifi.counts = NotificationCounts::new(unread, uint!(0)); - } else { - // counts will not be serialised if it's the default (0, 0) - // skip_serializing_if = "NotificationCounts::is_default" - notifi.counts = NotificationCounts::default(); - } - - if !event_id_only { - if *event.kind() == TimelineEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) - { - notifi.prio = NotificationPriority::High; - } else { - notifi.prio = NotificationPriority::Low; - } - notifi.sender = Some(event.sender().to_owned()); - notifi.event_type = Some(event.kind().to_owned()); - notifi.content = serde_json::value::to_raw_value(event.content()).ok(); - - if *event.kind() == TimelineEventType::RoomMember { - notifi.user_is_target = - event.state_key() == Some(event.sender().as_str()); - } - - notifi.sender_display_name = self - .services - .users - .displayname(event.sender()) - .await - .ok(); - - notifi.room_name = self - .services - .state_accessor - .get_name(event.room_id()) - .await - .ok(); - - notifi.room_alias = self - .services - .state_accessor - .get_canonical_alias(event.room_id()) - .await - .ok(); - } - - self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) - .await?; - - Ok(()) - }, - // TODO: Handle email - //PusherKind::Email(_) => Ok(()), - | _ => Ok(()), - } + let pushkey = data.pusher.ids.pushkey.as_str(); + let key = (sender, pushkey); + self.db.senderkey_pusher.put(key, Json(pusher)); + self.db + .pushkey_deviceid + .insert(pushkey, sender_device); + }, + | set_pusher::v3::PusherAction::Delete(ids) => { + self.delete_pusher(sender, ids.pushkey.as_str()) + .await; + }, } + + Ok(()) +} + +#[implement(Service)] +pub async fn delete_pusher(&self, sender: &UserId, pushkey: &str) { + let key = (sender, pushkey); + self.db.senderkey_pusher.del(key); + self.db.pushkey_deviceid.remove(pushkey); + + self.services + .sending + .cleanup_events(None, Some(sender), Some(pushkey)) + .await + .ok(); +} + +#[implement(Service)] +pub async fn get_device_pushkeys(&self, sender: &UserId, device_id: &DeviceId) -> Vec { + self.get_pushkeys(sender) + .map(ToOwned::to_owned) + .broad_filter_map(async |pushkey| { + self.get_pusher_device(&pushkey) + .await + .ok() + .filter(|pusher_device| pusher_device == device_id) + .is_some() + .then_some(pushkey) + }) + .collect() + .await +} + +#[implement(Service)] +pub async fn get_pusher_device(&self, pushkey: &str) -> Result { + self.db + .pushkey_deviceid + .get(pushkey) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result { + let senderkey = (sender, pushkey); + self.db + .senderkey_pusher + .qry(&senderkey) + .await + .deserialized() +} + +#[implement(Service)] +pub async fn get_pushers(&self, sender: &UserId) -> Vec { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .stream_prefix(&prefix) + .ignore_err() + .map(|(_, pusher): (Ignore, Pusher)| pusher) + .collect() + .await +} + +#[implement(Service)] +pub fn get_pushkeys<'a>(&'a self, sender: &'a UserId) -> impl Stream + Send + 'a { + let prefix = (sender, Interfix); + self.db + .senderkey_pusher + .keys_prefix(&prefix) + .ignore_err() + .map(|(_, pushkey): (Ignore, &str)| pushkey) +} + +#[implement(Service)] +#[tracing::instrument(level = "debug", skip_all)] +pub async fn get_actions<'a>( + &self, + user: &UserId, + ruleset: &'a Ruleset, + power_levels: &RoomPowerLevels, + pdu: &Raw, + room_id: &RoomId, +) -> &'a [Action] { + let power_levels = PushConditionPowerLevelsCtx { + users: power_levels.users.clone(), + users_default: power_levels.users_default, + notifications: power_levels.notifications.clone(), + rules: power_levels.rules.clone(), + }; + + let room_joined_count = self + .services + .state_cache + .room_joined_count(room_id) + .await + .unwrap_or(1) + .try_into() + .unwrap_or_else(|_| uint!(0)); + + let user_display_name = self + .services + .users + .displayname(user) + .await + .unwrap_or_else(|_| user.localpart().to_owned()); + + let ctx = PushConditionRoomCtx { + room_id: room_id.to_owned(), + member_count: room_joined_count, + user_id: user.to_owned(), + user_display_name, + power_levels: Some(power_levels), + }; + + ruleset.get_actions(pdu, &ctx).await } diff --git a/src/service/rooms/user/mod.rs b/src/service/pusher/notification.rs similarity index 66% rename from src/service/rooms/user/mod.rs rename to src/service/pusher/notification.rs index daf9e65c..5cbb7c93 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/pusher/notification.rs @@ -1,40 +1,11 @@ -use std::sync::Arc; - use ruma::{RoomId, UserId}; use tuwunel_core::{ Result, implement, trace, utils::stream::{ReadyExt, TryIgnore}, }; -use tuwunel_database::{Deserialized, Interfix, Map}; +use tuwunel_database::{Deserialized, Interfix}; -pub struct Service { - db: Data, - services: Arc, -} - -struct Data { - userroomid_notificationcount: Arc, - userroomid_highlightcount: Arc, - roomuserid_lastnotificationread: Arc, -} - -impl crate::Service for Service { - fn build(args: &crate::Args<'_>) -> Result> { - Ok(Arc::new(Self { - db: Data { - userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(), - userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(), - roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"] - .clone(), - }, - services: args.services.clone(), - })) - } - - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } -} - -#[implement(Service)] +#[implement(super::Service)] #[tracing::instrument(level = "debug", skip(self))] pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { let count = self.services.globals.next_count(); @@ -53,7 +24,7 @@ pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) { .put(roomuser_id, *count); } -#[implement(Service)] +#[implement(super::Service)] #[tracing::instrument(level = "debug", skip(self), ret(level = "trace"))] pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { let key = (user_id, room_id); @@ -65,7 +36,7 @@ pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u6 .unwrap_or(0) } -#[implement(Service)] +#[implement(super::Service)] #[tracing::instrument(level = "debug", skip(self), ret(level = "trace"))] pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { let key = (user_id, room_id); @@ -77,7 +48,7 @@ pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 { .unwrap_or(0) } -#[implement(Service)] +#[implement(super::Service)] #[tracing::instrument(level = "debug", skip(self), ret(level = "trace"))] pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 { let key = (room_id, user_id); @@ -89,7 +60,7 @@ pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) - .unwrap_or(0) } -#[implement(Service)] +#[implement(super::Service)] pub async fn delete_room_notification_read(&self, room_id: &RoomId) -> Result { let key = (room_id, Interfix); self.db diff --git a/src/service/pusher/request.rs b/src/service/pusher/request.rs new file mode 100644 index 00000000..d2a7cce1 --- /dev/null +++ b/src/service/pusher/request.rs @@ -0,0 +1,106 @@ +use std::{fmt::Debug, mem}; + +use bytes::BytesMut; +use ipaddress::IPAddress; +use ruma::api::{ + IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, SupportedVersions, +}; +use tuwunel_core::{ + Err, Result, debug_warn, err, implement, trace, utils::string_from_bytes, warn, +}; + +#[implement(super::Service)] +#[tracing::instrument(level = "debug", skip_all)] +pub(super) async fn send_request(&self, dest: &str, request: T) -> Result +where + T: OutgoingRequest + Debug + Send, +{ + const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0]; + let supported = SupportedVersions { + versions: VERSIONS.into(), + features: Default::default(), + }; + + let dest = dest.replace(&self.services.config.notification_push_path, ""); + trace!("Push gateway destination: {dest}"); + + let http_request = request + .try_into_http_request::(&dest, SendAccessToken::IfRequired(""), &supported) + .map_err(|e| { + err!(BadServerResponse(warn!( + "Failed to find destination {dest} for push gateway: {e}" + ))) + })? + .map(BytesMut::freeze); + + let reqwest_request = reqwest::Request::try_from(http_request)?; + if let Some(url_host) = reqwest_request.url().host_str() { + trace!("Checking request URL for IP"); + if let Ok(ip) = IPAddress::parse(url_host) { + if !self.services.client.valid_cidr_range(&ip) { + return Err!(BadServerResponse("Not allowed to send requests to this IP")); + } + } + } + + let response = self + .services + .client + .pusher + .execute(reqwest_request) + .await; + + match response { + | Ok(mut response) => { + // reqwest::Response -> http::Response conversion + + trace!("Checking response destination's IP"); + if let Some(remote_addr) = response.remote_addr() { + if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { + if !self.services.client.valid_cidr_range(&ip) { + return Err!(BadServerResponse( + "Not allowed to send requests to this IP" + )); + } + } + } + + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await?; // TODO: handle timeout + + if !status.is_success() { + debug_warn!("Push gateway response body: {:?}", string_from_bytes(&body)); + return Err!(BadServerResponse(warn!( + "Push gateway {dest} returned unsuccessful HTTP response: {status}" + ))); + } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + + response.map_err(|e| { + err!(BadServerResponse(warn!( + "Push gateway {dest} returned invalid response: {e}" + ))) + }) + }, + | Err(e) => { + warn!("Could not send request to pusher {dest}: {e}"); + Err(e.into()) + }, + } +} diff --git a/src/service/pusher/send.rs b/src/service/pusher/send.rs new file mode 100644 index 00000000..f746c2df --- /dev/null +++ b/src/service/pusher/send.rs @@ -0,0 +1,191 @@ +use ipaddress::IPAddress; +use ruma::{ + UInt, UserId, + api::{ + client::push::{Pusher, PusherKind}, + push_gateway::send_event_notification::{ + self, + v1::{Device, Notification, NotificationCounts, NotificationPriority}, + }, + }, + events::{TimelineEventType, room::power_levels::RoomPowerLevels}, + push::{Action, PushFormat, Ruleset, Tweak}, + uint, +}; +use tuwunel_core::{Err, Result, err, implement, matrix::Event}; + +#[implement(super::Service)] +#[tracing::instrument(level = "debug", skip_all)] +pub async fn send_push_notice( + &self, + user_id: &UserId, + pusher: &Pusher, + ruleset: &Ruleset, + event: &E, +) -> Result +where + E: Event, +{ + let mut notify = None; + let mut tweaks = Vec::new(); + + let unread: UInt = self + .services + .pusher + .notification_count(user_id, event.room_id()) + .await + .try_into()?; + + let power_levels: RoomPowerLevels = self + .services + .state_accessor + .get_power_levels(event.room_id()) + .await?; + + let serialized = event.to_format(); + for action in self + .get_actions(user_id, ruleset, &power_levels, &serialized, event.room_id()) + .await + { + let n = match action { + | Action::Notify => true, + | Action::SetTweak(tweak) => { + tweaks.push(tweak.clone()); + continue; + }, + | _ => false, + }; + + if notify.is_some() { + return Err!(Database( + r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"# + )); + } + + notify = Some(n); + } + + if notify == Some(true) { + self.send_notice(unread, pusher, tweaks, event) + .await?; + } + // Else the event triggered no actions + + Ok(()) +} + +#[implement(super::Service)] +#[tracing::instrument(level = "debug", skip_all)] +async fn send_notice( + &self, + unread: UInt, + pusher: &Pusher, + tweaks: Vec, + event: &Pdu, +) -> Result { + // TODO: email + match &pusher.kind { + | PusherKind::Http(http) => { + let url = &http.url; + let url = url::Url::parse(&http.url).map_err(|e| { + err!(Request(InvalidParam( + warn!(%url, "HTTP pusher URL is not a valid URL: {e}") + ))) + })?; + + if ["http", "https"] + .iter() + .all(|&scheme| scheme != url.scheme().to_lowercase()) + { + return Err!(Request(InvalidParam( + warn!(%url, "HTTP pusher URL is not a valid HTTP/HTTPS URL") + ))); + } + + if let Ok(ip) = IPAddress::parse(url.host_str().expect("URL previously validated")) { + if !self.services.client.valid_cidr_range(&ip) { + return Err!(Request(InvalidParam( + warn!(%url, "HTTP pusher URL is a forbidden remote address") + ))); + } + } + + // TODO (timo): can pusher/devices have conflicting formats + let event_id_only = http.format == Some(PushFormat::EventIdOnly); + + let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone()); + device.data.data.clone_from(&http.data); + device.data.format.clone_from(&http.format); + + // Tweaks are only added if the format is NOT event_id_only + if !event_id_only { + device.tweaks.clone_from(&tweaks); + } + + let d = vec![device]; + let mut notifi = Notification::new(d); + + notifi.event_id = Some(event.event_id().to_owned()); + notifi.room_id = Some(event.room_id().to_owned()); + if http + .data + .get("org.matrix.msc4076.disable_badge_count") + .is_none() && http.data.get("disable_badge_count").is_none() + { + notifi.counts = NotificationCounts::new(unread, uint!(0)); + } else { + // counts will not be serialised if it's the default (0, 0) + // skip_serializing_if = "NotificationCounts::is_default" + notifi.counts = NotificationCounts::default(); + } + + if !event_id_only { + if *event.kind() == TimelineEventType::RoomEncrypted + || tweaks + .iter() + .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + { + notifi.prio = NotificationPriority::High; + } else { + notifi.prio = NotificationPriority::Low; + } + notifi.sender = Some(event.sender().to_owned()); + notifi.event_type = Some(event.kind().to_owned()); + notifi.content = serde_json::value::to_raw_value(event.content()).ok(); + + if *event.kind() == TimelineEventType::RoomMember { + notifi.user_is_target = event.state_key() == Some(event.sender().as_str()); + } + + notifi.sender_display_name = self + .services + .users + .displayname(event.sender()) + .await + .ok(); + + notifi.room_name = self + .services + .state_accessor + .get_name(event.room_id()) + .await + .ok(); + + notifi.room_alias = self + .services + .state_accessor + .get_canonical_alias(event.room_id()) + .await + .ok(); + } + + self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) + .await?; + + Ok(()) + }, + // TODO: Handle email + //PusherKind::Email(_) => Ok(()), + | _ => Ok(()), + } +} diff --git a/src/service/rooms/delete/mod.rs b/src/service/rooms/delete/mod.rs index 86803bc1..c0f50b3a 100644 --- a/src/service/rooms/delete/mod.rs +++ b/src/service/rooms/delete/mod.rs @@ -159,7 +159,7 @@ impl Service { debug!("Deleting the room's last notifications read."); self.services - .user + .pusher .delete_room_notification_read(room_id) .await .log_err() diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index 1f8c5029..0fc2e1d9 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -17,4 +17,3 @@ pub mod state_compressor; pub mod threads; pub mod timeline; pub mod typing; -pub mod user; diff --git a/src/service/rooms/timeline/append.rs b/src/service/rooms/timeline/append.rs index cba6a3ac..0e4e9d4a 100644 --- a/src/service/rooms/timeline/append.rs +++ b/src/service/rooms/timeline/append.rs @@ -168,7 +168,7 @@ where .private_read_set(pdu.room_id(), pdu.sender(), *next_count2); self.services - .user + .pusher .reset_notification_counts(pdu.sender(), pdu.room_id()); let count = PduCount::Normal(*next_count1); diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 6e6ce81f..480b2f1a 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -10,14 +10,14 @@ use std::{ use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; use futures::{ - FutureExt, StreamExt, - future::{BoxFuture, OptionFuture, join3}, + FutureExt, StreamExt, TryFutureExt, + future::{BoxFuture, OptionFuture, join3, try_join3}, pin_mut, stream::FuturesUnordered, }; use ruma::{ MilliSecondsSinceUnixEpoch, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, - UInt, + UserId, api::{ appservice::event::push_events::v1::EphemeralData, federation::transactions::{ @@ -39,7 +39,7 @@ use ruma::{ uint, }; use tuwunel_core::{ - Error, Event, Result, debug, err, error, + Error, Event, Result, debug, err, error, extract_variant, result::LogErr, trace, utils::{ @@ -792,114 +792,103 @@ impl Service { pushkey: String, events: Vec, ) -> SendingResult { - let Ok(pusher) = self + let suppressed = self.pushing_suppressed(&user_id).map(Ok); + + let pusher = self .services .pusher .get_pusher(&user_id, &pushkey) - .await - else { - return Err(( - Destination::Push(user_id.clone(), pushkey.clone()), - err!(Database(error!(?user_id, ?pushkey, "Missing pusher"))), - )); - }; + .map_err(|_| { + ( + Destination::Push(user_id.clone(), pushkey.clone()), + err!(Database(error!(?user_id, ?pushkey, "Missing pusher"))), + ) + }); - let mut pdus = Vec::with_capacity( - events - .iter() - .filter(|event| matches!(event, SendingEvent::Pdu(_))) - .count(), - ); - for event in &events { - match event { - | SendingEvent::Pdu(pdu_id) => { - if let Ok(pdu) = self - .services - .timeline - .get_pdu_from_id(pdu_id) - .await - { - pdus.push(pdu); - } - }, - | SendingEvent::Edu(_) | SendingEvent::Flush => { - // Push gateways don't need EDUs (?) and flush only; - // no new content - }, - } - } - - for pdu in pdus { - // Redacted events are not notification targets (we don't send push for them) - if pdu.contains_unsigned_property("redacted_because", serde_json::Value::is_string) { - continue; - } - - // optional suppression: heuristic combining presence age and recent sync - // activity. - if self.services.config.suppress_push_when_active - && let Ok(presence) = self - .services - .presence - .get_presence(&user_id) - .await - { - let is_online = presence.content.presence == PresenceState::Online; - - let presence_age_ms = presence - .content - .last_active_ago - .map(u64::from) - .unwrap_or(u64::MAX); - - let sync_gap_ms = self - .services - .presence - .last_sync_gap_ms(&user_id) - .await; - - let considered_active = is_online - && presence_age_ms < 65_000 - && sync_gap_ms.is_some_and(|gap| gap < 32_000); - - if considered_active { - trace!( - ?user_id, - presence_age_ms, sync_gap_ms, "suppressing push: active heuristic" - ); - continue; - } - } - - let rules_for_user = self - .services - .account_data - .get_global(&user_id, GlobalAccountDataEventType::PushRules) - .await - .map_or_else( + let rules_for_user = self + .services + .account_data + .get_global(&user_id, GlobalAccountDataEventType::PushRules) + .map(|ev| { + ev.map_or_else( |_| push::Ruleset::server_default(&user_id), |ev: PushRulesEvent| ev.content.global, - ); + ) + }) + .map(Ok); - let unread: UInt = self - .services - .user - .notification_count(&user_id, pdu.room_id()) - .await - .try_into() - .expect("notification count can't go that high"); + let (pusher, rules_for_user, suppressed) = + try_join3(pusher, rules_for_user, suppressed).await?; - let _response = self - .services - .pusher - .send_push_notice(&user_id, unread, &pusher, rules_for_user, &pdu) - .await - .map_err(|e| (Destination::Push(user_id.clone(), pushkey.clone()), e)); + if suppressed { + return Ok(Destination::Push(user_id, pushkey)); } + let _sent = events + .iter() + .stream() + .ready_filter_map(|event| extract_variant!(event, SendingEvent::Pdu)) + .wide_filter_map(|pdu_id| { + self.services + .timeline + .get_pdu_from_id(pdu_id) + .ok() + }) + .ready_filter(|pdu| !pdu.is_redacted()) + .wide_filter_map(async |pdu| { + self.services + .pusher + .send_push_notice(&user_id, &pusher, &rules_for_user, &pdu) + .await + .map_err(|e| (Destination::Push(user_id.clone(), pushkey.clone()), e)) + .ok() + }) + .count() + .await; + Ok(Destination::Push(user_id, pushkey)) } + // optional suppression: heuristic combining presence age and recent sync + // activity. + async fn pushing_suppressed(&self, user_id: &UserId) -> bool { + if !self.services.config.suppress_push_when_active { + return false; + } + + let Ok(presence) = self.services.presence.get_presence(user_id).await else { + return false; + }; + + if presence.content.presence != PresenceState::Online { + return false; + } + + let presence_age_ms = presence + .content + .last_active_ago + .map(u64::from) + .unwrap_or(u64::MAX); + + if presence_age_ms >= 65_000 { + return false; + } + + let sync_gap_ms = self + .services + .presence + .last_sync_gap_ms(user_id) + .await; + + let considered_active = sync_gap_ms.is_some_and(|gap| gap < 32_000); + + if considered_active { + trace!(?user_id, presence_age_ms, "suppressing push: active heuristic"); + } + + considered_active + } + async fn send_events_dest_federation( &self, server: OwnedServerName, diff --git a/src/service/services.rs b/src/service/services.rs index aa1eaa97..961bed21 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -49,7 +49,6 @@ pub struct Services { pub threads: Arc, pub timeline: Arc, pub typing: Arc, - pub user: Arc, pub federation: Arc, pub sending: Arc, pub server_keys: Arc, @@ -107,7 +106,6 @@ pub async fn build(server: Arc) -> Result> { threads: rooms::threads::Service::build(&args)?, timeline: rooms::timeline::Service::build(&args)?, typing: rooms::typing::Service::build(&args)?, - user: rooms::user::Service::build(&args)?, federation: federation::Service::build(&args)?, sending: sending::Service::build(&args)?, server_keys: server_keys::Service::build(&args)?, @@ -166,7 +164,6 @@ pub(crate) fn services(&self) -> impl Iterator> + Send { cast!(self.threads), cast!(self.timeline), cast!(self.typing), - cast!(self.user), cast!(self.federation), cast!(self.sending), cast!(self.server_keys),