Fix relations pagination compliance.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2025-11-18 22:06:11 +00:00
parent f4eeaaf167
commit 5147b541b5
3 changed files with 105 additions and 80 deletions

View File

@@ -1,5 +1,5 @@
use axum::extract::State; use axum::extract::State;
use futures::{Stream, StreamExt, TryFutureExt, future::try_join}; use futures::{FutureExt, Stream, StreamExt, TryFutureExt, future::try_join3};
use ruma::{ use ruma::{
EventId, RoomId, UInt, UserId, EventId, RoomId, UInt, UserId,
api::{ api::{
@@ -12,12 +12,14 @@ use ruma::{
events::{TimelineEventType, relation::RelationType}, events::{TimelineEventType, relation::RelationType},
}; };
use tuwunel_core::{ use tuwunel_core::{
Result, at, Err, Error, Result, at, err,
matrix::{ matrix::{
ShortRoomId,
event::{Event, RelationTypeEqual}, event::{Event, RelationTypeEqual},
pdu::{Pdu, PduCount}, pdu::{Pdu, PduCount, PduId},
}, },
utils::{ utils::{
BoolExt,
result::FlatOk, result::FlatOk,
stream::{IterStream, ReadyExt, WidebandExt}, stream::{IterStream, ReadyExt, WidebandExt},
}, },
@@ -122,55 +124,92 @@ async fn paginate_relations_with_filter(
recurse: bool, recurse: bool,
dir: Direction, dir: Direction,
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let start: PduCount = from let from: Option<PduCount> = from.map(str::parse).transpose()?;
.map(str::parse)
.transpose()?
.unwrap_or_else(|| match dir {
| Direction::Forward => PduCount::min(),
| Direction::Backward => PduCount::max(),
});
let to: Option<PduCount> = to.map(str::parse).flat_ok(); let to: Option<PduCount> = to.map(str::parse).flat_ok();
// Use limit or else 30, with maximum 100 // Spec (v1.10) recommends depth of at least 3
let depth: u8 = if recurse { 3 } else { 1 };
let limit: usize = limit let limit: usize = limit
.map(TryInto::try_into) .map(TryInto::try_into)
.flat_ok() .flat_ok()
.unwrap_or(30) .unwrap_or(30)
.min(100); .min(100);
// Spec (v1.10) recommends depth of at least 3 let shortroomid = services.short.get_shortroomid(room_id);
let depth: u8 = if recurse { 3 } else { 1 };
let events: Vec<_> = let target = services
get_relations(services, sender_user, room_id, target, start, limit, depth, dir) .timeline
.await .get_pdu_id(target)
.ready_filter(|(_, pdu)| { .map_ok(PduId::from)
filter_event_type .map_ok(Ok::<_, Error>);
.as_ref()
.is_none_or(|kind| kind == pdu.kind())
})
.ready_filter(|(_, pdu)| {
filter_rel_type
.as_ref()
.is_none_or(|rel_type| rel_type.relation_type_equal(pdu))
})
.ready_take_while(|(count, _)| Some(*count) != to)
.wide_filter_map(|item| visibility_filter(services, sender_user, item))
.take(limit)
.collect()
.await;
let next_batch = events let visible = services
.last() .state_accessor
.map(at!(0)) .user_can_see_state_events(sender_user, room_id)
.as_ref() .map(|visible| {
.map(ToString::to_string); visible.ok_or_else(|| err!(Request(Forbidden("You cannot view this room."))))
});
let (shortroomid, target, ()) = try_join3(shortroomid, target, visible).await?;
let Ok(target) = target else {
return Ok(get_relating_events::v1::Response::new(Vec::new()));
};
if shortroomid != target.shortroomid {
return Err!(Request(NotFound("Event not found in room.")));
}
//TODO: support backfilled relations
if let PduCount::Backfilled(_) = target.count {
return Ok(get_relating_events::v1::Response::new(Vec::new()));
}
let events: Vec<_> = get_relations(
services,
sender_user,
target.shortroomid,
target.count,
from,
limit,
depth,
dir,
)
.await //TODO: XXX
.ready_take_while(|(count, _)| Some(*count) != to)
.ready_filter(|(_, pdu)| {
filter_event_type
.as_ref()
.is_none_or(|kind| kind == pdu.kind())
})
.ready_filter(|(_, pdu)| {
filter_rel_type
.as_ref()
.is_none_or(|rel_type| rel_type.relation_type_equal(pdu))
})
.wide_filter_map(|item| visibility_filter(services, sender_user, item))
.take(limit)
.collect()
.await;
Ok(get_relating_events::v1::Response { Ok(get_relating_events::v1::Response {
next_batch,
prev_batch: from.map(Into::into),
recursion_depth: recurse.then_some(depth.into()), recursion_depth: recurse.then_some(depth.into()),
next_batch: events
.last()
.map(at!(0))
.as_ref()
.map(ToString::to_string),
prev_batch: events
.first()
.map(at!(0))
.or(from)
.as_ref()
.map(ToString::to_string),
chunk: events chunk: events
.into_iter() .into_iter()
.map(at!(1)) .map(at!(1))
@@ -183,35 +222,17 @@ async fn paginate_relations_with_filter(
async fn get_relations( async fn get_relations(
services: &Services, services: &Services,
sender_user: &UserId, sender_user: &UserId,
room_id: &RoomId, shortroomid: ShortRoomId,
target: &EventId, target: PduCount,
from: PduCount, from: Option<PduCount>,
limit: usize, limit: usize,
max_depth: u8, max_depth: u8,
dir: Direction, dir: Direction,
) -> impl Stream<Item = (PduCount, Pdu)> + Send { ) -> impl Stream<Item = (PduCount, Pdu)> + Send {
let room_id = services.short.get_shortroomid(room_id);
let target = services
.timeline
.get_pdu_count(target)
.map_ok(|target| {
match target {
| PduCount::Normal(count) => count,
| _ => {
// TODO: Support backfilled relations
0 // This will result in an empty iterator
},
}
});
let Ok((room_id, target)) = try_join(room_id, target).await else {
return Vec::new().into_iter().stream();
};
let mut pdus: Vec<_> = services let mut pdus: Vec<_> = services
.pdu_metadata .pdu_metadata
.get_relations(room_id, target.into(), from, dir, Some(sender_user)) .get_relations(shortroomid, target, from, dir, Some(sender_user))
.take(limit)
.collect() .collect()
.await; .await;
@@ -221,28 +242,25 @@ async fn get_relations(
.map(|(count, _)| (*count, 1)) .map(|(count, _)| (*count, 1))
.collect(); .collect();
'limit: while let Some((count, depth)) = stack.pop() { 'limit: while let Some((target, depth)) = stack.pop() {
let target = match count { let PduCount::Normal(target) = target else {
| PduCount::Normal(count) => count, continue;
| PduCount::Backfilled(_) => {
// TODO: Support backfilled relations
0 // This will result in an empty iterator
},
}; };
let relations: Vec<_> = services let relations: Vec<_> = services
.pdu_metadata .pdu_metadata
.get_relations(room_id, target.into(), from, dir, Some(sender_user)) .get_relations(shortroomid, target.into(), from, dir, Some(sender_user))
.take(limit.saturating_sub(pdus.len()))
.collect() .collect()
.await; .await;
for (count, pdu) in relations { for (target, pdu) in relations {
if depth < max_depth { if depth < max_depth {
stack.push((count, depth.saturating_add(1))); stack.push((target, depth.saturating_add(1)));
} }
if pdus.len() < limit { if pdus.len() < limit {
pdus.push((count, pdu)); pdus.push((target, pdu));
} else { } else {
break 'limit; break 'limit;
} }

View File

@@ -65,12 +65,19 @@ pub fn get_relations<'a>(
&'a self, &'a self,
shortroomid: ShortRoomId, shortroomid: ShortRoomId,
target: PduCount, target: PduCount,
from: PduCount, from: Option<PduCount>,
dir: Direction, dir: Direction,
user_id: Option<&'a UserId>, user_id: Option<&'a UserId>,
) -> impl Stream<Item = (PduCount, Pdu)> + Send + '_ { ) -> impl Stream<Item = (PduCount, Pdu)> + Send + '_ {
let target = target.to_be_bytes(); let target = target.to_be_bytes();
let from = from.saturating_inc(dir).to_be_bytes(); let from = from
.map(|from| from.saturating_inc(dir))
.unwrap_or_else(|| match dir {
| Direction::Backward => PduCount::max(),
| Direction::Forward => PduCount::default(),
})
.to_be_bytes();
let mut buf = ArrayVec::<u8, 16>::new(); let mut buf = ArrayVec::<u8, 16>::new();
let start = { let start = {
buf.extend(target); buf.extend(target);
@@ -79,16 +86,16 @@ pub fn get_relations<'a>(
}; };
match dir { match dir {
| Direction::Forward => self
.db
.tofrom_relation
.raw_keys_from(start)
.boxed(),
| Direction::Backward => self | Direction::Backward => self
.db .db
.tofrom_relation .tofrom_relation
.rev_raw_keys_from(start) .rev_raw_keys_from(start)
.boxed(), .boxed(),
| Direction::Forward => self
.db
.tofrom_relation
.raw_keys_from(start)
.boxed(),
} }
.ignore_err() .ignore_err()
.ready_take_while(move |key| key.starts_with(&target)) .ready_take_while(move |key| key.starts_with(&target))

View File

@@ -508,7 +508,7 @@
{"Action":"pass","Test":"TestRegistration/parallel/POST_{}_returns_a_set_of_flows"} {"Action":"pass","Test":"TestRegistration/parallel/POST_{}_returns_a_set_of_flows"}
{"Action":"pass","Test":"TestRegistration/parallel/Registration_accepts_non-ascii_passwords"} {"Action":"pass","Test":"TestRegistration/parallel/Registration_accepts_non-ascii_passwords"}
{"Action":"pass","Test":"TestRelations"} {"Action":"pass","Test":"TestRelations"}
{"Action":"fail","Test":"TestRelationsPagination"} {"Action":"pass","Test":"TestRelationsPagination"}
{"Action":"pass","Test":"TestRelationsPaginationSync"} {"Action":"pass","Test":"TestRelationsPaginationSync"}
{"Action":"pass","Test":"TestRemoteAliasRequestsUnderstandUnicode"} {"Action":"pass","Test":"TestRemoteAliasRequestsUnderstandUnicode"}
{"Action":"pass","Test":"TestRemotePngThumbnail"} {"Action":"pass","Test":"TestRemotePngThumbnail"}