Files
tuwunel/src/core/utils/set.rs
2026-03-03 06:12:09 +00:00

118 lines
2.8 KiB
Rust

use std::{
cmp::{Eq, Ord},
convert::identity,
pin::Pin,
sync::Arc,
};
use futures::{
Stream, StreamExt,
stream::{Peekable, unfold},
};
use tokio::sync::Mutex;
use crate::{is_equal_to, is_less_than, utils::stream::ReadyExt};
/// Intersection of sets
///
/// Outputs the set of elements common to all input sets. Inputs do not have to
/// be sorted. If inputs are sorted a more optimized function is available in
/// this suite and should be used.
pub fn intersection<Item, Iter, Iters>(mut input: Iters) -> impl Iterator<Item = Item> + Send
where
Iters: Iterator<Item = Iter> + Clone + Send,
Iter: Iterator<Item = Item> + Send,
Item: Eq,
{
input.next().into_iter().flat_map(move |first| {
let input = input.clone();
first.filter(move |targ| {
input
.clone()
.all(|mut other| other.any(is_equal_to!(*targ)))
})
})
}
/// Intersection of sets
///
/// Outputs the set of elements common to all input sets. Inputs must be sorted.
pub fn intersection_sorted<Item, Iter, Iters>(
mut input: Iters,
) -> impl Iterator<Item = Item> + Send
where
Iters: Iterator<Item = Iter> + Clone + Send,
Iter: Iterator<Item = Item> + Send,
Item: Eq + Ord,
{
input.next().into_iter().flat_map(move |first| {
let mut input = input.clone().collect::<Vec<_>>();
first.filter(move |targ| {
input.iter_mut().all(|it| {
it.by_ref()
.skip_while(is_less_than!(targ))
.peekable()
.peek()
.is_some_and(is_equal_to!(targ))
})
})
})
}
/// Intersection of sets
///
/// Outputs the set of elements common to both streams. Streams must be sorted.
pub fn intersection_sorted_stream2<S, Item>(a: S, b: S) -> impl Stream<Item = Item> + Send
where
S: Stream<Item = Item> + Send + Unpin,
Item: Eq + PartialOrd + Send + Sync,
{
struct State<S: Stream> {
a: S,
b: Peekable<S>,
}
unfold(State { a, b: b.peekable() }, async |mut state| {
let ai = state.a.next().await?;
while let Some(bi) = Pin::new(&mut state.b)
.next_if(|bi| *bi <= ai)
.await
.as_ref()
{
if ai == *bi {
return Some((Some(ai), state));
}
}
Some((None, state))
})
.ready_filter_map(identity)
}
/// Difference of sets
///
/// Outputs the set of elements found in `a` which are not found in `b`. Streams
/// must be sorted.
pub fn difference_sorted_stream2<Item, A, B>(a: A, b: B) -> impl Stream<Item = Item> + Send
where
A: Stream<Item = Item> + Send,
B: Stream<Item = Item> + Send + Unpin,
Item: Eq + PartialOrd + Send + Sync,
{
let b = Arc::new(Mutex::new(b.peekable()));
a.map(move |ai| (ai, b.clone()))
.filter_map(async move |(ai, b)| {
let mut lock = b.lock().await;
let b = &mut Pin::new(&mut *lock);
while b.as_mut().next_if(|bi| *bi < ai).await.is_some() {
continue;
}
b.as_mut()
.next_if_eq(&ai)
.await
.is_none()
.then_some(ai)
})
}