diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 5ae11846..b850c0a1 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,5 +1,5 @@ use axum::extract::State; -use futures::{Stream, StreamExt, TryFutureExt, future::try_join}; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt, future::try_join3}; use ruma::{ EventId, RoomId, UInt, UserId, api::{ @@ -12,12 +12,14 @@ use ruma::{ events::{TimelineEventType, relation::RelationType}, }; use tuwunel_core::{ - Result, at, + Err, Error, Result, at, err, matrix::{ + ShortRoomId, event::{Event, RelationTypeEqual}, - pdu::{Pdu, PduCount}, + pdu::{Pdu, PduCount, PduId}, }, utils::{ + BoolExt, result::FlatOk, stream::{IterStream, ReadyExt, WidebandExt}, }, @@ -122,55 +124,92 @@ async fn paginate_relations_with_filter( recurse: bool, dir: Direction, ) -> Result { - let start: PduCount = from - .map(str::parse) - .transpose()? - .unwrap_or_else(|| match dir { - | Direction::Forward => PduCount::min(), - | Direction::Backward => PduCount::max(), - }); + let from: Option = from.map(str::parse).transpose()?; let to: Option = to.map(str::parse).flat_ok(); - // Use limit or else 30, with maximum 100 + // Spec (v1.10) recommends depth of at least 3 + let depth: u8 = if recurse { 3 } else { 1 }; + let limit: usize = limit .map(TryInto::try_into) .flat_ok() .unwrap_or(30) .min(100); - // Spec (v1.10) recommends depth of at least 3 - let depth: u8 = if recurse { 3 } else { 1 }; + let shortroomid = services.short.get_shortroomid(room_id); - let events: Vec<_> = - get_relations(services, sender_user, room_id, target, start, limit, depth, dir) - .await - .ready_filter(|(_, pdu)| { - filter_event_type - .as_ref() - .is_none_or(|kind| kind == pdu.kind()) - }) - .ready_filter(|(_, pdu)| { - filter_rel_type - .as_ref() - .is_none_or(|rel_type| rel_type.relation_type_equal(pdu)) - }) - .ready_take_while(|(count, _)| Some(*count) != to) - .wide_filter_map(|item| visibility_filter(services, sender_user, item)) - .take(limit) - .collect() - .await; + let target = services + .timeline + .get_pdu_id(target) + .map_ok(PduId::from) + .map_ok(Ok::<_, Error>); - let next_batch = events - .last() - .map(at!(0)) - .as_ref() - .map(ToString::to_string); + let visible = services + .state_accessor + .user_can_see_state_events(sender_user, room_id) + .map(|visible| { + visible.ok_or_else(|| err!(Request(Forbidden("You cannot view this room.")))) + }); + + let (shortroomid, target, ()) = try_join3(shortroomid, target, visible).await?; + + let Ok(target) = target else { + return Ok(get_relating_events::v1::Response::new(Vec::new())); + }; + + if shortroomid != target.shortroomid { + return Err!(Request(NotFound("Event not found in room."))); + } + + //TODO: support backfilled relations + if let PduCount::Backfilled(_) = target.count { + return Ok(get_relating_events::v1::Response::new(Vec::new())); + } + + let events: Vec<_> = get_relations( + services, + sender_user, + target.shortroomid, + target.count, + from, + limit, + depth, + dir, + ) + .await //TODO: XXX + .ready_take_while(|(count, _)| Some(*count) != to) + .ready_filter(|(_, pdu)| { + filter_event_type + .as_ref() + .is_none_or(|kind| kind == pdu.kind()) + }) + .ready_filter(|(_, pdu)| { + filter_rel_type + .as_ref() + .is_none_or(|rel_type| rel_type.relation_type_equal(pdu)) + }) + .wide_filter_map(|item| visibility_filter(services, sender_user, item)) + .take(limit) + .collect() + .await; Ok(get_relating_events::v1::Response { - next_batch, - prev_batch: from.map(Into::into), recursion_depth: recurse.then_some(depth.into()), + + next_batch: events + .last() + .map(at!(0)) + .as_ref() + .map(ToString::to_string), + + prev_batch: events + .first() + .map(at!(0)) + .or(from) + .as_ref() + .map(ToString::to_string), + chunk: events .into_iter() .map(at!(1)) @@ -183,35 +222,17 @@ async fn paginate_relations_with_filter( async fn get_relations( services: &Services, sender_user: &UserId, - room_id: &RoomId, - target: &EventId, - from: PduCount, + shortroomid: ShortRoomId, + target: PduCount, + from: Option, limit: usize, max_depth: u8, dir: Direction, ) -> impl Stream + Send { - let room_id = services.short.get_shortroomid(room_id); - - let target = services - .timeline - .get_pdu_count(target) - .map_ok(|target| { - match target { - | PduCount::Normal(count) => count, - | _ => { - // TODO: Support backfilled relations - 0 // This will result in an empty iterator - }, - } - }); - - let Ok((room_id, target)) = try_join(room_id, target).await else { - return Vec::new().into_iter().stream(); - }; - let mut pdus: Vec<_> = services .pdu_metadata - .get_relations(room_id, target.into(), from, dir, Some(sender_user)) + .get_relations(shortroomid, target, from, dir, Some(sender_user)) + .take(limit) .collect() .await; @@ -221,28 +242,25 @@ async fn get_relations( .map(|(count, _)| (*count, 1)) .collect(); - 'limit: while let Some((count, depth)) = stack.pop() { - let target = match count { - | PduCount::Normal(count) => count, - | PduCount::Backfilled(_) => { - // TODO: Support backfilled relations - 0 // This will result in an empty iterator - }, + 'limit: while let Some((target, depth)) = stack.pop() { + let PduCount::Normal(target) = target else { + continue; }; let relations: Vec<_> = services .pdu_metadata - .get_relations(room_id, target.into(), from, dir, Some(sender_user)) + .get_relations(shortroomid, target.into(), from, dir, Some(sender_user)) + .take(limit.saturating_sub(pdus.len())) .collect() .await; - for (count, pdu) in relations { + for (target, pdu) in relations { if depth < max_depth { - stack.push((count, depth.saturating_add(1))); + stack.push((target, depth.saturating_add(1))); } if pdus.len() < limit { - pdus.push((count, pdu)); + pdus.push((target, pdu)); } else { break 'limit; } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index c662e940..e1aeca3d 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -65,12 +65,19 @@ pub fn get_relations<'a>( &'a self, shortroomid: ShortRoomId, target: PduCount, - from: PduCount, + from: Option, dir: Direction, user_id: Option<&'a UserId>, ) -> impl Stream + Send + '_ { let target = target.to_be_bytes(); - let from = from.saturating_inc(dir).to_be_bytes(); + let from = from + .map(|from| from.saturating_inc(dir)) + .unwrap_or_else(|| match dir { + | Direction::Backward => PduCount::max(), + | Direction::Forward => PduCount::default(), + }) + .to_be_bytes(); + let mut buf = ArrayVec::::new(); let start = { buf.extend(target); @@ -79,16 +86,16 @@ pub fn get_relations<'a>( }; match dir { - | Direction::Forward => self - .db - .tofrom_relation - .raw_keys_from(start) - .boxed(), | Direction::Backward => self .db .tofrom_relation .rev_raw_keys_from(start) .boxed(), + | Direction::Forward => self + .db + .tofrom_relation + .raw_keys_from(start) + .boxed(), } .ignore_err() .ready_take_while(move |key| key.starts_with(&target)) diff --git a/tests/complement/results.jsonl b/tests/complement/results.jsonl index 62a1fb5d..262b9886 100644 --- a/tests/complement/results.jsonl +++ b/tests/complement/results.jsonl @@ -508,7 +508,7 @@ {"Action":"pass","Test":"TestRegistration/parallel/POST_{}_returns_a_set_of_flows"} {"Action":"pass","Test":"TestRegistration/parallel/Registration_accepts_non-ascii_passwords"} {"Action":"pass","Test":"TestRelations"} -{"Action":"fail","Test":"TestRelationsPagination"} +{"Action":"pass","Test":"TestRelationsPagination"} {"Action":"pass","Test":"TestRelationsPaginationSync"} {"Action":"pass","Test":"TestRemoteAliasRequestsUnderstandUnicode"} {"Action":"pass","Test":"TestRemotePngThumbnail"}