chore: checkpoint before Python removal

This commit is contained in:
2026-03-26 22:33:59 +00:00
parent 683cec9307
commit e568ddf82a
29972 changed files with 11269302 additions and 2 deletions

80
vendor/tokio-util/src/cfg.rs vendored Normal file
View File

@@ -0,0 +1,80 @@
macro_rules! cfg_codec {
($($item:item)*) => {
$(
#[cfg(feature = "codec")]
#[cfg_attr(docsrs, doc(cfg(feature = "codec")))]
$item
)*
}
}
macro_rules! cfg_compat {
($($item:item)*) => {
$(
#[cfg(feature = "compat")]
#[cfg_attr(docsrs, doc(cfg(feature = "compat")))]
$item
)*
}
}
macro_rules! cfg_net {
($($item:item)*) => {
$(
#[cfg(all(feature = "net", feature = "codec"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "net", feature = "codec"))))]
$item
)*
}
}
macro_rules! cfg_io {
($($item:item)*) => {
$(
#[cfg(feature = "io")]
#[cfg_attr(docsrs, doc(cfg(feature = "io")))]
$item
)*
}
}
cfg_io! {
macro_rules! cfg_io_util {
($($item:item)*) => {
$(
#[cfg(feature = "io-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
$item
)*
}
}
}
macro_rules! cfg_rt {
($($item:item)*) => {
$(
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
$item
)*
}
}
macro_rules! cfg_not_rt {
($($item:item)*) => {
$(
#[cfg(not(feature = "rt"))]
$item
)*
}
}
macro_rules! cfg_time {
($($item:item)*) => {
$(
#[cfg(feature = "time")]
#[cfg_attr(docsrs, doc(cfg(feature = "time")))]
$item
)*
}
}

View File

@@ -0,0 +1,261 @@
use crate::codec::decoder::Decoder;
use crate::codec::encoder::Encoder;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::{cmp, fmt, io, str};
const DEFAULT_SEEK_DELIMITERS: &[u8] = b",;\n\r";
const DEFAULT_SEQUENCE_WRITER: &[u8] = b",";
/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into chunks based on any character in the given delimiter string.
///
/// [`Decoder`]: crate::codec::Decoder
/// [`Encoder`]: crate::codec::Encoder
///
/// # Example
/// Decode string of bytes containing various different delimiters.
///
/// [`BytesMut`]: bytes::BytesMut
/// [`Error`]: std::io::Error
///
/// ```
/// use tokio_util::codec::{AnyDelimiterCodec, Decoder};
/// use bytes::{BufMut, BytesMut};
///
/// #
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> Result<(), std::io::Error> {
/// let mut codec = AnyDelimiterCodec::new(b",;\r\n".to_vec(),b";".to_vec());
/// let buf = &mut BytesMut::new();
/// buf.reserve(200);
/// buf.put_slice(b"chunk 1,chunk 2;chunk 3\n\r");
/// assert_eq!("chunk 1", codec.decode(buf).unwrap().unwrap());
/// assert_eq!("chunk 2", codec.decode(buf).unwrap().unwrap());
/// assert_eq!("chunk 3", codec.decode(buf).unwrap().unwrap());
/// assert_eq!("", codec.decode(buf).unwrap().unwrap());
/// assert_eq!(None, codec.decode(buf).unwrap());
/// # Ok(())
/// # }
/// ```
///
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct AnyDelimiterCodec {
// Stored index of the next index to examine for the delimiter character.
// This is used to optimize searching.
// For example, if `decode` was called with `abc` and the delimiter is '{}', it would hold `3`,
// because that is the next index to examine.
// The next time `decode` is called with `abcde}`, the method will
// only look at `de}` before returning.
next_index: usize,
/// The maximum length for a given chunk. If `usize::MAX`, chunks will be
/// read until a delimiter character is reached.
max_length: usize,
/// Are we currently discarding the remainder of a chunk which was over
/// the length limit?
is_discarding: bool,
/// The bytes that are using for search during decode
seek_delimiters: Vec<u8>,
/// The bytes that are using for encoding
sequence_writer: Vec<u8>,
}
impl AnyDelimiterCodec {
/// Returns a `AnyDelimiterCodec` for splitting up data into chunks.
///
/// # Note
///
/// The returned `AnyDelimiterCodec` will not have an upper bound on the length
/// of a buffered chunk. See the documentation for [`new_with_max_length`]
/// for information on why this could be a potential security risk.
///
/// [`new_with_max_length`]: crate::codec::AnyDelimiterCodec::new_with_max_length()
pub fn new(seek_delimiters: Vec<u8>, sequence_writer: Vec<u8>) -> AnyDelimiterCodec {
AnyDelimiterCodec {
next_index: 0,
max_length: usize::MAX,
is_discarding: false,
seek_delimiters,
sequence_writer,
}
}
/// Returns a `AnyDelimiterCodec` with a maximum chunk length limit.
///
/// If this is set, calls to `AnyDelimiterCodec::decode` will return a
/// [`AnyDelimiterCodecError`] when a chunk exceeds the length limit. Subsequent calls
/// will discard up to `limit` bytes from that chunk until a delimiter
/// character is reached, returning `None` until the delimiter over the limit
/// has been fully discarded. After that point, calls to `decode` will
/// function as normal.
///
/// # Note
///
/// Setting a length limit is highly recommended for any `AnyDelimiterCodec` which
/// will be exposed to untrusted input. Otherwise, the size of the buffer
/// that holds the chunk currently being read is unbounded. An attacker could
/// exploit this unbounded buffer by sending an unbounded amount of input
/// without any delimiter characters, causing unbounded memory consumption.
///
/// [`AnyDelimiterCodecError`]: crate::codec::AnyDelimiterCodecError
pub fn new_with_max_length(
seek_delimiters: Vec<u8>,
sequence_writer: Vec<u8>,
max_length: usize,
) -> Self {
AnyDelimiterCodec {
max_length,
..AnyDelimiterCodec::new(seek_delimiters, sequence_writer)
}
}
/// Returns the maximum chunk length when decoding.
///
/// ```
/// use std::usize;
/// use tokio_util::codec::AnyDelimiterCodec;
///
/// let codec = AnyDelimiterCodec::new(b",;\n".to_vec(), b";".to_vec());
/// assert_eq!(codec.max_length(), usize::MAX);
/// ```
/// ```
/// use tokio_util::codec::AnyDelimiterCodec;
///
/// let codec = AnyDelimiterCodec::new_with_max_length(b",;\n".to_vec(), b";".to_vec(), 256);
/// assert_eq!(codec.max_length(), 256);
/// ```
pub fn max_length(&self) -> usize {
self.max_length
}
}
impl Decoder for AnyDelimiterCodec {
type Item = Bytes;
type Error = AnyDelimiterCodecError;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
loop {
// Determine how far into the buffer we'll search for a delimiter. If
// there's no max_length set, we'll read to the end of the buffer.
let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
let new_chunk_offset = buf[self.next_index..read_to]
.iter()
.position(|b| self.seek_delimiters.contains(b));
match (self.is_discarding, new_chunk_offset) {
(true, Some(offset)) => {
// If we found a new chunk, discard up to that offset and
// then stop discarding. On the next iteration, we'll try
// to read a chunk normally.
buf.advance(offset + self.next_index + 1);
self.is_discarding = false;
self.next_index = 0;
}
(true, None) => {
// Otherwise, we didn't find a new chunk, so we'll discard
// everything we read. On the next iteration, we'll continue
// discarding up to max_len bytes unless we find a new chunk.
buf.advance(read_to);
self.next_index = 0;
if buf.is_empty() {
return Ok(None);
}
}
(false, Some(offset)) => {
// Found a chunk!
let new_chunk_index = offset + self.next_index;
self.next_index = 0;
let mut chunk = buf.split_to(new_chunk_index + 1);
chunk.truncate(chunk.len() - 1);
let chunk = chunk.freeze();
return Ok(Some(chunk));
}
(false, None) if buf.len() > self.max_length => {
// Reached the maximum length without finding a
// new chunk, return an error and start discarding on the
// next call.
self.is_discarding = true;
return Err(AnyDelimiterCodecError::MaxChunkLengthExceeded);
}
(false, None) => {
// We didn't find a chunk or reach the length limit, so the next
// call will resume searching at the current offset.
self.next_index = read_to;
return Ok(None);
}
}
}
}
fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, AnyDelimiterCodecError> {
Ok(match self.decode(buf)? {
Some(frame) => Some(frame),
None => {
// return remaining data, if any
if buf.is_empty() {
None
} else {
let chunk = buf.split_to(buf.len());
self.next_index = 0;
Some(chunk.freeze())
}
}
})
}
}
impl<T> Encoder<T> for AnyDelimiterCodec
where
T: AsRef<str>,
{
type Error = AnyDelimiterCodecError;
fn encode(&mut self, chunk: T, buf: &mut BytesMut) -> Result<(), AnyDelimiterCodecError> {
let chunk = chunk.as_ref();
buf.reserve(chunk.len() + self.sequence_writer.len());
buf.put(chunk.as_bytes());
buf.put(self.sequence_writer.as_ref());
Ok(())
}
}
impl Default for AnyDelimiterCodec {
fn default() -> Self {
Self::new(
DEFAULT_SEEK_DELIMITERS.to_vec(),
DEFAULT_SEQUENCE_WRITER.to_vec(),
)
}
}
/// An error occurred while encoding or decoding a chunk.
#[derive(Debug)]
pub enum AnyDelimiterCodecError {
/// The maximum chunk length was exceeded.
MaxChunkLengthExceeded,
/// An IO error occurred.
Io(io::Error),
}
impl fmt::Display for AnyDelimiterCodecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AnyDelimiterCodecError::MaxChunkLengthExceeded => {
write!(f, "max chunk length exceeded")
}
AnyDelimiterCodecError::Io(e) => write!(f, "{e}"),
}
}
}
impl From<io::Error> for AnyDelimiterCodecError {
fn from(e: io::Error) -> AnyDelimiterCodecError {
AnyDelimiterCodecError::Io(e)
}
}
impl std::error::Error for AnyDelimiterCodecError {}

View File

@@ -0,0 +1,86 @@
use crate::codec::decoder::Decoder;
use crate::codec::encoder::Encoder;
use bytes::{BufMut, Bytes, BytesMut};
use std::io;
/// A simple [`Decoder`] and [`Encoder`] implementation that just ships bytes around.
///
/// [`Decoder`]: crate::codec::Decoder
/// [`Encoder`]: crate::codec::Encoder
///
/// # Example
///
/// Turn an [`AsyncRead`] into a stream of `Result<`[`BytesMut`]`, `[`Error`]`>`.
///
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`BytesMut`]: bytes::BytesMut
/// [`Error`]: std::io::Error
///
/// ```
/// # mod hidden {
/// # #[allow(unused_imports)]
/// use tokio::fs::File;
/// # }
/// use tokio::io::AsyncRead;
/// use tokio_util::codec::{FramedRead, BytesCodec};
///
/// # enum File {}
/// # impl File {
/// # async fn open(_name: &str) -> Result<impl AsyncRead, std::io::Error> {
/// # use std::io::Cursor;
/// # Ok(Cursor::new(vec![0, 1, 2, 3, 4, 5]))
/// # }
/// # }
/// #
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> Result<(), std::io::Error> {
/// let my_async_read = File::open("filename.txt").await?;
/// let my_stream_of_bytes = FramedRead::new(my_async_read, BytesCodec::new());
/// # Ok(())
/// # }
/// ```
///
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
pub struct BytesCodec(());
impl BytesCodec {
/// Creates a new `BytesCodec` for shipping around raw bytes.
pub fn new() -> BytesCodec {
BytesCodec(())
}
}
impl Decoder for BytesCodec {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
if !buf.is_empty() {
let len = buf.len();
Ok(Some(buf.split_to(len)))
} else {
Ok(None)
}
}
}
impl Encoder<Bytes> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}
impl Encoder<BytesMut> for BytesCodec {
type Error = io::Error;
fn encode(&mut self, data: BytesMut, buf: &mut BytesMut) -> Result<(), io::Error> {
buf.reserve(data.len());
buf.put(data);
Ok(())
}
}

184
vendor/tokio-util/src/codec/decoder.rs vendored Normal file
View File

@@ -0,0 +1,184 @@
use crate::codec::Framed;
use tokio::io::{AsyncRead, AsyncWrite};
use bytes::BytesMut;
use std::io;
/// Decoding of frames via buffers.
///
/// This trait is used when constructing an instance of [`Framed`] or
/// [`FramedRead`]. An implementation of `Decoder` takes a byte stream that has
/// already been buffered in `src` and decodes the data into a stream of
/// `Self::Item` frames.
///
/// Implementations are able to track state on `self`, which enables
/// implementing stateful streaming parsers. In many cases, though, this type
/// will simply be a unit struct (e.g. `struct HttpDecoder`).
///
/// For some underlying data-sources, namely files and FIFOs,
/// it's possible to temporarily read 0 bytes by reaching EOF.
///
/// In these cases `decode_eof` will be called until it signals
/// fulfillment of all closing frames by returning `Ok(None)`.
/// After that, repeated attempts to read from the [`Framed`] or [`FramedRead`]
/// will not invoke `decode` or `decode_eof` again, until data can be read
/// during a retry.
///
/// It is up to the Decoder to keep track of a restart after an EOF,
/// and to decide how to handle such an event by, for example,
/// allowing frames to cross EOF boundaries, re-emitting opening frames, or
/// resetting the entire internal state.
///
/// [`Framed`]: crate::codec::Framed
/// [`FramedRead`]: crate::codec::FramedRead
pub trait Decoder {
/// The type of decoded frames.
type Item;
/// The type of unrecoverable frame decoding errors.
///
/// If an individual message is ill-formed but can be ignored without
/// interfering with the processing of future messages, it may be more
/// useful to report the failure as an `Item`.
///
/// `From<io::Error>` is required in the interest of making `Error` suitable
/// for returning directly from a [`FramedRead`], and to enable the default
/// implementation of `decode_eof` to yield an `io::Error` when the decoder
/// fails to consume all available data.
///
/// Note that implementors of this trait can simply indicate `type Error =
/// io::Error` to use I/O errors as this type.
///
/// [`FramedRead`]: crate::codec::FramedRead
type Error: From<io::Error>;
/// Attempts to decode a frame from the provided buffer of bytes.
///
/// This method is called by [`FramedRead`] whenever bytes are ready to be
/// parsed. The provided buffer of bytes is what's been read so far, and
/// this instance of `Decode` can determine whether an entire frame is in
/// the buffer and is ready to be returned.
///
/// If an entire frame is available, then this instance will remove those
/// bytes from the buffer provided and return them as a decoded
/// frame. Note that removing bytes from the provided buffer doesn't always
/// necessarily copy the bytes, so this should be an efficient operation in
/// most circumstances.
///
/// If the bytes look valid, but a frame isn't fully available yet, then
/// `Ok(None)` is returned. This indicates to the [`Framed`] instance that
/// it needs to read some more bytes before calling this method again.
///
/// Note that the bytes provided may be empty. If a previous call to
/// `decode` consumed all the bytes in the buffer then `decode` will be
/// called again until it returns `Ok(None)`, indicating that more bytes need to
/// be read.
///
/// Finally, if the bytes in the buffer are malformed then an error is
/// returned indicating why. This informs [`Framed`] that the stream is now
/// corrupt and should be terminated.
///
/// [`Framed`]: crate::codec::Framed
/// [`FramedRead`]: crate::codec::FramedRead
///
/// # Buffer management
///
/// Before returning from the function, implementations should ensure that
/// the buffer has appropriate capacity in anticipation of future calls to
/// `decode`. Failing to do so leads to inefficiency.
///
/// For example, if frames have a fixed length, or if the length of the
/// current frame is known from a header, a possible buffer management
/// strategy is:
///
/// ```no_run
/// # use std::io;
/// #
/// # use bytes::BytesMut;
/// # use tokio_util::codec::Decoder;
/// #
/// # struct MyCodec;
/// #
/// impl Decoder for MyCodec {
/// // ...
/// # type Item = BytesMut;
/// # type Error = io::Error;
///
/// fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
/// // ...
///
/// // Reserve enough to complete decoding of the current frame.
/// let current_frame_len: usize = 1000; // Example.
/// // And to start decoding the next frame.
/// let next_frame_header_len: usize = 10; // Example.
/// src.reserve(current_frame_len + next_frame_header_len);
///
/// return Ok(None);
/// }
/// }
/// ```
///
/// An optimal buffer management strategy minimizes reallocations and
/// over-allocations.
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>;
/// A default method available to be called when there are no more bytes
/// available to be read from the underlying I/O.
///
/// This method defaults to calling `decode` and returns an error if
/// `Ok(None)` is returned while there is unconsumed data in `buf`.
/// Typically this doesn't need to be implemented unless the framing
/// protocol differs near the end of the stream, or if you need to construct
/// frames _across_ eof boundaries on sources that can be resumed.
///
/// Note that the `buf` argument may be empty. If a previous call to
/// `decode_eof` consumed all the bytes in the buffer, `decode_eof` will be
/// called again until it returns `None`, indicating that there are no more
/// frames to yield. This behavior enables returning finalization frames
/// that may not be based on inbound data.
///
/// Once `None` has been returned, `decode_eof` won't be called again until
/// an attempt to resume the stream has been made, where the underlying stream
/// actually returned more data.
fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.decode(buf)? {
Some(frame) => Ok(Some(frame)),
None => {
if buf.is_empty() {
Ok(None)
} else {
Err(io::Error::new(io::ErrorKind::Other, "bytes remaining on stream").into())
}
}
}
}
/// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
/// `Io` object, using `Decode` and `Encode` to read and write the raw data.
///
/// Raw I/O objects work with byte sequences, but higher-level code usually
/// wants to batch these into meaningful chunks, called "frames". This
/// method layers framing on top of an I/O object, by using the `Codec`
/// traits to handle encoding and decoding of messages frames. Note that
/// the incoming and outgoing frame types may be distinct.
///
/// This function returns a *single* object that is both `Stream` and
/// `Sink`; grouping this into a single object is often useful for layering
/// things like gzip or TLS, which require both read and write access to the
/// underlying object.
///
/// If you want to work more directly with the streams and sink, consider
/// calling `split` on the [`Framed`] returned by this method, which will
/// break them into separate objects, allowing them to interact more easily.
///
/// [`Stream`]: futures_core::Stream
/// [`Sink`]: futures_sink::Sink
/// [`Framed`]: crate::codec::Framed
fn framed<T: AsyncRead + AsyncWrite + Sized>(self, io: T) -> Framed<T, Self>
where
Self: Sized,
{
Framed::new(io, self)
}
}

25
vendor/tokio-util/src/codec/encoder.rs vendored Normal file
View File

@@ -0,0 +1,25 @@
use bytes::BytesMut;
use std::io;
/// Trait of helper objects to write out messages as bytes, for use with
/// [`FramedWrite`].
///
/// [`FramedWrite`]: crate::codec::FramedWrite
pub trait Encoder<Item> {
/// The type of encoding errors.
///
/// [`FramedWrite`] requires `Encoder`s errors to implement `From<io::Error>`
/// in the interest of letting it return `Error`s directly.
///
/// [`FramedWrite`]: crate::codec::FramedWrite
type Error: From<io::Error>;
/// Encodes a frame into the buffer provided.
///
/// This method will encode `item` into the byte buffer provided by `dst`.
/// The `dst` provided is an internal buffer of the [`FramedWrite`] instance and
/// will be written out when possible.
///
/// [`FramedWrite`]: crate::codec::FramedWrite
fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Self::Error>;
}

392
vendor/tokio-util/src/codec/framed.rs vendored Normal file
View File

@@ -0,0 +1,392 @@
use crate::codec::decoder::Decoder;
use crate::codec::encoder::Encoder;
use crate::codec::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame};
use futures_core::Stream;
use tokio::io::{AsyncRead, AsyncWrite};
use bytes::BytesMut;
use futures_sink::Sink;
use pin_project_lite::pin_project;
use std::fmt;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
pin_project! {
/// A unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using
/// the `Encoder` and `Decoder` traits to encode and decode frames.
///
/// You can create a `Framed` instance by using the [`Decoder::framed`] adapter, or
/// by using the `new` function seen below.
///
/// # Cancellation safety
///
/// * [`futures_util::sink::SinkExt::send`]: if send is used as the event in a
/// `tokio::select!` statement and some other branch completes first, then it is
/// guaranteed that the message was not sent, but the message itself is lost.
/// * [`tokio_stream::StreamExt::next`]: This method is cancel safe. The returned
/// future only holds onto a reference to the underlying stream, so dropping it will
/// never lose a value.
///
/// [`Stream`]: futures_core::Stream
/// [`Sink`]: futures_sink::Sink
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`Decoder::framed`]: crate::codec::Decoder::framed()
/// [`futures_util::sink::SinkExt::send`]: futures_util::sink::SinkExt::send
/// [`tokio_stream::StreamExt::next`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html#method.next
pub struct Framed<T, U> {
#[pin]
inner: FramedImpl<T, U, RWFrames>
}
}
impl<T, U> Framed<T, U> {
/// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
/// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data.
///
/// Raw I/O objects work with byte sequences, but higher-level code usually
/// wants to batch these into meaningful chunks, called "frames". This
/// method layers framing on top of an I/O object, by using the codec
/// traits to handle encoding and decoding of messages frames. Note that
/// the incoming and outgoing frame types may be distinct.
///
/// This function returns a *single* object that is both [`Stream`] and
/// [`Sink`]; grouping this into a single object is often useful for layering
/// things like gzip or TLS, which require both read and write access to the
/// underlying object.
///
/// If you want to work more directly with the streams and sink, consider
/// calling [`split`] on the `Framed` returned by this method, which will
/// break them into separate objects, allowing them to interact more easily.
///
/// Note that, for some byte sources, the stream can be resumed after an EOF
/// by reading from it, even after it has returned `None`. Repeated attempts
/// to do so, without new data available, continue to return `None` without
/// creating more (closing) frames.
///
/// [`Stream`]: futures_core::Stream
/// [`Sink`]: futures_sink::Sink
/// [`Decode`]: crate::codec::Decoder
/// [`Encoder`]: crate::codec::Encoder
/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
pub fn new(inner: T, codec: U) -> Framed<T, U> {
Framed {
inner: FramedImpl {
inner,
codec,
state: Default::default(),
},
}
}
/// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
/// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data,
/// with a specific read buffer initial capacity.
///
/// Raw I/O objects work with byte sequences, but higher-level code usually
/// wants to batch these into meaningful chunks, called "frames". This
/// method layers framing on top of an I/O object, by using the codec
/// traits to handle encoding and decoding of messages frames. Note that
/// the incoming and outgoing frame types may be distinct.
///
/// This function returns a *single* object that is both [`Stream`] and
/// [`Sink`]; grouping this into a single object is often useful for layering
/// things like gzip or TLS, which require both read and write access to the
/// underlying object.
///
/// If you want to work more directly with the streams and sink, consider
/// calling [`split`] on the `Framed` returned by this method, which will
/// break them into separate objects, allowing them to interact more easily.
///
/// [`Stream`]: futures_core::Stream
/// [`Sink`]: futures_sink::Sink
/// [`Decode`]: crate::codec::Decoder
/// [`Encoder`]: crate::codec::Encoder
/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
pub fn with_capacity(inner: T, codec: U, capacity: usize) -> Framed<T, U> {
Framed {
inner: FramedImpl {
inner,
codec,
state: RWFrames {
read: ReadFrame {
eof: false,
is_readable: false,
buffer: BytesMut::with_capacity(capacity),
has_errored: false,
},
write: WriteFrame {
buffer: BytesMut::with_capacity(capacity),
backpressure_boundary: capacity,
},
},
},
}
}
/// Provides a [`Stream`] and [`Sink`] interface for reading and writing to this
/// I/O object, using [`Decoder`] and [`Encoder`] to read and write the raw data.
///
/// Raw I/O objects work with byte sequences, but higher-level code usually
/// wants to batch these into meaningful chunks, called "frames". This
/// method layers framing on top of an I/O object, by using the `Codec`
/// traits to handle encoding and decoding of messages frames. Note that
/// the incoming and outgoing frame types may be distinct.
///
/// This function returns a *single* object that is both [`Stream`] and
/// [`Sink`]; grouping this into a single object is often useful for layering
/// things like gzip or TLS, which require both read and write access to the
/// underlying object.
///
/// This objects takes a stream and a `readbuffer` and a `writebuffer`. These field
/// can be obtained from an existing `Framed` with the [`into_parts`] method.
///
/// If you want to work more directly with the streams and sink, consider
/// calling [`split`] on the `Framed` returned by this method, which will
/// break them into separate objects, allowing them to interact more easily.
///
/// [`Stream`]: futures_core::Stream
/// [`Sink`]: futures_sink::Sink
/// [`Decoder`]: crate::codec::Decoder
/// [`Encoder`]: crate::codec::Encoder
/// [`into_parts`]: crate::codec::Framed::into_parts()
/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> {
Framed {
inner: FramedImpl {
inner: parts.io,
codec: parts.codec,
state: RWFrames {
read: parts.read_buf.into(),
write: parts.write_buf.into(),
},
},
}
}
/// Returns a reference to the underlying I/O stream wrapped by
/// `Framed`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_ref(&self) -> &T {
&self.inner.inner
}
/// Returns a mutable reference to the underlying I/O stream wrapped by
/// `Framed`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner.inner
}
/// Returns a pinned mutable reference to the underlying I/O stream wrapped by
/// `Framed`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().inner.project().inner
}
/// Returns a reference to the underlying codec wrapped by
/// `Framed`.
///
/// Note that care should be taken to not tamper with the underlying codec
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn codec(&self) -> &U {
&self.inner.codec
}
/// Returns a mutable reference to the underlying codec wrapped by
/// `Framed`.
///
/// Note that care should be taken to not tamper with the underlying codec
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn codec_mut(&mut self) -> &mut U {
&mut self.inner.codec
}
/// Maps the codec `U` to `C`, preserving the read and write buffers
/// wrapped by `Framed`.
///
/// Note that care should be taken to not tamper with the underlying codec
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn map_codec<C, F>(self, map: F) -> Framed<T, C>
where
F: FnOnce(U) -> C,
{
// This could be potentially simplified once rust-lang/rust#86555 hits stable
let parts = self.into_parts();
Framed::from_parts(FramedParts {
io: parts.io,
codec: map(parts.codec),
read_buf: parts.read_buf,
write_buf: parts.write_buf,
_priv: (),
})
}
/// Returns a mutable reference to the underlying codec wrapped by
/// `Framed`.
///
/// Note that care should be taken to not tamper with the underlying codec
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn codec_pin_mut(self: Pin<&mut Self>) -> &mut U {
self.project().inner.project().codec
}
/// Returns a reference to the read buffer.
pub fn read_buffer(&self) -> &BytesMut {
&self.inner.state.read.buffer
}
/// Returns a mutable reference to the read buffer.
pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
&mut self.inner.state.read.buffer
}
/// Returns a reference to the write buffer.
pub fn write_buffer(&self) -> &BytesMut {
&self.inner.state.write.buffer
}
/// Returns a mutable reference to the write buffer.
pub fn write_buffer_mut(&mut self) -> &mut BytesMut {
&mut self.inner.state.write.buffer
}
/// Returns backpressure boundary
pub fn backpressure_boundary(&self) -> usize {
self.inner.state.write.backpressure_boundary
}
/// Updates backpressure boundary
pub fn set_backpressure_boundary(&mut self, boundary: usize) {
self.inner.state.write.backpressure_boundary = boundary;
}
/// Consumes the `Framed`, returning its underlying I/O stream.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn into_inner(self) -> T {
self.inner.inner
}
/// Consumes the `Framed`, returning its underlying I/O stream, the buffer
/// with unprocessed data, and the codec.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn into_parts(self) -> FramedParts<T, U> {
FramedParts {
io: self.inner.inner,
codec: self.inner.codec,
read_buf: self.inner.state.read.buffer,
write_buf: self.inner.state.write.buffer,
_priv: (),
}
}
}
// This impl just defers to the underlying FramedImpl
impl<T, U> Stream for Framed<T, U>
where
T: AsyncRead,
U: Decoder,
{
type Item = Result<U::Item, U::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
// This impl just defers to the underlying FramedImpl
impl<T, I, U> Sink<I> for Framed<T, U>
where
T: AsyncWrite,
U: Encoder<I>,
U::Error: From<io::Error>,
{
type Error = U::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<T, U> fmt::Debug for Framed<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Framed")
.field("io", self.get_ref())
.field("codec", self.codec())
.finish()
}
}
/// `FramedParts` contains an export of the data of a Framed transport.
/// It can be used to construct a new [`Framed`] with a different codec.
/// It contains all current buffers and the inner transport.
///
/// [`Framed`]: crate::codec::Framed
#[derive(Debug)]
#[allow(clippy::manual_non_exhaustive)]
pub struct FramedParts<T, U> {
/// The inner transport used to read bytes to and write bytes to
pub io: T,
/// The codec
pub codec: U,
/// The buffer with read but unprocessed data.
pub read_buf: BytesMut,
/// A buffer with unprocessed data which are not written yet.
pub write_buf: BytesMut,
/// This private field allows us to add additional fields in the future in a
/// backwards compatible way.
pub(crate) _priv: (),
}
impl<T, U> FramedParts<T, U> {
/// Create a new, default, `FramedParts`
pub fn new<I>(io: T, codec: U) -> FramedParts<T, U>
where
U: Encoder<I>,
{
FramedParts {
io,
codec,
read_buf: BytesMut::new(),
write_buf: BytesMut::new(),
_priv: (),
}
}
}

View File

@@ -0,0 +1,311 @@
use crate::codec::decoder::Decoder;
use crate::codec::encoder::Encoder;
use futures_core::Stream;
use tokio::io::{AsyncRead, AsyncWrite};
use bytes::BytesMut;
use futures_sink::Sink;
use pin_project_lite::pin_project;
use std::borrow::{Borrow, BorrowMut};
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
pin_project! {
#[derive(Debug)]
pub(crate) struct FramedImpl<T, U, State> {
#[pin]
pub(crate) inner: T,
pub(crate) state: State,
pub(crate) codec: U,
}
}
const INITIAL_CAPACITY: usize = 8 * 1024;
#[derive(Debug)]
pub(crate) struct ReadFrame {
pub(crate) eof: bool,
pub(crate) is_readable: bool,
pub(crate) buffer: BytesMut,
pub(crate) has_errored: bool,
}
pub(crate) struct WriteFrame {
pub(crate) buffer: BytesMut,
pub(crate) backpressure_boundary: usize,
}
#[derive(Default)]
pub(crate) struct RWFrames {
pub(crate) read: ReadFrame,
pub(crate) write: WriteFrame,
}
impl Default for ReadFrame {
fn default() -> Self {
Self {
eof: false,
is_readable: false,
buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
has_errored: false,
}
}
}
impl Default for WriteFrame {
fn default() -> Self {
Self {
buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
backpressure_boundary: INITIAL_CAPACITY,
}
}
}
impl From<BytesMut> for ReadFrame {
fn from(mut buffer: BytesMut) -> Self {
let size = buffer.capacity();
if size < INITIAL_CAPACITY {
buffer.reserve(INITIAL_CAPACITY - size);
}
Self {
buffer,
is_readable: size > 0,
eof: false,
has_errored: false,
}
}
}
impl From<BytesMut> for WriteFrame {
fn from(mut buffer: BytesMut) -> Self {
let size = buffer.capacity();
if size < INITIAL_CAPACITY {
buffer.reserve(INITIAL_CAPACITY - size);
}
Self {
buffer,
backpressure_boundary: INITIAL_CAPACITY,
}
}
}
impl Borrow<ReadFrame> for RWFrames {
fn borrow(&self) -> &ReadFrame {
&self.read
}
}
impl BorrowMut<ReadFrame> for RWFrames {
fn borrow_mut(&mut self) -> &mut ReadFrame {
&mut self.read
}
}
impl Borrow<WriteFrame> for RWFrames {
fn borrow(&self) -> &WriteFrame {
&self.write
}
}
impl BorrowMut<WriteFrame> for RWFrames {
fn borrow_mut(&mut self) -> &mut WriteFrame {
&mut self.write
}
}
impl<T, U, R> Stream for FramedImpl<T, U, R>
where
T: AsyncRead,
U: Decoder,
R: BorrowMut<ReadFrame>,
{
type Item = Result<U::Item, U::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use crate::util::poll_read_buf;
let mut pinned = self.project();
let state: &mut ReadFrame = pinned.state.borrow_mut();
// The following loops implements a state machine with each state corresponding
// to a combination of the `is_readable` and `eof` flags. States persist across
// loop entries and most state transitions occur with a return.
//
// The initial state is `reading`.
//
// | state | eof | is_readable | has_errored |
// |---------|-------|-------------|-------------|
// | reading | false | false | false |
// | framing | false | true | false |
// | pausing | true | true | false |
// | paused | true | false | false |
// | errored | <any> | <any> | true |
// `decode_eof` returns Err
// ┌────────────────────────────────────────────────────────┐
// `decode_eof` returns │ │
// `Ok(Some)` │ │
// ┌─────┐ │ `decode_eof` returns After returning │
// Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
// ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
// │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
// Pending read │ │ │ │ │ │
// ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
// │ │ │ ┌──────┐ │ Pending │ │
// │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
// └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
// └──┬─▲────┘ └─────┬──┬┘ │ │
// │ │ │ │ `decode` returns Err │ │
// │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
// │ read returns Err │
// └────────────────────────────────────────────────────────────────────────────────────────────┘
loop {
// Return `None` if we have encountered an error from the underlying decoder
// See: https://github.com/tokio-rs/tokio/issues/3976
if state.has_errored {
// preparing has_errored -> paused
trace!("Returning None and setting paused");
state.is_readable = false;
state.has_errored = false;
return Poll::Ready(None);
}
// Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
// i.e. it _might_ contain data consumable as a frame or closing frame.
// Both signal that there is no such data by returning `None`.
//
// If `decode` couldn't read a frame and the upstream source has returned eof,
// `decode_eof` will attempt to decode the remaining bytes as closing frames.
//
// If the underlying AsyncRead is resumable, we may continue after an EOF,
// but must finish emitting all of it's associated `decode_eof` frames.
// Furthermore, we don't want to emit any `decode_eof` frames on retried
// reads after an EOF unless we've actually read more data.
if state.is_readable {
// pausing or framing
if state.eof {
// pausing
let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
trace!("Got an error, going to errored state");
state.has_errored = true;
err
})?;
if frame.is_none() {
state.is_readable = false; // prepare pausing -> paused
}
// implicit pausing -> pausing or pausing -> paused
return Poll::Ready(frame.map(Ok));
}
// framing
trace!("attempting to decode a frame");
if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
trace!("Got an error, going to errored state");
state.has_errored = true;
op
})? {
trace!("frame decoded from buffer");
// implicit framing -> framing
return Poll::Ready(Some(Ok(frame)));
}
// framing -> reading
state.is_readable = false;
}
// reading or paused
// If we can't build a frame yet, try to read more data and try again.
// Make sure we've got room for at least one byte to read to ensure
// that we don't get a spurious 0 that looks like EOF.
state.buffer.reserve(1);
#[allow(clippy::blocks_in_conditions)]
let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
|err| {
trace!("Got an error, going to errored state");
state.has_errored = true;
err
},
)? {
Poll::Ready(ct) => ct,
// implicit reading -> reading or implicit paused -> paused
Poll::Pending => return Poll::Pending,
};
if bytect == 0 {
if state.eof {
// We're already at an EOF, and since we've reached this path
// we're also not readable. This implies that we've already finished
// our `decode_eof` handling, so we can simply return `None`.
// implicit paused -> paused
return Poll::Ready(None);
}
// prepare reading -> paused
state.eof = true;
} else {
// prepare paused -> framing or noop reading -> framing
state.eof = false;
}
// paused -> framing or reading -> framing or reading -> pausing
state.is_readable = true;
}
}
}
impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
where
T: AsyncWrite,
U: Encoder<I>,
U::Error: From<io::Error>,
W: BorrowMut<WriteFrame>,
{
type Error = U::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary {
self.as_mut().poll_flush(cx)
} else {
Poll::Ready(Ok(()))
}
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
let pinned = self.project();
pinned
.codec
.encode(item, &mut pinned.state.borrow_mut().buffer)?;
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
use crate::util::poll_write_buf;
trace!("flushing framed transport");
let mut pinned = self.project();
while !pinned.state.borrow_mut().buffer.is_empty() {
let WriteFrame { buffer, .. } = pinned.state.borrow_mut();
trace!(remaining = buffer.len(), "writing;");
let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
if n == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to \
write frame to transport",
)
.into()));
}
}
// Try flushing the underlying IO
ready!(pinned.inner.poll_flush(cx))?;
trace!("framed transport flushed");
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
ready!(self.project().inner.poll_shutdown(cx))?;
Poll::Ready(Ok(()))
}
}

View File

@@ -0,0 +1,217 @@
use crate::codec::framed_impl::{FramedImpl, ReadFrame};
use crate::codec::Decoder;
use futures_core::Stream;
use tokio::io::AsyncRead;
use bytes::BytesMut;
use futures_sink::Sink;
use pin_project_lite::pin_project;
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};
use super::FramedParts;
pin_project! {
/// A [`Stream`] of messages decoded from an [`AsyncRead`].
///
/// For examples of how to use `FramedRead` with a codec, see the
/// examples on the [`codec`] module.
///
/// # Cancellation safety
/// * [`tokio_stream::StreamExt::next`]: This method is cancel safe. The returned
/// future only holds onto a reference to the underlying stream, so dropping it will
/// never lose a value.
///
/// [`Stream`]: futures_core::Stream
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`codec`]: crate::codec
/// [`tokio_stream::StreamExt::next`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html#method.next
pub struct FramedRead<T, D> {
#[pin]
inner: FramedImpl<T, D, ReadFrame>,
}
}
// ===== impl FramedRead =====
impl<T, D> FramedRead<T, D> {
/// Creates a new `FramedRead` with the given `decoder`.
pub fn new(inner: T, decoder: D) -> FramedRead<T, D> {
FramedRead {
inner: FramedImpl {
inner,
codec: decoder,
state: Default::default(),
},
}
}
/// Creates a new `FramedRead` with the given `decoder` and a buffer of `capacity`
/// initial size.
pub fn with_capacity(inner: T, decoder: D, capacity: usize) -> FramedRead<T, D> {
FramedRead {
inner: FramedImpl {
inner,
codec: decoder,
state: ReadFrame {
eof: false,
is_readable: false,
buffer: BytesMut::with_capacity(capacity),
has_errored: false,
},
},
}
}
/// Returns a reference to the underlying I/O stream wrapped by
/// `FramedRead`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_ref(&self) -> &T {
&self.inner.inner
}
/// Returns a mutable reference to the underlying I/O stream wrapped by
/// `FramedRead`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner.inner
}
/// Returns a pinned mutable reference to the underlying I/O stream wrapped by
/// `FramedRead`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().inner.project().inner
}
/// Consumes the `FramedRead`, returning its underlying I/O stream.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn into_inner(self) -> T {
self.inner.inner
}
/// Returns a reference to the underlying decoder.
pub fn decoder(&self) -> &D {
&self.inner.codec
}
/// Returns a mutable reference to the underlying decoder.
pub fn decoder_mut(&mut self) -> &mut D {
&mut self.inner.codec
}
/// Maps the decoder `D` to `C`, preserving the read buffer
/// wrapped by `Framed`.
pub fn map_decoder<C, F>(self, map: F) -> FramedRead<T, C>
where
F: FnOnce(D) -> C,
{
// This could be potentially simplified once rust-lang/rust#86555 hits stable
let FramedImpl {
inner,
state,
codec,
} = self.inner;
FramedRead {
inner: FramedImpl {
inner,
state,
codec: map(codec),
},
}
}
/// Returns a mutable reference to the underlying decoder.
pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D {
self.project().inner.project().codec
}
/// Returns a reference to the read buffer.
pub fn read_buffer(&self) -> &BytesMut {
&self.inner.state.buffer
}
/// Returns a mutable reference to the read buffer.
pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
&mut self.inner.state.buffer
}
/// Consumes the `FramedRead`, returning its underlying I/O stream, the buffer
/// with unprocessed data, and the codec.
pub fn into_parts(self) -> FramedParts<T, D> {
FramedParts {
io: self.inner.inner,
codec: self.inner.codec,
read_buf: self.inner.state.buffer,
write_buf: BytesMut::new(),
_priv: (),
}
}
}
// This impl just defers to the underlying FramedImpl
impl<T, D> Stream for FramedRead<T, D>
where
T: AsyncRead,
D: Decoder,
{
type Item = Result<D::Item, D::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
// This impl just defers to the underlying T: Sink
impl<T, I, D> Sink<I> for FramedRead<T, D>
where
T: Sink<I>,
{
type Error = T::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.project().inner.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.project().inner.poll_close(cx)
}
}
impl<T, D> fmt::Debug for FramedRead<T, D>
where
T: fmt::Debug,
D: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FramedRead")
.field("inner", &self.get_ref())
.field("decoder", &self.decoder())
.field("eof", &self.inner.state.eof)
.field("is_readable", &self.inner.state.is_readable)
.field("buffer", &self.read_buffer())
.finish()
}
}

View File

@@ -0,0 +1,223 @@
use crate::codec::encoder::Encoder;
use crate::codec::framed_impl::{FramedImpl, WriteFrame};
use futures_core::Stream;
use tokio::io::AsyncWrite;
use bytes::BytesMut;
use futures_sink::Sink;
use pin_project_lite::pin_project;
use std::fmt;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use super::FramedParts;
pin_project! {
/// A [`Sink`] of frames encoded to an `AsyncWrite`.
///
/// For examples of how to use `FramedWrite` with a codec, see the
/// examples on the [`codec`] module.
///
/// # Cancellation safety
///
/// * [`futures_util::sink::SinkExt::send`]: if send is used as the event in a
/// `tokio::select!` statement and some other branch completes first, then it is
/// guaranteed that the message was not sent, but the message itself is lost.
///
/// [`Sink`]: futures_sink::Sink
/// [`codec`]: crate::codec
/// [`futures_util::sink::SinkExt::send`]: futures_util::sink::SinkExt::send
pub struct FramedWrite<T, E> {
#[pin]
inner: FramedImpl<T, E, WriteFrame>,
}
}
impl<T, E> FramedWrite<T, E> {
/// Creates a new `FramedWrite` with the given `encoder`.
pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> {
FramedWrite {
inner: FramedImpl {
inner,
codec: encoder,
state: WriteFrame::default(),
},
}
}
/// Creates a new `FramedWrite` with the given `encoder` and a buffer of `capacity`
/// initial size.
pub fn with_capacity(inner: T, encoder: E, capacity: usize) -> FramedWrite<T, E> {
FramedWrite {
inner: FramedImpl {
inner,
codec: encoder,
state: WriteFrame {
buffer: BytesMut::with_capacity(capacity),
backpressure_boundary: capacity,
},
},
}
}
/// Returns a reference to the underlying I/O stream wrapped by
/// `FramedWrite`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_ref(&self) -> &T {
&self.inner.inner
}
/// Returns a mutable reference to the underlying I/O stream wrapped by
/// `FramedWrite`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner.inner
}
/// Returns a pinned mutable reference to the underlying I/O stream wrapped by
/// `FramedWrite`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().inner.project().inner
}
/// Consumes the `FramedWrite`, returning its underlying I/O stream.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn into_inner(self) -> T {
self.inner.inner
}
/// Returns a reference to the underlying encoder.
pub fn encoder(&self) -> &E {
&self.inner.codec
}
/// Returns a mutable reference to the underlying encoder.
pub fn encoder_mut(&mut self) -> &mut E {
&mut self.inner.codec
}
/// Maps the encoder `E` to `C`, preserving the write buffer
/// wrapped by `Framed`.
pub fn map_encoder<C, F>(self, map: F) -> FramedWrite<T, C>
where
F: FnOnce(E) -> C,
{
// This could be potentially simplified once rust-lang/rust#86555 hits stable
let FramedImpl {
inner,
state,
codec,
} = self.inner;
FramedWrite {
inner: FramedImpl {
inner,
state,
codec: map(codec),
},
}
}
/// Returns a mutable reference to the underlying encoder.
pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E {
self.project().inner.project().codec
}
/// Returns a reference to the write buffer.
pub fn write_buffer(&self) -> &BytesMut {
&self.inner.state.buffer
}
/// Returns a mutable reference to the write buffer.
pub fn write_buffer_mut(&mut self) -> &mut BytesMut {
&mut self.inner.state.buffer
}
/// Returns backpressure boundary
pub fn backpressure_boundary(&self) -> usize {
self.inner.state.backpressure_boundary
}
/// Updates backpressure boundary
pub fn set_backpressure_boundary(&mut self, boundary: usize) {
self.inner.state.backpressure_boundary = boundary;
}
/// Consumes the `FramedWrite`, returning its underlying I/O stream, the buffer
/// with unprocessed data, and the codec.
pub fn into_parts(self) -> FramedParts<T, E> {
FramedParts {
io: self.inner.inner,
codec: self.inner.codec,
read_buf: BytesMut::new(),
write_buf: self.inner.state.buffer,
_priv: (),
}
}
}
// This impl just defers to the underlying FramedImpl
impl<T, I, E> Sink<I> for FramedWrite<T, E>
where
T: AsyncWrite,
E: Encoder<I>,
E::Error: From<io::Error>,
{
type Error = E::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
// This impl just defers to the underlying T: Stream
impl<T, D> Stream for FramedWrite<T, D>
where
T: Stream,
{
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.project().inner.poll_next(cx)
}
}
impl<T, U> fmt::Debug for FramedWrite<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FramedWrite")
.field("inner", &self.get_ref())
.field("encoder", &self.encoder())
.field("buffer", &self.inner.state.buffer)
.finish()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,232 @@
use crate::codec::decoder::Decoder;
use crate::codec::encoder::Encoder;
use bytes::{Buf, BufMut, BytesMut};
use std::{cmp, fmt, io, str};
/// A simple [`Decoder`] and [`Encoder`] implementation that splits up data into lines.
///
/// This uses the `\n` character as the line ending on all platforms.
///
/// [`Decoder`]: crate::codec::Decoder
/// [`Encoder`]: crate::codec::Encoder
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct LinesCodec {
// Stored index of the next index to examine for a `\n` character.
// This is used to optimize searching.
// For example, if `decode` was called with `abc`, it would hold `3`,
// because that is the next index to examine.
// The next time `decode` is called with `abcde\n`, the method will
// only look at `de\n` before returning.
next_index: usize,
/// The maximum length for a given line. If `usize::MAX`, lines will be
/// read until a `\n` character is reached.
max_length: usize,
/// Are we currently discarding the remainder of a line which was over
/// the length limit?
is_discarding: bool,
}
impl LinesCodec {
/// Returns a `LinesCodec` for splitting up data into lines.
///
/// # Note
///
/// The returned `LinesCodec` will not have an upper bound on the length
/// of a buffered line. See the documentation for [`new_with_max_length`]
/// for information on why this could be a potential security risk.
///
/// [`new_with_max_length`]: crate::codec::LinesCodec::new_with_max_length()
pub fn new() -> LinesCodec {
LinesCodec {
next_index: 0,
max_length: usize::MAX,
is_discarding: false,
}
}
/// Returns a `LinesCodec` with a maximum line length limit.
///
/// If this is set, calls to `LinesCodec::decode` will return a
/// [`LinesCodecError`] when a line exceeds the length limit. Subsequent calls
/// will discard up to `limit` bytes from that line until a newline
/// character is reached, returning `None` until the line over the limit
/// has been fully discarded. After that point, calls to `decode` will
/// function as normal.
///
/// # Note
///
/// Setting a length limit is highly recommended for any `LinesCodec` which
/// will be exposed to untrusted input. Otherwise, the size of the buffer
/// that holds the line currently being read is unbounded. An attacker could
/// exploit this unbounded buffer by sending an unbounded amount of input
/// without any `\n` characters, causing unbounded memory consumption.
///
/// [`LinesCodecError`]: crate::codec::LinesCodecError
pub fn new_with_max_length(max_length: usize) -> Self {
LinesCodec {
max_length,
..LinesCodec::new()
}
}
/// Returns the maximum line length when decoding.
///
/// ```
/// use std::usize;
/// use tokio_util::codec::LinesCodec;
///
/// let codec = LinesCodec::new();
/// assert_eq!(codec.max_length(), usize::MAX);
/// ```
/// ```
/// use tokio_util::codec::LinesCodec;
///
/// let codec = LinesCodec::new_with_max_length(256);
/// assert_eq!(codec.max_length(), 256);
/// ```
pub fn max_length(&self) -> usize {
self.max_length
}
}
fn utf8(buf: &[u8]) -> Result<&str, io::Error> {
str::from_utf8(buf)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Unable to decode input as UTF8"))
}
fn without_carriage_return(s: &[u8]) -> &[u8] {
if let Some(&b'\r') = s.last() {
&s[..s.len() - 1]
} else {
s
}
}
impl Decoder for LinesCodec {
type Item = String;
type Error = LinesCodecError;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> {
loop {
// Determine how far into the buffer we'll search for a newline. If
// there's no max_length set, we'll read to the end of the buffer.
let read_to = cmp::min(self.max_length.saturating_add(1), buf.len());
let newline_offset = buf[self.next_index..read_to]
.iter()
.position(|b| *b == b'\n');
match (self.is_discarding, newline_offset) {
(true, Some(offset)) => {
// If we found a newline, discard up to that offset and
// then stop discarding. On the next iteration, we'll try
// to read a line normally.
buf.advance(offset + self.next_index + 1);
self.is_discarding = false;
self.next_index = 0;
}
(true, None) => {
// Otherwise, we didn't find a newline, so we'll discard
// everything we read. On the next iteration, we'll continue
// discarding up to max_len bytes unless we find a newline.
buf.advance(read_to);
self.next_index = 0;
if buf.is_empty() {
return Ok(None);
}
}
(false, Some(offset)) => {
// Found a line!
let newline_index = offset + self.next_index;
self.next_index = 0;
let line = buf.split_to(newline_index + 1);
let line = &line[..line.len() - 1];
let line = without_carriage_return(line);
let line = utf8(line)?;
return Ok(Some(line.to_string()));
}
(false, None) if buf.len() > self.max_length => {
// Reached the maximum length without finding a
// newline, return an error and start discarding on the
// next call.
self.is_discarding = true;
return Err(LinesCodecError::MaxLineLengthExceeded);
}
(false, None) => {
// We didn't find a line or reach the length limit, so the next
// call will resume searching at the current offset.
self.next_index = read_to;
return Ok(None);
}
}
}
}
fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<String>, LinesCodecError> {
Ok(match self.decode(buf)? {
Some(frame) => Some(frame),
None => {
self.next_index = 0;
// No terminating newline - return remaining data, if any
if buf.is_empty() || buf == &b"\r"[..] {
None
} else {
let line = buf.split_to(buf.len());
let line = without_carriage_return(&line);
let line = utf8(line)?;
Some(line.to_string())
}
}
})
}
}
impl<T> Encoder<T> for LinesCodec
where
T: AsRef<str>,
{
type Error = LinesCodecError;
fn encode(&mut self, line: T, buf: &mut BytesMut) -> Result<(), LinesCodecError> {
let line = line.as_ref();
buf.reserve(line.len() + 1);
buf.put(line.as_bytes());
buf.put_u8(b'\n');
Ok(())
}
}
impl Default for LinesCodec {
fn default() -> Self {
Self::new()
}
}
/// An error occurred while encoding or decoding a line.
#[derive(Debug)]
pub enum LinesCodecError {
/// The maximum line length was exceeded.
MaxLineLengthExceeded,
/// An IO error occurred.
Io(io::Error),
}
impl fmt::Display for LinesCodecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LinesCodecError::MaxLineLengthExceeded => write!(f, "max line length exceeded"),
LinesCodecError::Io(e) => write!(f, "{e}"),
}
}
}
impl From<io::Error> for LinesCodecError {
fn from(e: io::Error) -> LinesCodecError {
LinesCodecError::Io(e)
}
}
impl std::error::Error for LinesCodecError {}

356
vendor/tokio-util/src/codec/mod.rs vendored Normal file
View File

@@ -0,0 +1,356 @@
//! Adaptors from `AsyncRead`/`AsyncWrite` to Stream/Sink
//!
//! Raw I/O objects work with byte sequences, but higher-level code usually
//! wants to batch these into meaningful chunks, called "frames".
//!
//! This module contains adapters to go from streams of bytes, [`AsyncRead`] and
//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`].
//! Framed streams are also known as transports.
//!
//! # Example encoding using `LinesCodec`
//!
//! The following example demonstrates how to use a codec such as [`LinesCodec`] to
//! write framed data. [`FramedWrite`] can be used to achieve this. Data sent to
//! [`FramedWrite`] are first framed according to a specific codec, and then sent to
//! an implementor of [`AsyncWrite`].
//!
//! ```
//! use futures::sink::SinkExt;
//! use tokio_util::codec::LinesCodec;
//! use tokio_util::codec::FramedWrite;
//!
//! # #[tokio::main(flavor = "current_thread")]
//! # async fn main() {
//! let buffer = Vec::new();
//! let messages = vec!["Hello", "World"];
//! let encoder = LinesCodec::new();
//!
//! // FramedWrite is a sink which means you can send values into it
//! // asynchronously.
//! let mut writer = FramedWrite::new(buffer, encoder);
//!
//! // To be able to send values into a FramedWrite, you need to bring the
//! // `SinkExt` trait into scope.
//! writer.send(messages[0]).await.unwrap();
//! writer.send(messages[1]).await.unwrap();
//!
//! let buffer = writer.get_ref();
//!
//! assert_eq!(buffer.as_slice(), "Hello\nWorld\n".as_bytes());
//! # }
//!```
//!
//! # Example decoding using `LinesCodec`
//! The following example demonstrates how to use a codec such as [`LinesCodec`] to
//! read a stream of framed data. [`FramedRead`] can be used to achieve this. [`FramedRead`]
//! will keep reading from an [`AsyncRead`] implementor until a whole frame, according to a codec,
//! can be parsed.
//!
//!```
//! use tokio_stream::StreamExt;
//! use tokio_util::codec::LinesCodec;
//! use tokio_util::codec::FramedRead;
//!
//! # #[tokio::main(flavor = "current_thread")]
//! # async fn main() {
//! let message = "Hello\nWorld".as_bytes();
//! let decoder = LinesCodec::new();
//!
//! // FramedRead can be used to read a stream of values that are framed according to
//! // a codec. FramedRead will read from its input (here `buffer`) until a whole frame
//! // can be parsed.
//! let mut reader = FramedRead::new(message, decoder);
//!
//! // To read values from a FramedRead, you need to bring the
//! // `StreamExt` trait into scope.
//! let frame1 = reader.next().await.unwrap().unwrap();
//! let frame2 = reader.next().await.unwrap().unwrap();
//!
//! assert!(reader.next().await.is_none());
//! assert_eq!(frame1, "Hello");
//! assert_eq!(frame2, "World");
//! # }
//! ```
//!
//! # The Decoder trait
//!
//! A [`Decoder`] is used together with [`FramedRead`] or [`Framed`] to turn an
//! [`AsyncRead`] into a [`Stream`]. The job of the decoder trait is to specify
//! how sequences of bytes are turned into a sequence of frames, and to
//! determine where the boundaries between frames are. The job of the
//! `FramedRead` is to repeatedly switch between reading more data from the IO
//! resource, and asking the decoder whether we have received enough data to
//! decode another frame of data.
//!
//! The main method on the `Decoder` trait is the [`decode`] method. This method
//! takes as argument the data that has been read so far, and when it is called,
//! it will be in one of the following situations:
//!
//! 1. The buffer contains less than a full frame.
//! 2. The buffer contains exactly a full frame.
//! 3. The buffer contains more than a full frame.
//!
//! In the first situation, the decoder should return `Ok(None)`.
//!
//! In the second situation, the decoder should clear the provided buffer and
//! return `Ok(Some(the_decoded_frame))`.
//!
//! In the third situation, the decoder should use a method such as [`split_to`]
//! or [`advance`] to modify the buffer such that the frame is removed from the
//! buffer, but any data in the buffer after that frame should still remain in
//! the buffer. The decoder should also return `Ok(Some(the_decoded_frame))` in
//! this case.
//!
//! Finally the decoder may return an error if the data is invalid in some way.
//! The decoder should _not_ return an error just because it has yet to receive
//! a full frame.
//!
//! It is guaranteed that, from one call to `decode` to another, the provided
//! buffer will contain the exact same data as before, except that if more data
//! has arrived through the IO resource, that data will have been appended to
//! the buffer. This means that reading frames from a `FramedRead` is
//! essentially equivalent to the following loop:
//!
//! ```no_run
//! use tokio::io::AsyncReadExt;
//! # // This uses async_stream to create an example that compiles.
//! # fn foo() -> impl futures_core::Stream<Item = std::io::Result<bytes::BytesMut>> { async_stream::try_stream! {
//! # use tokio_util::codec::Decoder;
//! # let mut decoder = tokio_util::codec::BytesCodec::new();
//! # let io_resource = &mut &[0u8, 1, 2, 3][..];
//!
//! let mut buf = bytes::BytesMut::new();
//! loop {
//! // The read_buf call will append to buf rather than overwrite existing data.
//! let len = io_resource.read_buf(&mut buf).await?;
//!
//! if len == 0 {
//! while let Some(frame) = decoder.decode_eof(&mut buf)? {
//! yield frame;
//! }
//! break;
//! }
//!
//! while let Some(frame) = decoder.decode(&mut buf)? {
//! yield frame;
//! }
//! }
//! # }}
//! ```
//! The example above uses `yield` whenever the `Stream` produces an item.
//!
//! ## Example decoder
//!
//! As an example, consider a protocol that can be used to send strings where
//! each frame is a four byte integer that contains the length of the frame,
//! followed by that many bytes of string data. The decoder fails with an error
//! if the string data is not valid utf-8 or too long.
//!
//! Such a decoder can be written like this:
//! ```
//! use tokio_util::codec::Decoder;
//! use bytes::{BytesMut, Buf};
//!
//! struct MyStringDecoder {}
//!
//! const MAX: usize = 8 * 1024 * 1024;
//!
//! impl Decoder for MyStringDecoder {
//! type Item = String;
//! type Error = std::io::Error;
//!
//! fn decode(
//! &mut self,
//! src: &mut BytesMut
//! ) -> Result<Option<Self::Item>, Self::Error> {
//! if src.len() < 4 {
//! // Not enough data to read length marker.
//! return Ok(None);
//! }
//!
//! // Read length marker.
//! let mut length_bytes = [0u8; 4];
//! length_bytes.copy_from_slice(&src[..4]);
//! let length = u32::from_le_bytes(length_bytes) as usize;
//!
//! // Check that the length is not too large to avoid a denial of
//! // service attack where the server runs out of memory.
//! if length > MAX {
//! return Err(std::io::Error::new(
//! std::io::ErrorKind::InvalidData,
//! format!("Frame of length {} is too large.", length)
//! ));
//! }
//!
//! if src.len() < 4 + length {
//! // The full string has not yet arrived.
//! //
//! // We reserve more space in the buffer. This is not strictly
//! // necessary, but is a good idea performance-wise.
//! src.reserve(4 + length - src.len());
//!
//! // We inform the Framed that we need more bytes to form the next
//! // frame.
//! return Ok(None);
//! }
//!
//! // Use advance to modify src such that it no longer contains
//! // this frame.
//! let data = src[4..4 + length].to_vec();
//! src.advance(4 + length);
//!
//! // Convert the data to a string, or fail if it is not valid utf-8.
//! match String::from_utf8(data) {
//! Ok(string) => Ok(Some(string)),
//! Err(utf8_error) => {
//! Err(std::io::Error::new(
//! std::io::ErrorKind::InvalidData,
//! utf8_error.utf8_error(),
//! ))
//! },
//! }
//! }
//! }
//! ```
//!
//! # The Encoder trait
//!
//! An [`Encoder`] is used together with [`FramedWrite`] or [`Framed`] to turn
//! an [`AsyncWrite`] into a [`Sink`]. The job of the encoder trait is to
//! specify how frames are turned into a sequences of bytes. The job of the
//! `FramedWrite` is to take the resulting sequence of bytes and write it to the
//! IO resource.
//!
//! The main method on the `Encoder` trait is the [`encode`] method. This method
//! takes an item that is being written, and a buffer to write the item to. The
//! buffer may already contain data, and in this case, the encoder should append
//! the new frame to the buffer rather than overwrite the existing data.
//!
//! It is guaranteed that, from one call to `encode` to another, the provided
//! buffer will contain the exact same data as before, except that some of the
//! data may have been removed from the front of the buffer. Writing to a
//! `FramedWrite` is essentially equivalent to the following loop:
//!
//! ```no_run
//! use tokio::io::AsyncWriteExt;
//! use bytes::Buf; // for advance
//! # use tokio_util::codec::Encoder;
//! # async fn next_frame() -> bytes::Bytes { bytes::Bytes::new() }
//! # async fn no_more_frames() { }
//! # #[tokio::main] async fn main() -> std::io::Result<()> {
//! # let mut io_resource = tokio::io::sink();
//! # let mut encoder = tokio_util::codec::BytesCodec::new();
//!
//! const MAX: usize = 8192;
//!
//! let mut buf = bytes::BytesMut::new();
//! loop {
//! tokio::select! {
//! num_written = io_resource.write(&buf), if !buf.is_empty() => {
//! buf.advance(num_written?);
//! },
//! frame = next_frame(), if buf.len() < MAX => {
//! encoder.encode(frame, &mut buf)?;
//! },
//! _ = no_more_frames() => {
//! io_resource.write_all(&buf).await?;
//! io_resource.shutdown().await?;
//! return Ok(());
//! },
//! }
//! }
//! # }
//! ```
//! Here the `next_frame` method corresponds to any frames you write to the
//! `FramedWrite`. The `no_more_frames` method corresponds to closing the
//! `FramedWrite` with [`SinkExt::close`].
//!
//! ## Example encoder
//!
//! As an example, consider a protocol that can be used to send strings where
//! each frame is a four byte integer that contains the length of the frame,
//! followed by that many bytes of string data. The encoder will fail if the
//! string is too long.
//!
//! Such an encoder can be written like this:
//! ```
//! use tokio_util::codec::Encoder;
//! use bytes::BytesMut;
//!
//! struct MyStringEncoder {}
//!
//! const MAX: usize = 8 * 1024 * 1024;
//!
//! impl Encoder<String> for MyStringEncoder {
//! type Error = std::io::Error;
//!
//! fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> {
//! // Don't send a string if it is longer than the other end will
//! // accept.
//! if item.len() > MAX {
//! return Err(std::io::Error::new(
//! std::io::ErrorKind::InvalidData,
//! format!("Frame of length {} is too large.", item.len())
//! ));
//! }
//!
//! // Convert the length into a byte array.
//! // The cast to u32 cannot overflow due to the length check above.
//! let len_slice = u32::to_le_bytes(item.len() as u32);
//!
//! // Reserve space in the buffer.
//! dst.reserve(4 + item.len());
//!
//! // Write the length and string to the buffer.
//! dst.extend_from_slice(&len_slice);
//! dst.extend_from_slice(item.as_bytes());
//! Ok(())
//! }
//! }
//! ```
//!
//! [`AsyncRead`]: tokio::io::AsyncRead
//! [`AsyncWrite`]: tokio::io::AsyncWrite
//! [`Stream`]: futures_core::Stream
//! [`Sink`]: futures_sink::Sink
//! [`SinkExt`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html
//! [`SinkExt::close`]: https://docs.rs/futures/0.3/futures/sink/trait.SinkExt.html#method.close
//! [`FramedRead`]: struct@crate::codec::FramedRead
//! [`FramedWrite`]: struct@crate::codec::FramedWrite
//! [`Framed`]: struct@crate::codec::Framed
//! [`Decoder`]: trait@crate::codec::Decoder
//! [`decode`]: fn@crate::codec::Decoder::decode
//! [`encode`]: fn@crate::codec::Encoder::encode
//! [`split_to`]: fn@bytes::BytesMut::split_to
//! [`advance`]: fn@bytes::Buf::advance
mod bytes_codec;
pub use self::bytes_codec::BytesCodec;
mod decoder;
pub use self::decoder::Decoder;
mod encoder;
pub use self::encoder::Encoder;
mod framed_impl;
#[allow(unused_imports)]
pub(crate) use self::framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame};
mod framed;
pub use self::framed::{Framed, FramedParts};
mod framed_read;
pub use self::framed_read::FramedRead;
mod framed_write;
pub use self::framed_write::FramedWrite;
pub mod length_delimited;
pub use self::length_delimited::{LengthDelimitedCodec, LengthDelimitedCodecError};
mod lines_codec;
pub use self::lines_codec::{LinesCodec, LinesCodecError};
mod any_delimiter_codec;
pub use self::any_delimiter_codec::{AnyDelimiterCodec, AnyDelimiterCodecError};

386
vendor/tokio-util/src/compat.rs vendored Normal file
View File

@@ -0,0 +1,386 @@
//! Compatibility between the `tokio::io` and `futures-io` versions of the
//! `AsyncRead` and `AsyncWrite` traits.
//!
//! ## Bridging Tokio and Futures I/O with `compat()`
//!
//! The [`compat()`] function provides a compatibility layer that allows types implementing
//! [`tokio::io::AsyncRead`] or [`tokio::io::AsyncWrite`] to be used as their
//! [`futures::io::AsyncRead`] or [`futures::io::AsyncWrite`] counterparts — and vice versa.
//!
//! This is especially useful when working with libraries that expect I/O types from one ecosystem
//! (usually `futures`) but you are using types from the other (usually `tokio`).
//!
//! ## Compatibility Overview
//!
//! | Inner Type Implements... | `Compat<T>` Implements... |
//! |-----------------------------|-----------------------------|
//! | [`tokio::io::AsyncRead`] | [`futures::io::AsyncRead`] |
//! | [`futures::io::AsyncRead`] | [`tokio::io::AsyncRead`] |
//! | [`tokio::io::AsyncWrite`] | [`futures::io::AsyncWrite`] |
//! | [`futures::io::AsyncWrite`] | [`tokio::io::AsyncWrite`] |
//!
//! ## Feature Flag
//!
//! This functionality is available through the `compat` feature flag:
//!
//! ```toml
//! tokio-util = { version = "...", features = ["compat"] }
//! ```
//!
//! ## Example 1: Tokio -> Futures (`AsyncRead`)
//!
//! This example demonstrates sending data over a [`tokio::net::TcpStream`] and using
//! [`futures::io::AsyncReadExt::read`] from the `futures` crate to read it after adapting the
//! stream via [`compat()`].
//!
//! ```no_run
//! # #[cfg(not(target_family = "wasm"))]
//! # {
//! use tokio::net::{TcpListener, TcpStream};
//! use tokio::io::AsyncWriteExt;
//! use tokio_util::compat::TokioAsyncReadCompatExt;
//! use futures::io::AsyncReadExt;
//!
//! #[tokio::main]
//! async fn main() -> std::io::Result<()> {
//! let listener = TcpListener::bind("127.0.0.1:8081").await?;
//!
//! tokio::spawn(async {
//! let mut client = TcpStream::connect("127.0.0.1:8081").await.unwrap();
//! client.write_all(b"Hello World").await.unwrap();
//! });
//!
//! let (stream, _) = listener.accept().await?;
//!
//! // Adapt `tokio::TcpStream` to be used with `futures::io::AsyncReadExt`
//! let mut compat_stream = stream.compat();
//! let mut buffer = [0; 20];
//! let n = compat_stream.read(&mut buffer).await?;
//! println!("Received: {}", String::from_utf8_lossy(&buffer[..n]));
//!
//! Ok(())
//! }
//! # }
//! ```
//!
//! ## Example 2: Futures -> Tokio (`AsyncRead`)
//!
//! The reverse is also possible: you can take a [`futures::io::AsyncRead`] (e.g. a cursor) and
//! adapt it to be used with [`tokio::io::AsyncReadExt::read_to_end`]
//!
//! ```
//! # #[cfg(not(target_family = "wasm"))]
//! # {
//! use futures::io::Cursor;
//! use tokio_util::compat::FuturesAsyncReadCompatExt;
//! use tokio::io::AsyncReadExt;
//!
//! fn main() {
//! let future = async {
//! let reader = Cursor::new(b"Hello from futures");
//! let mut compat_reader = reader.compat();
//! let mut buf = Vec::new();
//! compat_reader.read_to_end(&mut buf).await.unwrap();
//! assert_eq!(&buf, b"Hello from futures");
//! };
//!
//! // Run the future inside a Tokio runtime
//! tokio::runtime::Runtime::new().unwrap().block_on(future);
//! }
//! # }
//! ```
//!
//! ## Common Use Cases
//!
//! - Using `tokio` sockets with `async-tungstenite`, `async-compression`, or `futures-rs`-based
//! libraries.
//! - Bridging I/O interfaces between mixed-ecosystem libraries.
//! - Avoiding rewrites or duplication of I/O code in async environments.
//!
//! ## See Also
//!
//! - [`Compat`] type
//! - [`TokioAsyncReadCompatExt`]
//! - [`FuturesAsyncReadCompatExt`]
//! - [`tokio::io`]
//! - [`futures::io`]
//!
//! [`futures::io`]: https://docs.rs/futures/latest/futures/io/
//! [`futures::io::AsyncRead`]: https://docs.rs/futures/latest/futures/io/trait.AsyncRead.html
//! [`futures::io::AsyncWrite`]: https://docs.rs/futures/latest/futures/io/trait.AsyncWrite.html
//! [`futures::io::AsyncReadExt::read`]: https://docs.rs/futures/latest/futures/io/trait.AsyncReadExt.html#method.read
//! [`compat()`]: TokioAsyncReadCompatExt::compat
use pin_project_lite::pin_project;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
pin_project! {
/// A compatibility layer that allows conversion between the
/// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits.
#[derive(Copy, Clone, Debug)]
pub struct Compat<T> {
#[pin]
inner: T,
seek_pos: Option<io::SeekFrom>,
}
}
/// Extension trait that allows converting a type implementing
/// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`.
pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead {
/// Wraps `self` with a compatibility layer that implements
/// `tokio_io::AsyncRead`.
fn compat(self) -> Compat<Self>
where
Self: Sized,
{
Compat::new(self)
}
}
impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {}
/// Extension trait that allows converting a type implementing
/// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`.
pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite {
/// Wraps `self` with a compatibility layer that implements
/// `tokio::io::AsyncWrite`.
fn compat_write(self) -> Compat<Self>
where
Self: Sized,
{
Compat::new(self)
}
}
impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {}
/// Extension trait that allows converting a type implementing
/// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`.
pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead {
/// Wraps `self` with a compatibility layer that implements
/// `futures_io::AsyncRead`.
fn compat(self) -> Compat<Self>
where
Self: Sized,
{
Compat::new(self)
}
}
impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {}
/// Extension trait that allows converting a type implementing
/// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`.
pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite {
/// Wraps `self` with a compatibility layer that implements
/// `futures_io::AsyncWrite`.
fn compat_write(self) -> Compat<Self>
where
Self: Sized,
{
Compat::new(self)
}
}
impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {}
// === impl Compat ===
impl<T> Compat<T> {
fn new(inner: T) -> Self {
Self {
inner,
seek_pos: None,
}
}
/// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
/// contained within.
pub fn get_ref(&self) -> &T {
&self.inner
}
/// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
/// contained within.
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
/// Returns the wrapped item.
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T> tokio::io::AsyncRead for Compat<T>
where
T: futures_io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
// We can't trust the inner type to not peak at the bytes,
// so we must defensively initialize the buffer.
let slice = buf.initialize_unfilled();
let n = ready!(futures_io::AsyncRead::poll_read(
self.project().inner,
cx,
slice
))?;
buf.advance(n);
Poll::Ready(Ok(()))
}
}
impl<T> futures_io::AsyncRead for Compat<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
slice: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut buf = tokio::io::ReadBuf::new(slice);
ready!(tokio::io::AsyncRead::poll_read(
self.project().inner,
cx,
&mut buf
))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
impl<T> tokio::io::AsyncBufRead for Compat<T>
where
T: futures_io::AsyncBufRead,
{
fn poll_fill_buf<'a>(
self: Pin<&'a mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<&'a [u8]>> {
futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
futures_io::AsyncBufRead::consume(self.project().inner, amt)
}
}
impl<T> futures_io::AsyncBufRead for Compat<T>
where
T: tokio::io::AsyncBufRead,
{
fn poll_fill_buf<'a>(
self: Pin<&'a mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<&'a [u8]>> {
tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
}
fn consume(self: Pin<&mut Self>, amt: usize) {
tokio::io::AsyncBufRead::consume(self.project().inner, amt)
}
}
impl<T> tokio::io::AsyncWrite for Compat<T>
where
T: futures_io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
futures_io::AsyncWrite::poll_flush(self.project().inner, cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
futures_io::AsyncWrite::poll_close(self.project().inner, cx)
}
}
impl<T> futures_io::AsyncWrite for Compat<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}
}
impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> {
fn poll_seek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
pos: io::SeekFrom,
) -> Poll<io::Result<u64>> {
if self.seek_pos != Some(pos) {
// Ensure previous seeks have finished before starting a new one
ready!(self.as_mut().project().inner.poll_complete(cx))?;
self.as_mut().project().inner.start_seek(pos)?;
*self.as_mut().project().seek_pos = Some(pos);
}
let res = ready!(self.as_mut().project().inner.poll_complete(cx));
*self.as_mut().project().seek_pos = None;
Poll::Ready(res)
}
}
impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> {
fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
*self.as_mut().project().seek_pos = Some(pos);
Ok(())
}
fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let pos = match self.seek_pos {
None => {
// tokio 1.x AsyncSeek recommends calling poll_complete before start_seek.
// We don't have to guarantee that the value returned by
// poll_complete called without start_seek is correct,
// so we'll return 0.
return Poll::Ready(Ok(0));
}
Some(pos) => pos,
};
let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos));
*self.as_mut().project().seek_pos = None;
Poll::Ready(res)
}
}
#[cfg(unix)]
impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> {
fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(windows)]
impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> {
fn as_raw_handle(&self) -> std::os::windows::io::RawHandle {
self.inner.as_raw_handle()
}
}

199
vendor/tokio-util/src/context.rs vendored Normal file
View File

@@ -0,0 +1,199 @@
//! Tokio context aware futures utilities.
//!
//! This module includes utilities around integrating tokio with other runtimes
//! by allowing the context to be attached to futures. This allows spawning
//! futures on other executors while still using tokio to drive them. This
//! can be useful if you need to use a tokio based library in an executor/runtime
//! that does not provide a tokio context.
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::runtime::{Handle, Runtime};
pin_project! {
/// `TokioContext` allows running futures that must be inside Tokio's
/// context on a non-Tokio runtime.
///
/// It contains a [`Handle`] to the runtime. A handle to the runtime can be
/// obtain by calling the [`Runtime::handle()`] method.
///
/// Note that the `TokioContext` wrapper only works if the `Runtime` it is
/// connected to has not yet been destroyed. You must keep the `Runtime`
/// alive until the future has finished executing.
///
/// **Warning:** If `TokioContext` is used together with a [current thread]
/// runtime, that runtime must be inside a call to `block_on` for the
/// wrapped future to work. For this reason, it is recommended to use a
/// [multi thread] runtime, even if you configure it to only spawn one
/// worker thread.
///
/// # Examples
///
/// This example creates two runtimes, but only [enables time] on one of
/// them. It then uses the context of the runtime with the timer enabled to
/// execute a [`sleep`] future on the runtime with timing disabled.
/// ```
/// # #[cfg(not(target_family = "wasm"))]
/// # {
/// use tokio::time::{sleep, Duration};
/// use tokio_util::context::RuntimeExt;
///
/// // This runtime has timers enabled.
/// let rt = tokio::runtime::Builder::new_multi_thread()
/// .enable_all()
/// .build()
/// .unwrap();
///
/// // This runtime has timers disabled.
/// let rt2 = tokio::runtime::Builder::new_multi_thread()
/// .build()
/// .unwrap();
///
/// // Wrap the sleep future in the context of rt.
/// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await });
///
/// // Execute the future on rt2.
/// rt2.block_on(fut);
/// # }
/// ```
///
/// [`Handle`]: struct@tokio::runtime::Handle
/// [`Runtime::handle()`]: fn@tokio::runtime::Runtime::handle
/// [`RuntimeExt`]: trait@crate::context::RuntimeExt
/// [`new_static`]: fn@Self::new_static
/// [`sleep`]: fn@tokio::time::sleep
/// [current thread]: fn@tokio::runtime::Builder::new_current_thread
/// [enables time]: fn@tokio::runtime::Builder::enable_time
/// [multi thread]: fn@tokio::runtime::Builder::new_multi_thread
pub struct TokioContext<F> {
#[pin]
inner: F,
handle: Handle,
}
}
impl<F> TokioContext<F> {
/// Associate the provided future with the context of the runtime behind
/// the provided `Handle`.
///
/// This constructor uses a `'static` lifetime to opt-out of checking that
/// the runtime still exists.
///
/// # Examples
///
/// This is the same as the example above, but uses the `new` constructor
/// rather than [`RuntimeExt::wrap`].
///
/// [`RuntimeExt::wrap`]: fn@RuntimeExt::wrap
///
/// ```
/// # #[cfg(not(target_family = "wasm"))]
/// # {
/// use tokio::time::{sleep, Duration};
/// use tokio_util::context::TokioContext;
///
/// // This runtime has timers enabled.
/// let rt = tokio::runtime::Builder::new_multi_thread()
/// .enable_all()
/// .build()
/// .unwrap();
///
/// // This runtime has timers disabled.
/// let rt2 = tokio::runtime::Builder::new_multi_thread()
/// .build()
/// .unwrap();
///
/// let fut = TokioContext::new(
/// async { sleep(Duration::from_millis(2)).await },
/// rt.handle().clone(),
/// );
///
/// // Execute the future on rt2.
/// rt2.block_on(fut);
/// # }
/// ```
pub fn new(future: F, handle: Handle) -> TokioContext<F> {
TokioContext {
inner: future,
handle,
}
}
/// Obtain a reference to the handle inside this `TokioContext`.
pub fn handle(&self) -> &Handle {
&self.handle
}
/// Remove the association between the Tokio runtime and the wrapped future.
pub fn into_inner(self) -> F {
self.inner
}
}
impl<F: Future> Future for TokioContext<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = self.project();
let handle = me.handle;
let fut = me.inner;
let _enter = handle.enter();
fut.poll(cx)
}
}
/// Extension trait that simplifies bundling a `Handle` with a `Future`.
pub trait RuntimeExt {
/// Create a [`TokioContext`] that wraps the provided future and runs it in
/// this runtime's context.
///
/// # Examples
///
/// This example creates two runtimes, but only [enables time] on one of
/// them. It then uses the context of the runtime with the timer enabled to
/// execute a [`sleep`] future on the runtime with timing disabled.
///
/// ```
/// # #[cfg(not(target_family = "wasm"))]
/// # {
/// use tokio::time::{sleep, Duration};
/// use tokio_util::context::RuntimeExt;
///
/// // This runtime has timers enabled.
/// let rt = tokio::runtime::Builder::new_multi_thread()
/// .enable_all()
/// .build()
/// .unwrap();
///
/// // This runtime has timers disabled.
/// let rt2 = tokio::runtime::Builder::new_multi_thread()
/// .build()
/// .unwrap();
///
/// // Wrap the sleep future in the context of rt.
/// let fut = rt.wrap(async { sleep(Duration::from_millis(2)).await });
///
/// // Execute the future on rt2.
/// rt2.block_on(fut);
/// # }
/// ```
///
/// [`TokioContext`]: struct@crate::context::TokioContext
/// [`sleep`]: fn@tokio::time::sleep
/// [enables time]: fn@tokio::runtime::Builder::enable_time
fn wrap<F: Future>(&self, fut: F) -> TokioContext<F>;
}
impl RuntimeExt for Runtime {
fn wrap<F: Future>(&self, fut: F) -> TokioContext<F> {
TokioContext {
inner: fut,
handle: self.handle().clone(),
}
}
}

236
vendor/tokio-util/src/either.rs vendored Normal file
View File

@@ -0,0 +1,236 @@
//! Module defining an Either type.
use std::{
future::Future,
io::SeekFrom,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result};
/// Combines two different futures, streams, or sinks having the same associated types into a single type.
///
/// This type implements common asynchronous traits such as [`Future`] and those in Tokio.
///
/// [`Future`]: std::future::Future
///
/// # Example
///
/// The following code will not work:
///
/// ```compile_fail
/// # fn some_condition() -> bool { true }
/// # async fn some_async_function() -> u32 { 10 }
/// # async fn other_async_function() -> u32 { 20 }
/// #[tokio::main]
/// async fn main() {
/// let result = if some_condition() {
/// some_async_function()
/// } else {
/// other_async_function() // <- Will print: "`if` and `else` have incompatible types"
/// };
///
/// println!("Result is {}", result.await);
/// }
/// ```
///
// This is because although the output types for both futures is the same, the exact future
// types are different, but the compiler must be able to choose a single type for the
// `result` variable.
///
/// When the output type is the same, we can wrap each future in `Either` to avoid the
/// issue:
///
/// ```
/// use tokio_util::either::Either;
/// # fn some_condition() -> bool { true }
/// # async fn some_async_function() -> u32 { 10 }
/// # async fn other_async_function() -> u32 { 20 }
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let result = if some_condition() {
/// Either::Left(some_async_function())
/// } else {
/// Either::Right(other_async_function())
/// };
///
/// let value = result.await;
/// println!("Result is {}", value);
/// # assert_eq!(value, 10);
/// # }
/// ```
#[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense.
#[derive(Debug, Clone)]
pub enum Either<L, R> {
Left(L),
Right(R),
}
/// A small helper macro which reduces amount of boilerplate in the actual trait method implementation.
/// It takes an invocation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either
/// enum variant held in `self`.
macro_rules! delegate_call {
($self:ident.$method:ident($($args:ident),+)) => {
unsafe {
match $self.get_unchecked_mut() {
Self::Left(l) => Pin::new_unchecked(l).$method($($args),+),
Self::Right(r) => Pin::new_unchecked(r).$method($($args),+),
}
}
}
}
impl<L, R, O> Future for Either<L, R>
where
L: Future<Output = O>,
R: Future<Output = O>,
{
type Output = O;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
delegate_call!(self.poll(cx))
}
}
impl<L, R> AsyncRead for Either<L, R>
where
L: AsyncRead,
R: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
delegate_call!(self.poll_read(cx, buf))
}
}
impl<L, R> AsyncBufRead for Either<L, R>
where
L: AsyncBufRead,
R: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
delegate_call!(self.poll_fill_buf(cx))
}
fn consume(self: Pin<&mut Self>, amt: usize) {
delegate_call!(self.consume(amt));
}
}
impl<L, R> AsyncSeek for Either<L, R>
where
L: AsyncSeek,
R: AsyncSeek,
{
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> {
delegate_call!(self.start_seek(position))
}
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> {
delegate_call!(self.poll_complete(cx))
}
}
impl<L, R> AsyncWrite for Either<L, R>
where
L: AsyncWrite,
R: AsyncWrite,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
delegate_call!(self.poll_write(cx, buf))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
delegate_call!(self.poll_flush(cx))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
delegate_call!(self.poll_shutdown(cx))
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::result::Result<usize, std::io::Error>> {
delegate_call!(self.poll_write_vectored(cx, bufs))
}
fn is_write_vectored(&self) -> bool {
match self {
Self::Left(l) => l.is_write_vectored(),
Self::Right(r) => r.is_write_vectored(),
}
}
}
impl<L, R> futures_core::stream::Stream for Either<L, R>
where
L: futures_core::stream::Stream,
R: futures_core::stream::Stream<Item = L::Item>,
{
type Item = L::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
delegate_call!(self.poll_next(cx))
}
}
impl<L, R, Item, Error> futures_sink::Sink<Item> for Either<L, R>
where
L: futures_sink::Sink<Item, Error = Error>,
R: futures_sink::Sink<Item, Error = Error>,
{
type Error = Error;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
delegate_call!(self.poll_ready(cx))
}
fn start_send(self: Pin<&mut Self>, item: Item) -> std::result::Result<(), Self::Error> {
delegate_call!(self.start_send(item))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
delegate_call!(self.poll_flush(cx))
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
delegate_call!(self.poll_close(cx))
}
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use tokio::io::{repeat, AsyncReadExt, Repeat};
use tokio_stream::{once, Once, StreamExt};
#[tokio::test]
async fn either_is_stream() {
let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1));
assert_eq!(Some(1u32), either.next().await);
}
#[tokio::test]
async fn either_is_async_read() {
let mut buffer = [0; 3];
let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101));
either.read_exact(&mut buffer).await.unwrap();
assert_eq!(buffer, [0b101, 0b101, 0b101]);
}
}

137
vendor/tokio-util/src/future.rs vendored Normal file
View File

@@ -0,0 +1,137 @@
//! An extension trait for Futures that provides a variety of convenient adapters.
mod with_cancellation_token;
use with_cancellation_token::{WithCancellationTokenFuture, WithCancellationTokenFutureOwned};
use std::future::Future;
use crate::sync::CancellationToken;
/// A trait which contains a variety of convenient adapters and utilities for `Future`s.
pub trait FutureExt: Future {
cfg_time! {
/// A wrapper around [`tokio::time::timeout`], with the advantage that it is easier to write
/// fluent call chains.
///
/// # Examples
///
/// ```rust
/// use tokio::{sync::oneshot, time::Duration};
/// use tokio_util::future::FutureExt;
///
/// # async fn dox() {
/// let (_tx, rx) = oneshot::channel::<()>();
///
/// let res = rx.timeout(Duration::from_millis(10)).await;
/// assert!(res.is_err());
/// # }
/// ```
#[track_caller]
fn timeout(self, timeout: std::time::Duration) -> tokio::time::Timeout<Self>
where
Self: Sized,
{
tokio::time::timeout(timeout, self)
}
/// A wrapper around [`tokio::time::timeout_at`], with the advantage that it is easier to write
/// fluent call chains.
///
/// # Examples
///
/// ```rust
/// use tokio::{sync::oneshot, time::{Duration, Instant}};
/// use tokio_util::future::FutureExt;
///
/// # async fn dox() {
/// let (_tx, rx) = oneshot::channel::<()>();
/// let deadline = Instant::now() + Duration::from_millis(10);
///
/// let res = rx.timeout_at(deadline).await;
/// assert!(res.is_err());
/// # }
/// ```
fn timeout_at(self, deadline: tokio::time::Instant) -> tokio::time::Timeout<Self>
where
Self: Sized,
{
tokio::time::timeout_at(deadline, self)
}
}
/// Similar to [`CancellationToken::run_until_cancelled`],
/// but with the advantage that it is easier to write fluent call chains.
///
/// # Fairness
///
/// Calling this on an already-cancelled token directly returns `None`.
/// For all subsequent polls, in case of concurrent completion and
/// cancellation, this is biased towards the `self` future completion.
///
/// # Examples
///
/// ```rust
/// use tokio::sync::oneshot;
/// use tokio_util::future::FutureExt;
/// use tokio_util::sync::CancellationToken;
///
/// # async fn dox() {
/// let (_tx, rx) = oneshot::channel::<()>();
/// let token = CancellationToken::new();
/// let token_clone = token.clone();
/// tokio::spawn(async move {
/// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
/// token.cancel();
/// });
/// assert!(rx.with_cancellation_token(&token_clone).await.is_none())
/// # }
/// ```
fn with_cancellation_token(
self,
cancellation_token: &CancellationToken,
) -> WithCancellationTokenFuture<'_, Self>
where
Self: Sized,
{
WithCancellationTokenFuture::new(cancellation_token, self)
}
/// Similar to [`CancellationToken::run_until_cancelled_owned`],
/// but with the advantage that it is easier to write fluent call chains.
///
/// # Fairness
///
/// Calling this on an already-cancelled token directly returns `None`.
/// For all subsequent polls, in case of concurrent completion and
/// cancellation, this is biased towards the `self` future completion.
///
/// # Examples
///
/// ```rust
/// use tokio::sync::oneshot;
/// use tokio_util::future::FutureExt;
/// use tokio_util::sync::CancellationToken;
///
/// # async fn dox() {
/// let (_tx, rx) = oneshot::channel::<()>();
/// let token = CancellationToken::new();
/// let token_clone = token.clone();
/// tokio::spawn(async move {
/// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
/// token.cancel();
/// });
/// assert!(rx.with_cancellation_token_owned(token_clone).await.is_none())
/// # }
/// ```
fn with_cancellation_token_owned(
self,
cancellation_token: CancellationToken,
) -> WithCancellationTokenFutureOwned<Self>
where
Self: Sized,
{
WithCancellationTokenFutureOwned::new(cancellation_token, self)
}
}
impl<T: Future + ?Sized> FutureExt for T {}

View File

@@ -0,0 +1,79 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use pin_project_lite::pin_project;
use crate::sync::{CancellationToken, RunUntilCancelledFuture, RunUntilCancelledFutureOwned};
pin_project! {
/// A [`Future`] that is resolved once the corresponding [`CancellationToken`]
/// is cancelled or a given [`Future`] gets resolved.
///
/// This future is immediately resolved if the corresponding [`CancellationToken`]
/// is already cancelled, otherwise, in case of concurrent completion and
/// cancellation, this is biased towards the future completion.
#[must_use = "futures do nothing unless polled"]
pub struct WithCancellationTokenFuture<'a, F: Future> {
#[pin]
run_until_cancelled: Option<RunUntilCancelledFuture<'a, F>>
}
}
impl<'a, F: Future> WithCancellationTokenFuture<'a, F> {
pub(crate) fn new(cancellation_token: &'a CancellationToken, future: F) -> Self {
Self {
run_until_cancelled: (!cancellation_token.is_cancelled())
.then(|| RunUntilCancelledFuture::new(cancellation_token, future)),
}
}
}
impl<'a, F: Future> Future for WithCancellationTokenFuture<'a, F> {
type Output = Option<F::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.run_until_cancelled.as_pin_mut() {
Some(fut) => fut.poll(cx),
None => Poll::Ready(None),
}
}
}
pin_project! {
/// A [`Future`] that is resolved once the corresponding [`CancellationToken`]
/// is cancelled or a given [`Future`] gets resolved.
///
/// This future is immediately resolved if the corresponding [`CancellationToken`]
/// is already cancelled, otherwise, in case of concurrent completion and
/// cancellation, this is biased towards the future completion.
#[must_use = "futures do nothing unless polled"]
pub struct WithCancellationTokenFutureOwned<F: Future> {
#[pin]
run_until_cancelled: Option<RunUntilCancelledFutureOwned<F>>
}
}
impl<F: Future> WithCancellationTokenFutureOwned<F> {
pub(crate) fn new(cancellation_token: CancellationToken, future: F) -> Self {
Self {
run_until_cancelled: (!cancellation_token.is_cancelled())
.then(|| RunUntilCancelledFutureOwned::new(cancellation_token, future)),
}
}
}
impl<F: Future> Future for WithCancellationTokenFutureOwned<F> {
type Output = Option<F::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.run_until_cancelled.as_pin_mut() {
Some(fut) => fut.poll(cx),
None => Poll::Ready(None),
}
}
}

View File

@@ -0,0 +1,76 @@
use bytes::Bytes;
use futures_core::stream::Stream;
use futures_sink::Sink;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
pin_project! {
/// A helper that wraps a [`Sink`]`<`[`Bytes`]`>` and converts it into a
/// [`Sink`]`<&'a [u8]>` by copying each byte slice into an owned [`Bytes`].
///
/// See the documentation for [`SinkWriter`] for an example.
///
/// [`Bytes`]: bytes::Bytes
/// [`SinkWriter`]: crate::io::SinkWriter
/// [`Sink`]: futures_sink::Sink
#[derive(Debug)]
pub struct CopyToBytes<S> {
#[pin]
inner: S,
}
}
impl<S> CopyToBytes<S> {
/// Creates a new [`CopyToBytes`].
pub fn new(inner: S) -> Self {
Self { inner }
}
/// Gets a reference to the underlying sink.
pub fn get_ref(&self) -> &S {
&self.inner
}
/// Gets a mutable reference to the underlying sink.
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
/// Consumes this [`CopyToBytes`], returning the underlying sink.
pub fn into_inner(self) -> S {
self.inner
}
}
impl<'a, S> Sink<&'a [u8]> for CopyToBytes<S>
where
S: Sink<Bytes>,
{
type Error = S::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: &'a [u8]) -> Result<(), Self::Error> {
self.project()
.inner
.start_send(Bytes::copy_from_slice(item))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<S: Stream> Stream for CopyToBytes<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}

179
vendor/tokio-util/src/io/inspect.rs vendored Normal file
View File

@@ -0,0 +1,179 @@
use pin_project_lite::pin_project;
use std::io::{IoSlice, Result};
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pin_project! {
/// An adapter that lets you inspect the data that's being read.
///
/// This is useful for things like hashing data as it's read in.
pub struct InspectReader<R, F> {
#[pin]
reader: R,
f: F,
}
}
impl<R, F> InspectReader<R, F> {
/// Create a new `InspectReader`, wrapping `reader` and calling `f` for the
/// new data supplied by each read call.
///
/// The closure will only be called with an empty slice if the inner reader
/// returns without reading data into the buffer. This happens at EOF, or if
/// `poll_read` is called with a zero-size buffer.
pub fn new(reader: R, f: F) -> InspectReader<R, F>
where
R: AsyncRead,
F: FnMut(&[u8]),
{
InspectReader { reader, f }
}
/// Consumes the `InspectReader`, returning the wrapped reader
pub fn into_inner(self) -> R {
self.reader
}
}
impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let me = self.project();
let filled_length = buf.filled().len();
ready!(me.reader.poll_read(cx, buf))?;
(me.f)(&buf.filled()[filled_length..]);
Poll::Ready(Ok(()))
}
}
impl<R: AsyncWrite, F> AsyncWrite for InspectReader<R, F> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, std::io::Error>> {
self.project().reader.poll_write(cx, buf)
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
self.project().reader.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), std::io::Error>> {
self.project().reader.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
self.project().reader.poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.reader.is_write_vectored()
}
}
pin_project! {
/// An adapter that lets you inspect the data that's being written.
///
/// This is useful for things like hashing data as it's written out.
pub struct InspectWriter<W, F> {
#[pin]
writer: W,
f: F,
}
}
impl<W, F> InspectWriter<W, F> {
/// Create a new `InspectWriter`, wrapping `write` and calling `f` for the
/// data successfully written by each write call.
///
/// The closure `f` will never be called with an empty slice. A vectored
/// write can result in multiple calls to `f` - at most one call to `f` per
/// buffer supplied to `poll_write_vectored`.
pub fn new(writer: W, f: F) -> InspectWriter<W, F>
where
W: AsyncWrite,
F: FnMut(&[u8]),
{
InspectWriter { writer, f }
}
/// Consumes the `InspectWriter`, returning the wrapped writer
pub fn into_inner(self) -> W {
self.writer
}
}
impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let me = self.project();
let res = me.writer.poll_write(cx, buf);
if let Poll::Ready(Ok(count)) = res {
if count != 0 {
(me.f)(&buf[..count]);
}
}
res
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let me = self.project();
me.writer.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let me = self.project();
me.writer.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
let me = self.project();
let res = me.writer.poll_write_vectored(cx, bufs);
if let Poll::Ready(Ok(mut count)) = res {
for buf in bufs {
if count == 0 {
break;
}
let size = count.min(buf.len());
if size != 0 {
(me.f)(&buf[..size]);
count -= size;
}
}
}
res
}
fn is_write_vectored(&self) -> bool {
self.writer.is_write_vectored()
}
}
impl<W: AsyncRead, F> AsyncRead for InspectWriter<W, F> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.project().writer.poll_read(cx, buf)
}
}

35
vendor/tokio-util/src/io/mod.rs vendored Normal file
View File

@@ -0,0 +1,35 @@
//! Helpers for IO related tasks.
//!
//! The stream types are often used in combination with hyper or reqwest, as they
//! allow converting between a hyper [`Body`] and [`AsyncRead`].
//!
//! The [`SyncIoBridge`] type converts from the world of async I/O
//! to synchronous I/O; this may often come up when using synchronous APIs
//! inside [`tokio::task::spawn_blocking`].
//!
//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html
//! [`AsyncRead`]: tokio::io::AsyncRead
mod copy_to_bytes;
mod inspect;
mod read_buf;
mod reader_stream;
pub mod simplex;
mod sink_writer;
mod stream_reader;
cfg_io_util! {
mod read_arc;
pub use self::read_arc::read_exact_arc;
mod sync_bridge;
pub use self::sync_bridge::SyncIoBridge;
}
pub use self::copy_to_bytes::CopyToBytes;
pub use self::inspect::{InspectReader, InspectWriter};
pub use self::read_buf::read_buf;
pub use self::reader_stream::ReaderStream;
pub use self::sink_writer::SinkWriter;
pub use self::stream_reader::StreamReader;
pub use crate::util::{poll_read_buf, poll_write_buf};

44
vendor/tokio-util/src/io/read_arc.rs vendored Normal file
View File

@@ -0,0 +1,44 @@
use std::io;
use std::mem::MaybeUninit;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt};
/// Read data from an `AsyncRead` into an `Arc`.
///
/// This uses `Arc::new_uninit_slice` and reads into the resulting uninitialized `Arc`.
///
/// # Example
///
/// ```
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> std::io::Result<()> {
/// use tokio_util::io::read_exact_arc;
///
/// let read = tokio::io::repeat(42);
///
/// let arc = read_exact_arc(read, 4).await?;
///
/// assert_eq!(&arc[..], &[42; 4]);
/// # Ok(())
/// # }
/// ```
pub async fn read_exact_arc<R: AsyncRead>(read: R, len: usize) -> io::Result<Arc<[u8]>> {
tokio::pin!(read);
// TODO(MSRV 1.82): When bumping MSRV, switch to `Arc::new_uninit_slice(len)`. The following is
// equivalent, and generates the same assembly, but works without requiring MSRV 1.82.
let arc: Arc<[MaybeUninit<u8>]> = (0..len).map(|_| MaybeUninit::uninit()).collect();
// TODO(MSRV future): Use `Arc::get_mut_unchecked` once it's stabilized.
// SAFETY: We're the only owner of the `Arc`, and we keep the `Arc` valid throughout this loop
// as we write through this reference.
let mut buf = unsafe { &mut *(Arc::as_ptr(&arc) as *mut [MaybeUninit<u8>]) };
while !buf.is_empty() {
if read.read_buf(&mut buf).await? == 0 {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "early eof"));
}
}
// TODO(MSRV 1.82): When bumping MSRV, switch to `arc.assume_init()`. The following is
// equivalent, and generates the same assembly, but works without requiring MSRV 1.82.
// SAFETY: This changes `[MaybeUninit<u8>]` to `[u8]`, and we've initialized all the bytes in
// the loop above.
Ok(unsafe { Arc::from_raw(Arc::into_raw(arc) as *const [u8]) })
}

65
vendor/tokio-util/src/io/read_buf.rs vendored Normal file
View File

@@ -0,0 +1,65 @@
use bytes::BufMut;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
/// Read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
///
/// [`BufMut`]: bytes::BufMut
///
/// # Example
///
/// ```
/// use bytes::{Bytes, BytesMut};
/// use tokio_stream as stream;
/// use tokio::io::Result;
/// use tokio_util::io::{StreamReader, read_buf};
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> std::io::Result<()> {
///
/// // Create a reader from an iterator. This particular reader will always be
/// // ready.
/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
///
/// let mut buf = BytesMut::new();
/// let mut reads = 0;
///
/// loop {
/// reads += 1;
/// let n = read_buf(&mut read, &mut buf).await?;
///
/// if n == 0 {
/// break;
/// }
/// }
///
/// // one or more reads might be necessary.
/// assert!(reads >= 1);
/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
/// # Ok(())
/// # }
/// ```
pub async fn read_buf<R, B>(read: &mut R, buf: &mut B) -> io::Result<usize>
where
R: AsyncRead + Unpin,
B: BufMut,
{
return ReadBufFn(read, buf).await;
struct ReadBufFn<'a, R, B>(&'a mut R, &'a mut B);
impl<'a, R, B> Future for ReadBufFn<'a, R, B>
where
R: AsyncRead + Unpin,
B: BufMut,
{
type Output = io::Result<usize>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
crate::util::poll_read_buf(Pin::new(this.0), cx, this.1)
}
}
}

View File

@@ -0,0 +1,123 @@
use bytes::{Bytes, BytesMut};
use futures_core::stream::Stream;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
const DEFAULT_CAPACITY: usize = 4096;
pin_project! {
/// Convert an [`AsyncRead`] into a [`Stream`] of byte chunks.
///
/// This stream is fused. It performs the inverse operation of
/// [`StreamReader`].
///
/// # Example
///
/// ```
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> std::io::Result<()> {
/// use tokio_stream::StreamExt;
/// use tokio_util::io::ReaderStream;
///
/// // Create a stream of data.
/// let data = b"hello, world!";
/// let mut stream = ReaderStream::new(&data[..]);
///
/// // Read all of the chunks into a vector.
/// let mut stream_contents = Vec::new();
/// while let Some(chunk) = stream.next().await {
/// stream_contents.extend_from_slice(&chunk?);
/// }
///
/// // Once the chunks are concatenated, we should have the
/// // original data.
/// assert_eq!(stream_contents, data);
/// # Ok(())
/// # }
/// ```
///
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`StreamReader`]: crate::io::StreamReader
/// [`Stream`]: futures_core::Stream
#[derive(Debug)]
pub struct ReaderStream<R> {
// Reader itself.
//
// This value is `None` if the stream has terminated.
#[pin]
reader: Option<R>,
// Working buffer, used to optimize allocations.
buf: BytesMut,
capacity: usize,
}
}
impl<R: AsyncRead> ReaderStream<R> {
/// Convert an [`AsyncRead`] into a [`Stream`] with item type
/// `Result<Bytes, std::io::Error>`.
///
/// Currently, the default capacity 4096 bytes (4 KiB).
/// This capacity is not part of the semver contract
/// and may be tweaked in future releases without
/// requiring a major version bump.
///
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`Stream`]: futures_core::Stream
pub fn new(reader: R) -> Self {
ReaderStream {
reader: Some(reader),
buf: BytesMut::new(),
capacity: DEFAULT_CAPACITY,
}
}
/// Convert an [`AsyncRead`] into a [`Stream`] with item type
/// `Result<Bytes, std::io::Error>`,
/// with a specific read buffer initial capacity.
///
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`Stream`]: futures_core::Stream
pub fn with_capacity(reader: R, capacity: usize) -> Self {
ReaderStream {
reader: Some(reader),
buf: BytesMut::with_capacity(capacity),
capacity,
}
}
}
impl<R: AsyncRead> Stream for ReaderStream<R> {
type Item = std::io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use crate::util::poll_read_buf;
let mut this = self.as_mut().project();
let reader = match this.reader.as_pin_mut() {
Some(r) => r,
None => return Poll::Ready(None),
};
if this.buf.capacity() == 0 {
this.buf.reserve(*this.capacity);
}
match poll_read_buf(reader, cx, &mut this.buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
self.project().reader.set(None);
Poll::Ready(Some(Err(err)))
}
Poll::Ready(Ok(0)) => {
self.project().reader.set(None);
Poll::Ready(None)
}
Poll::Ready(Ok(_)) => {
let chunk = this.buf.split();
Poll::Ready(Some(Ok(chunk.freeze())))
}
}
}
}

357
vendor/tokio-util/src/io/simplex.rs vendored Normal file
View File

@@ -0,0 +1,357 @@
//! Unidirectional byte-oriented channel.
use crate::util::poll_proceed;
use bytes::Buf;
use bytes::BytesMut;
use futures_core::ready;
use std::io::Error as IoError;
use std::io::ErrorKind as IoErrorKind;
use std::io::IoSlice;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
type IoResult<T> = Result<T, IoError>;
const CLOSED_ERROR_MSG: &str = "simplex has been closed";
#[derive(Debug)]
struct Inner {
/// `poll_write` will return [`Poll::Pending`] if the backpressure boundary is reached
backpressure_boundary: usize,
/// either [`Sender`] or [`Receiver`] is closed
is_closed: bool,
/// Waker used to wake the [`Receiver`]
receiver_waker: Option<Waker>,
/// Waker used to wake the [`Sender`]
sender_waker: Option<Waker>,
/// Buffer used to read and write data
buf: BytesMut,
}
impl Inner {
fn with_capacity(capacity: usize) -> Self {
Self {
backpressure_boundary: capacity,
is_closed: false,
receiver_waker: None,
sender_waker: None,
buf: BytesMut::with_capacity(capacity),
}
}
fn register_receiver_waker(&mut self, waker: &Waker) -> Option<Waker> {
match self.receiver_waker.as_mut() {
Some(old) if old.will_wake(waker) => None,
_ => self.receiver_waker.replace(waker.clone()),
}
}
fn register_sender_waker(&mut self, waker: &Waker) -> Option<Waker> {
match self.sender_waker.as_mut() {
Some(old) if old.will_wake(waker) => None,
_ => self.sender_waker.replace(waker.clone()),
}
}
fn take_receiver_waker(&mut self) -> Option<Waker> {
self.receiver_waker.take()
}
fn take_sender_waker(&mut self) -> Option<Waker> {
self.sender_waker.take()
}
fn is_closed(&self) -> bool {
self.is_closed
}
fn close_receiver(&mut self) -> Option<Waker> {
self.is_closed = true;
self.take_sender_waker()
}
fn close_sender(&mut self) -> Option<Waker> {
self.is_closed = true;
self.take_receiver_waker()
}
}
/// Receiver of the simplex channel.
///
/// # Cancellation safety
///
/// The `Receiver` is cancel safe. If it is used as the event in a
/// [`tokio::select!`](macro@tokio::select) statement and some other branch
/// completes first, it is guaranteed that no bytes were received on this
/// channel.
///
/// You can still read the remaining data from the buffer
/// even if the write half has been dropped.
/// See [`Sender::poll_shutdown`] and [`Sender::drop`] for more details.
#[derive(Debug)]
pub struct Receiver {
inner: Arc<Mutex<Inner>>,
}
impl Drop for Receiver {
/// This also wakes up the [`Sender`].
fn drop(&mut self) {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_receiver()
};
if let Some(waker) = maybe_waker {
waker.wake();
}
}
}
impl AsyncRead for Receiver {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
let coop = ready!(poll_proceed(cx));
let mut inner = self.inner.lock().unwrap();
let to_read = buf.remaining().min(inner.buf.remaining());
if to_read == 0 {
if inner.is_closed() || buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
let old_waker = inner.register_receiver_waker(cx.waker());
let maybe_waker = inner.take_sender_waker();
// unlock before waking up and dropping old waker
drop(inner);
drop(old_waker);
if let Some(waker) = maybe_waker {
waker.wake();
}
return Poll::Pending;
}
// this is to avoid starving other tasks
coop.made_progress();
buf.put_slice(&inner.buf[..to_read]);
inner.buf.advance(to_read);
let waker = inner.take_sender_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(()))
}
}
/// Sender of the simplex channel.
///
/// # Cancellation safety
///
/// The `Sender` is cancel safe. If it is used as the event in a
/// [`tokio::select!`](macro@tokio::select) statement and some other branch
/// completes first, it is guaranteed that no bytes were sent on this
/// channel.
///
/// # Shutdown
///
/// See [`Sender::poll_shutdown`].
#[derive(Debug)]
pub struct Sender {
inner: Arc<Mutex<Inner>>,
}
impl Drop for Sender {
/// This also wakes up the [`Receiver`].
fn drop(&mut self) {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_sender()
};
if let Some(waker) = maybe_waker {
waker.wake();
}
}
}
impl AsyncWrite for Sender {
/// # Errors
///
/// This method will return [`IoErrorKind::BrokenPipe`]
/// if the channel has been closed.
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
let coop = ready!(poll_proceed(cx));
let mut inner = self.inner.lock().unwrap();
if inner.is_closed() {
return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
}
let free = inner
.backpressure_boundary
.checked_sub(inner.buf.len())
.expect("backpressure boundary overflow");
let to_write = buf.len().min(free);
if to_write == 0 {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let old_waker = inner.register_sender_waker(cx.waker());
let waker = inner.take_receiver_waker();
// unlock before waking up and dropping old waker
drop(inner);
drop(old_waker);
if let Some(waker) = waker {
waker.wake();
}
return Poll::Pending;
}
// this is to avoid starving other tasks
coop.made_progress();
inner.buf.extend_from_slice(&buf[..to_write]);
let waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(to_write))
}
/// # Errors
///
/// This method will return [`IoErrorKind::BrokenPipe`]
/// if the channel has been closed.
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let inner = self.inner.lock().unwrap();
if inner.is_closed() {
Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)))
} else {
Poll::Ready(Ok(()))
}
}
/// After returns [`Poll::Ready`], all the following call to
/// [`Sender::poll_write`] and [`Sender::poll_flush`]
/// will return error.
///
/// The [`Receiver`] can still be used to read remaining data
/// until all bytes have been consumed.
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let maybe_waker = {
let mut inner = self.inner.lock().unwrap();
inner.close_sender()
};
if let Some(waker) = maybe_waker {
waker.wake();
}
Poll::Ready(Ok(()))
}
fn is_write_vectored(&self) -> bool {
true
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, IoError>> {
let coop = ready!(poll_proceed(cx));
let mut inner = self.inner.lock().unwrap();
if inner.is_closed() {
return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
}
let free = inner
.backpressure_boundary
.checked_sub(inner.buf.len())
.expect("backpressure boundary overflow");
if free == 0 {
let old_waker = inner.register_sender_waker(cx.waker());
let maybe_waker = inner.take_receiver_waker();
// unlock before waking up and dropping old waker
drop(inner);
drop(old_waker);
if let Some(waker) = maybe_waker {
waker.wake();
}
return Poll::Pending;
}
// this is to avoid starving other tasks
coop.made_progress();
let mut rem = free;
for buf in bufs {
if rem == 0 {
break;
}
let to_write = buf.len().min(rem);
if to_write == 0 {
assert_ne!(rem, 0);
assert_eq!(buf.len(), 0);
continue;
}
inner.buf.extend_from_slice(&buf[..to_write]);
rem -= to_write;
}
let waker = inner.take_receiver_waker();
drop(inner); // unlock before waking up
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(Ok(free - rem))
}
}
/// Create a simplex channel.
///
/// The `capacity` parameter specifies the maximum number of bytes that can be
/// stored in the channel without making the [`Sender::poll_write`]
/// return [`Poll::Pending`].
///
/// # Panics
///
/// This function will panic if `capacity` is zero.
pub fn new(capacity: usize) -> (Sender, Receiver) {
assert_ne!(capacity, 0, "capacity must be greater than zero");
let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
let tx = Sender {
inner: Arc::clone(&inner),
};
let rx = Receiver { inner };
(tx, rx)
}

134
vendor/tokio-util/src/io/sink_writer.rs vendored Normal file
View File

@@ -0,0 +1,134 @@
use futures_sink::Sink;
use futures_core::stream::Stream;
use pin_project_lite::pin_project;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
pin_project! {
/// Convert a [`Sink`] of byte chunks into an [`AsyncWrite`].
///
/// Whenever you write to this [`SinkWriter`], the supplied bytes are
/// forwarded to the inner [`Sink`]. When `shutdown` is called on this
/// [`SinkWriter`], the inner sink is closed.
///
/// This adapter takes a `Sink<&[u8]>` and provides an [`AsyncWrite`] impl
/// for it. Because of the lifetime, this trait is relatively rarely
/// implemented. The main ways to get a `Sink<&[u8]>` that you can use with
/// this type are:
///
/// * With the codec module by implementing the [`Encoder`]`<&[u8]>` trait.
/// * By wrapping a `Sink<Bytes>` in a [`CopyToBytes`].
/// * Manually implementing `Sink<&[u8]>` directly.
///
/// The opposite conversion of implementing `Sink<_>` for an [`AsyncWrite`]
/// is done using the [`codec`] module.
///
/// # Example
///
/// ```
/// use bytes::Bytes;
/// use futures_util::SinkExt;
/// use std::io::{Error, ErrorKind};
/// use tokio::io::AsyncWriteExt;
/// use tokio_util::io::{SinkWriter, CopyToBytes};
/// use tokio_util::sync::PollSender;
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> Result<(), Error> {
/// // We use an mpsc channel as an example of a `Sink<Bytes>`.
/// let (tx, mut rx) = tokio::sync::mpsc::channel::<Bytes>(1);
/// let sink = PollSender::new(tx).sink_map_err(|_| Error::from(ErrorKind::BrokenPipe));
///
/// // Wrap it in `CopyToBytes` to get a `Sink<&[u8]>`.
/// let mut writer = SinkWriter::new(CopyToBytes::new(sink));
///
/// // Write data to our interface...
/// let data: [u8; 4] = [1, 2, 3, 4];
/// let _ = writer.write(&data).await?;
///
/// // ... and receive it.
/// assert_eq!(data.as_slice(), &*rx.recv().await.unwrap());
/// # Ok(())
/// # }
/// ```
///
/// [`AsyncWrite`]: tokio::io::AsyncWrite
/// [`CopyToBytes`]: crate::io::CopyToBytes
/// [`Encoder`]: crate::codec::Encoder
/// [`Sink`]: futures_sink::Sink
/// [`codec`]: crate::codec
#[derive(Debug)]
pub struct SinkWriter<S> {
#[pin]
inner: S,
}
}
impl<S> SinkWriter<S> {
/// Creates a new [`SinkWriter`].
pub fn new(sink: S) -> Self {
Self { inner: sink }
}
/// Gets a reference to the underlying sink.
pub fn get_ref(&self) -> &S {
&self.inner
}
/// Gets a mutable reference to the underlying sink.
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
/// Consumes this [`SinkWriter`], returning the underlying sink.
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S, E> AsyncWrite for SinkWriter<S>
where
for<'a> S: Sink<&'a [u8], Error = E>,
E: Into<io::Error>,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let mut this = self.project();
ready!(this.inner.as_mut().poll_ready(cx).map_err(Into::into))?;
match this.inner.as_mut().start_send(buf) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(e.into())),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx).map_err(Into::into)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_close(cx).map_err(Into::into)
}
}
impl<S: Stream> Stream for SinkWriter<S> {
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
impl<S: AsyncRead> AsyncRead for SinkWriter<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}

View File

@@ -0,0 +1,346 @@
use bytes::Buf;
use futures_core::stream::Stream;
use futures_sink::Sink;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
/// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
///
/// This type performs the inverse operation of [`ReaderStream`].
///
/// This type also implements the [`AsyncBufRead`] trait, so you can use it
/// to read a `Stream` of byte chunks line-by-line. See the examples below.
///
/// # Example
///
/// ```
/// use bytes::Bytes;
/// use tokio::io::{AsyncReadExt, Result};
/// use tokio_util::io::StreamReader;
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> std::io::Result<()> {
///
/// // Create a stream from an iterator.
/// let stream = tokio_stream::iter(vec![
/// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
/// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
/// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
/// ]);
///
/// // Convert it to an AsyncRead.
/// let mut read = StreamReader::new(stream);
///
/// // Read five bytes from the stream.
/// let mut buf = [0; 5];
/// read.read_exact(&mut buf).await?;
/// assert_eq!(buf, [0, 1, 2, 3, 4]);
///
/// // Read the rest of the current chunk.
/// assert_eq!(read.read(&mut buf).await?, 3);
/// assert_eq!(&buf[..3], [5, 6, 7]);
///
/// // Read the next chunk.
/// assert_eq!(read.read(&mut buf).await?, 4);
/// assert_eq!(&buf[..4], [8, 9, 10, 11]);
///
/// // We have now reached the end.
/// assert_eq!(read.read(&mut buf).await?, 0);
///
/// # Ok(())
/// # }
/// ```
///
/// If the stream produces errors which are not [`std::io::Error`],
/// the errors can be converted using [`StreamExt`] to map each
/// element.
///
/// ```
/// use bytes::Bytes;
/// use tokio::io::AsyncReadExt;
/// use tokio_util::io::StreamReader;
/// use tokio_stream::StreamExt;
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> std::io::Result<()> {
///
/// // Create a stream from an iterator, including an error.
/// let stream = tokio_stream::iter(vec![
/// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
/// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
/// Result::Err("Something bad happened!")
/// ]);
///
/// // Use StreamExt to map the stream and error to a std::io::Error
/// let stream = stream.map(|result| result.map_err(|err| {
/// std::io::Error::new(std::io::ErrorKind::Other, err)
/// }));
///
/// // Convert it to an AsyncRead.
/// let mut read = StreamReader::new(stream);
///
/// // Read five bytes from the stream.
/// let mut buf = [0; 5];
/// read.read_exact(&mut buf).await?;
/// assert_eq!(buf, [0, 1, 2, 3, 4]);
///
/// // Read the rest of the current chunk.
/// assert_eq!(read.read(&mut buf).await?, 3);
/// assert_eq!(&buf[..3], [5, 6, 7]);
///
/// // Reading the next chunk will produce an error
/// let error = read.read(&mut buf).await.unwrap_err();
/// assert_eq!(error.kind(), std::io::ErrorKind::Other);
/// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!");
///
/// // We have now reached the end.
/// assert_eq!(read.read(&mut buf).await?, 0);
///
/// # Ok(())
/// # }
/// ```
///
/// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks
/// line-by-line. Note that you will usually also need to convert the error
/// type when doing this. See the second example for an explanation of how
/// to do this.
///
/// ```
/// use tokio::io::{Result, AsyncBufReadExt};
/// use tokio_util::io::StreamReader;
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> std::io::Result<()> {
///
/// // Create a stream of byte chunks.
/// let stream = tokio_stream::iter(vec![
/// Result::Ok(b"The first line.\n".as_slice()),
/// Result::Ok(b"The second line.".as_slice()),
/// Result::Ok(b"\nThe third".as_slice()),
/// Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()),
/// ]);
///
/// // Convert it to an AsyncRead.
/// let mut read = StreamReader::new(stream);
///
/// // Loop through the lines from the `StreamReader`.
/// let mut line = String::new();
/// let mut lines = Vec::new();
/// loop {
/// line.clear();
/// let len = read.read_line(&mut line).await?;
/// if len == 0 { break; }
/// lines.push(line.clone());
/// }
///
/// // Verify that we got the lines we expected.
/// assert_eq!(
/// lines,
/// vec![
/// "The first line.\n",
/// "The second line.\n",
/// "The third line.\n",
/// "The fourth line.\n",
/// "The fifth line.\n",
/// ]
/// );
/// # Ok(())
/// # }
/// ```
///
/// [`AsyncRead`]: tokio::io::AsyncRead
/// [`AsyncBufRead`]: tokio::io::AsyncBufRead
/// [`Stream`]: futures_core::Stream
/// [`ReaderStream`]: crate::io::ReaderStream
/// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html
#[derive(Debug)]
pub struct StreamReader<S, B> {
// This field is pinned.
inner: S,
// This field is not pinned.
chunk: Option<B>,
}
impl<S, B, E> StreamReader<S, B>
where
S: Stream<Item = Result<B, E>>,
B: Buf,
E: Into<std::io::Error>,
{
/// Convert a stream of byte chunks into an [`AsyncRead`].
///
/// The item should be a [`Result`] with the ok variant being something that
/// implements the [`Buf`] trait (e.g. `Cursor<Vec<u8>>` or `Bytes`). The error
/// should be convertible into an [io error].
///
/// [`Result`]: std::result::Result
/// [`Buf`]: bytes::Buf
/// [io error]: std::io::Error
pub fn new(stream: S) -> Self {
Self {
inner: stream,
chunk: None,
}
}
/// Do we have a chunk and is it non-empty?
fn has_chunk(&self) -> bool {
if let Some(ref chunk) = self.chunk {
chunk.remaining() > 0
} else {
false
}
}
/// Consumes this `StreamReader`, returning a Tuple consisting
/// of the underlying stream and an Option of the internal buffer,
/// which is Some in case the buffer contains elements.
pub fn into_inner_with_chunk(self) -> (S, Option<B>) {
if self.has_chunk() {
(self.inner, self.chunk)
} else {
(self.inner, None)
}
}
}
impl<S, B> StreamReader<S, B> {
/// Gets a reference to the underlying stream.
///
/// It is inadvisable to directly read from the underlying stream.
pub fn get_ref(&self) -> &S {
&self.inner
}
/// Gets a mutable reference to the underlying stream.
///
/// It is inadvisable to directly read from the underlying stream.
pub fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
/// Gets a pinned mutable reference to the underlying stream.
///
/// It is inadvisable to directly read from the underlying stream.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
self.project().inner
}
/// Consumes this `BufWriter`, returning the underlying stream.
///
/// Note that any leftover data in the internal buffer is lost.
/// If you additionally want access to the internal buffer use
/// [`into_inner_with_chunk`].
///
/// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S, B, E> AsyncRead for StreamReader<S, B>
where
S: Stream<Item = Result<B, E>>,
B: Buf,
E: Into<std::io::Error>,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
let inner_buf = match self.as_mut().poll_fill_buf(cx) {
Poll::Ready(Ok(buf)) => buf,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
};
let len = std::cmp::min(inner_buf.len(), buf.remaining());
buf.put_slice(&inner_buf[..len]);
self.consume(len);
Poll::Ready(Ok(()))
}
}
impl<S, B, E> AsyncBufRead for StreamReader<S, B>
where
S: Stream<Item = Result<B, E>>,
B: Buf,
E: Into<std::io::Error>,
{
fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
loop {
if self.as_mut().has_chunk() {
// This unwrap is very sad, but it can't be avoided.
let buf = self.project().chunk.as_ref().unwrap().chunk();
return Poll::Ready(Ok(buf));
} else {
match self.as_mut().project().inner.poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
// Go around the loop in case the chunk is empty.
*self.as_mut().project().chunk = Some(chunk);
}
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
Poll::Ready(None) => return Poll::Ready(Ok(&[])),
Poll::Pending => return Poll::Pending,
}
}
}
}
fn consume(self: Pin<&mut Self>, amt: usize) {
if amt > 0 {
self.project()
.chunk
.as_mut()
.expect("No chunk present")
.advance(amt);
}
}
}
// The code below is a manual expansion of the code that pin-project-lite would
// generate. This is done because pin-project-lite fails by hitting the recursion
// limit on this struct. (Every line of documentation is handled recursively by
// the macro.)
impl<S: Unpin, B> Unpin for StreamReader<S, B> {}
struct StreamReaderProject<'a, S, B> {
inner: Pin<&'a mut S>,
chunk: &'a mut Option<B>,
}
impl<S, B> StreamReader<S, B> {
#[inline]
fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> {
// SAFETY: We define that only `inner` should be pinned when `Self` is
// and have an appropriate `impl Unpin` for this.
let me = unsafe { Pin::into_inner_unchecked(self) };
StreamReaderProject {
inner: unsafe { Pin::new_unchecked(&mut me.inner) },
chunk: &mut me.chunk,
}
}
}
impl<S: Sink<T, Error = E>, B, E, T> Sink<T> for StreamReader<S, B> {
type Error = E;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}

422
vendor/tokio-util/src/io/sync_bridge.rs vendored Normal file
View File

@@ -0,0 +1,422 @@
use std::io::{BufRead, Read, Seek, Write};
use tokio::io::{
AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite,
AsyncWriteExt,
};
/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
/// a [`tokio::io::AsyncWrite`] synchronously as a [`std::io::Write`].
///
/// # Alternatives
///
/// In many cases, there are better alternatives to using `SyncIoBridge`, especially
/// if you want to avoid blocking the async runtime. Consider the following scenarios:
///
/// When hashing data, using `SyncIoBridge` can lead to suboptimal performance and
/// might not fully leverage the async capabilities of the system.
///
/// ### Why It Matters:
///
/// `SyncIoBridge` allows you to use asynchronous I/O operations in an synchronous
/// context by blocking the current thread. However, this can be inefficient because:
/// - **Inefficient Resource Usage**: `SyncIoBridge` takes up an entire OS thread,
/// which is inefficient compared to asynchronous code that can multiplex many
/// tasks on a single thread.
/// - **Thread Pool Saturation**: Excessive use of `SyncIoBridge` can exhaust the
/// async runtime's thread pool, reducing the number of threads available for
/// other tasks and impacting overall performance.
/// - **Missed Concurrency Benefits**: By using synchronous operations with
/// `SyncIoBridge`, you lose the ability to interleave tasks efficiently,
/// which is a key advantage of asynchronous programming.
///
/// ## Example 1: Hashing Data
///
/// The use of `SyncIoBridge` is unnecessary when hashing data. Instead, you can
/// process the data asynchronously by reading it into memory, which avoids blocking
/// the async runtime.
///
/// There are two strategies for avoiding `SyncIoBridge` when hashing data. When
/// the data fits into memory, the easiest is to read the data into a `Vec<u8>`
/// and hash it:
///
/// Explanation: This example demonstrates how to asynchronously read data from a
/// reader into memory and hash it using a synchronous hashing function. The
/// `SyncIoBridge` is avoided, ensuring that the async runtime is not blocked.
/// ```rust
/// use tokio::io::AsyncReadExt;
/// use tokio::io::AsyncRead;
/// use std::io::Cursor;
/// # mod blake3 { pub fn hash(_: &[u8]) {} }
///
/// async fn hash_contents(mut reader: impl AsyncRead + Unpin) -> Result<(), std::io::Error> {
/// // Read all data from the reader into a Vec<u8>.
/// let mut data = Vec::new();
/// reader.read_to_end(&mut data).await?;
///
/// // Hash the data using the blake3 hashing function.
/// let hash = blake3::hash(&data);
///
/// Ok(hash)
/// }
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> Result<(), std::io::Error> {
/// // Example: In-memory data.
/// let data = b"Hello, world!"; // A byte slice.
/// let reader = Cursor::new(data); // Create an in-memory AsyncRead.
/// hash_contents(reader).await
/// # }
/// ```
///
/// When the data doesn't fit into memory, the hashing library will usually
/// provide a `hasher` that you can repeatedly call `update` on to hash the data
/// one chunk at the time.
///
/// Explanation: This example demonstrates how to asynchronously stream data in
/// chunks for hashing. Each chunk is read asynchronously, and the hash is updated
/// incrementally. This avoids blocking and improves performance over using
/// `SyncIoBridge`.
///
/// ```rust
/// use tokio::io::AsyncReadExt;
/// use tokio::io::AsyncRead;
/// use std::io::Cursor;
/// # struct Hasher;
/// # impl Hasher { pub fn update(&mut self, _: &[u8]) {} pub fn finalize(&self) {} }
///
/// /// Asynchronously streams data from an async reader, processes it in chunks,
/// /// and hashes the data incrementally.
/// async fn hash_stream(mut reader: impl AsyncRead + Unpin, mut hasher: Hasher) -> Result<(), std::io::Error> {
/// // Create a buffer to read data into, sized for performance.
/// let mut data = vec![0; 16 * 1024];
/// loop {
/// // Read data from the reader into the buffer.
/// let len = reader.read(&mut data).await?;
/// if len == 0 { break; } // Exit loop if no more data.
///
/// // Update the hash with the data read.
/// hasher.update(&data[..len]);
/// }
///
/// // Finalize the hash after all data has been processed.
/// let hash = hasher.finalize();
///
/// Ok(hash)
/// }
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> Result<(), std::io::Error> {
/// // Example: In-memory data.
/// let data = b"Hello, world!"; // A byte slice.
/// let reader = Cursor::new(data); // Create an in-memory AsyncRead.
/// let hasher = Hasher;
/// hash_stream(reader, hasher).await
/// # }
/// ```
///
///
/// ## Example 2: Compressing Data
///
/// When compressing data, the use of `SyncIoBridge` is unnecessary as it introduces
/// blocking and inefficient code. Instead, you can utilize an async compression library
/// such as the [`async-compression`](https://docs.rs/async-compression/latest/async_compression/)
/// crate, which is built to handle asynchronous data streams efficiently.
///
/// Explanation: This example shows how to asynchronously compress data using an
/// async compression library. By reading and writing asynchronously, it avoids
/// blocking and is more efficient than using `SyncIoBridge` with a non-async
/// compression library.
///
/// ```ignore
/// use async_compression::tokio::write::GzipEncoder;
/// use std::io::Cursor;
/// use tokio::io::AsyncRead;
///
/// /// Asynchronously compresses data from an async reader using Gzip and an async encoder.
/// async fn compress_data(mut reader: impl AsyncRead + Unpin) -> Result<(), std::io::Error> {
/// let writer = tokio::io::sink();
///
/// // Create a Gzip encoder that wraps the writer.
/// let mut encoder = GzipEncoder::new(writer);
///
/// // Copy data from the reader to the encoder, compressing it.
/// tokio::io::copy(&mut reader, &mut encoder).await?;
///
/// Ok(())
///}
///
/// #[tokio::main]
/// async fn main() -> Result<(), std::io::Error> {
/// // Example: In-memory data.
/// let data = b"Hello, world!"; // A byte slice.
/// let reader = Cursor::new(data); // Create an in-memory AsyncRead.
/// compress_data(reader).await?;
///
/// Ok(())
/// }
/// ```
///
///
/// ## Example 3: Parsing Data Formats
///
///
/// `SyncIoBridge` is not ideal when parsing data formats such as `JSON`, as it
/// blocks async operations. A more efficient approach is to read data asynchronously
/// into memory and then `deserialize` it, avoiding unnecessary synchronization overhead.
///
/// Explanation: This example shows how to asynchronously read data into memory
/// and then parse it as `JSON`. By avoiding `SyncIoBridge`, the asynchronous runtime
/// remains unblocked, leading to better performance when working with asynchronous
/// I/O streams.
///
/// ```rust,no_run
/// use tokio::io::AsyncRead;
/// use tokio::io::AsyncReadExt;
/// use std::io::Cursor;
/// # mod serde {
/// # pub trait DeserializeOwned: 'static {}
/// # impl<T: 'static> DeserializeOwned for T {}
/// # }
/// # mod serde_json {
/// # use super::serde::DeserializeOwned;
/// # pub fn from_slice<T: DeserializeOwned>(_: &[u8]) -> Result<T, std::io::Error> {
/// # unimplemented!()
/// # }
/// # }
/// # #[derive(Debug)] struct MyStruct;
///
///
/// async fn parse_json(mut reader: impl AsyncRead + Unpin) -> Result<MyStruct, std::io::Error> {
/// // Read all data from the reader into a Vec<u8>.
/// let mut data = Vec::new();
/// reader.read_to_end(&mut data).await?;
///
/// // Deserialize the data from the Vec<u8> into a MyStruct instance.
/// let value: MyStruct = serde_json::from_slice(&data)?;
///
/// Ok(value)
///}
///
/// #[tokio::main]
/// async fn main() -> Result<(), std::io::Error> {
/// // Example: In-memory data.
/// let data = b"Hello, world!"; // A byte slice.
/// let reader = Cursor::new(data); // Create an in-memory AsyncRead.
/// parse_json(reader).await?;
/// Ok(())
/// }
/// ```
///
/// ## Correct Usage of `SyncIoBridge` inside `spawn_blocking`
///
/// `SyncIoBridge` is mainly useful when you need to interface with synchronous
/// libraries from an asynchronous context.
///
/// Explanation: This example shows how to use `SyncIoBridge` inside a `spawn_blocking`
/// task to safely perform synchronous I/O without blocking the async runtime. The
/// `spawn_blocking` ensures that the synchronous code is offloaded to a dedicated
/// thread pool, preventing it from interfering with the async tasks.
///
/// ```rust
/// # #[cfg(not(target_family = "wasm"))]
/// # {
/// use tokio::task::spawn_blocking;
/// use tokio_util::io::SyncIoBridge;
/// use tokio::io::AsyncRead;
/// use std::marker::Unpin;
/// use std::io::Cursor;
///
/// /// Wraps an async reader with `SyncIoBridge` and performs synchronous I/O operations in a blocking task.
/// async fn process_sync_io(reader: impl AsyncRead + Unpin + Send + 'static) -> Result<Vec<u8>, std::io::Error> {
/// // Wrap the async reader with `SyncIoBridge` to allow synchronous reading.
/// let mut sync_reader = SyncIoBridge::new(reader);
///
/// // Spawn a blocking task to perform synchronous I/O operations.
/// let result = spawn_blocking(move || {
/// // Create an in-memory buffer to hold the copied data.
/// let mut buffer = Vec::new();
/// // Copy data from the sync_reader to the buffer.
/// std::io::copy(&mut sync_reader, &mut buffer)?;
/// // Return the buffer containing the copied data.
/// Ok::<_, std::io::Error>(buffer)
/// })
/// .await??;
///
/// // Return the result from the blocking task.
/// Ok(result)
///}
///
/// #[tokio::main]
/// async fn main() -> Result<(), std::io::Error> {
/// // Example: In-memory data.
/// let data = b"Hello, world!"; // A byte slice.
/// let reader = Cursor::new(data); // Create an in-memory AsyncRead.
/// let result = process_sync_io(reader).await?;
///
/// // You can use `result` here as needed.
///
/// Ok(())
/// }
/// # }
/// ```
///
#[derive(Debug)]
pub struct SyncIoBridge<T> {
src: T,
rt: tokio::runtime::Handle,
}
impl<T: AsyncBufRead + Unpin> BufRead for SyncIoBridge<T> {
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
let src = &mut self.src;
self.rt.block_on(AsyncBufReadExt::fill_buf(src))
}
fn consume(&mut self, amt: usize) {
let src = &mut self.src;
AsyncBufReadExt::consume(src, amt)
}
fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> std::io::Result<usize> {
let src = &mut self.src;
self.rt
.block_on(AsyncBufReadExt::read_until(src, byte, buf))
}
fn read_line(&mut self, buf: &mut String) -> std::io::Result<usize> {
let src = &mut self.src;
self.rt.block_on(AsyncBufReadExt::read_line(src, buf))
}
}
impl<T: AsyncRead + Unpin> Read for SyncIoBridge<T> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let src = &mut self.src;
self.rt.block_on(AsyncReadExt::read(src, buf))
}
fn read_to_end(&mut self, buf: &mut Vec<u8>) -> std::io::Result<usize> {
let src = &mut self.src;
self.rt.block_on(src.read_to_end(buf))
}
fn read_to_string(&mut self, buf: &mut String) -> std::io::Result<usize> {
let src = &mut self.src;
self.rt.block_on(src.read_to_string(buf))
}
fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
let src = &mut self.src;
// The AsyncRead trait returns the count, synchronous doesn't.
let _n = self.rt.block_on(src.read_exact(buf))?;
Ok(())
}
}
impl<T: AsyncWrite + Unpin> Write for SyncIoBridge<T> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let src = &mut self.src;
self.rt.block_on(src.write(buf))
}
fn flush(&mut self) -> std::io::Result<()> {
let src = &mut self.src;
self.rt.block_on(src.flush())
}
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
let src = &mut self.src;
self.rt.block_on(src.write_all(buf))
}
fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
let src = &mut self.src;
self.rt.block_on(src.write_vectored(bufs))
}
}
impl<T: AsyncSeek + Unpin> Seek for SyncIoBridge<T> {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
let src = &mut self.src;
self.rt.block_on(AsyncSeekExt::seek(src, pos))
}
}
// Because https://doc.rust-lang.org/std/io/trait.Write.html#method.is_write_vectored is at the time
// of this writing still unstable, we expose this as part of a standalone method.
impl<T: AsyncWrite> SyncIoBridge<T> {
/// Determines if the underlying [`tokio::io::AsyncWrite`] target supports efficient vectored writes.
///
/// See [`tokio::io::AsyncWrite::is_write_vectored`].
pub fn is_write_vectored(&self) -> bool {
self.src.is_write_vectored()
}
}
impl<T: AsyncWrite + Unpin> SyncIoBridge<T> {
/// Shutdown this writer. This method provides a way to call the [`AsyncWriteExt::shutdown`]
/// function of the inner [`tokio::io::AsyncWrite`] instance.
///
/// # Errors
///
/// This method returns the same errors as [`AsyncWriteExt::shutdown`].
///
/// [`AsyncWriteExt::shutdown`]: tokio::io::AsyncWriteExt::shutdown
pub fn shutdown(&mut self) -> std::io::Result<()> {
let src = &mut self.src;
self.rt.block_on(src.shutdown())
}
}
impl<T: Unpin> SyncIoBridge<T> {
/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
/// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`].
///
/// When this struct is created, it captures a handle to the current thread's runtime with [`tokio::runtime::Handle::current`].
/// It is hence OK to move this struct into a separate thread outside the runtime, as created
/// by e.g. [`tokio::task::spawn_blocking`].
///
/// Stated even more strongly: to make use of this bridge, you *must* move
/// it into a separate thread outside the runtime. The synchronous I/O will use the
/// underlying handle to block on the backing asynchronous source, via
/// [`tokio::runtime::Handle::block_on`]. As noted in the documentation for that
/// function, an attempt to `block_on` from an asynchronous execution context
/// will panic.
///
/// # Wrapping `!Unpin` types
///
/// Use e.g. `SyncIoBridge::new(Box::pin(src))`.
///
/// # Panics
///
/// This will panic if called outside the context of a Tokio runtime.
#[track_caller]
pub fn new(src: T) -> Self {
Self::new_with_handle(src, tokio::runtime::Handle::current())
}
/// Use a [`tokio::io::AsyncRead`] synchronously as a [`std::io::Read`] or
/// a [`tokio::io::AsyncWrite`] as a [`std::io::Write`].
///
/// This is the same as [`SyncIoBridge::new`], but allows passing an arbitrary handle and hence may
/// be initially invoked outside of an asynchronous context.
pub fn new_with_handle(src: T, rt: tokio::runtime::Handle) -> Self {
Self { src, rt }
}
/// Consume this bridge, returning the underlying stream.
pub fn into_inner(self) -> T {
self.src
}
}
impl<T> AsMut<T> for SyncIoBridge<T> {
fn as_mut(&mut self) -> &mut T {
&mut self.src
}
}
impl<T> AsRef<T> for SyncIoBridge<T> {
fn as_ref(&self) -> &T {
&self.src
}
}

65
vendor/tokio-util/src/lib.rs vendored Normal file
View File

@@ -0,0 +1,65 @@
#![allow(clippy::needless_doctest_main)]
#![warn(
missing_debug_implementations,
missing_docs,
rust_2018_idioms,
unreachable_pub
)]
#![doc(test(
no_crate_inject,
attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
))]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! Utilities for working with Tokio.
//!
//! This crate is not versioned in lockstep with the core
//! [`tokio`] crate. However, `tokio-util` _will_ respect Rust's
//! semantic versioning policy, especially with regard to breaking changes.
#[macro_use]
mod cfg;
mod loom;
cfg_codec! {
#[macro_use]
mod tracing;
pub mod codec;
}
cfg_net! {
#[cfg(not(target_arch = "wasm32"))]
pub mod udp;
pub mod net;
}
cfg_compat! {
pub mod compat;
}
cfg_io! {
pub mod io;
}
cfg_rt! {
pub mod context;
}
#[cfg(feature = "rt")]
pub mod task;
cfg_time! {
pub mod time;
}
pub mod sync;
pub mod either;
pub use bytes;
mod util;
pub mod future;

9
vendor/tokio-util/src/loom.rs vendored Normal file
View File

@@ -0,0 +1,9 @@
//! This module abstracts over `loom` and `std::sync` types depending on whether we
//! are running loom tests or not.
pub(crate) mod sync {
#[cfg(all(test, loom))]
pub(crate) use loom::sync::{Arc, Mutex, MutexGuard};
#[cfg(not(all(test, loom)))]
pub(crate) use std::sync::{Arc, Mutex, MutexGuard};
}

99
vendor/tokio-util/src/net/mod.rs vendored Normal file
View File

@@ -0,0 +1,99 @@
#![cfg(not(loom))]
//! TCP/UDP/Unix helpers for tokio.
use crate::either::Either;
use std::future::Future;
use std::io::Result;
use std::pin::Pin;
use std::task::{Context, Poll};
#[cfg(unix)]
pub mod unix;
/// A trait for a listener: `TcpListener` and `UnixListener`.
pub trait Listener {
/// The stream's type of this listener.
type Io: tokio::io::AsyncRead + tokio::io::AsyncWrite;
/// The socket address type of this listener.
type Addr;
/// Polls to accept a new incoming connection to this listener.
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>>;
/// Accepts a new incoming connection from this listener.
fn accept(&mut self) -> ListenerAcceptFut<'_, Self>
where
Self: Sized,
{
ListenerAcceptFut { listener: self }
}
/// Returns the local address that this listener is bound to.
fn local_addr(&self) -> Result<Self::Addr>;
}
impl Listener for tokio::net::TcpListener {
type Io = tokio::net::TcpStream;
type Addr = std::net::SocketAddr;
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> {
Self::poll_accept(self, cx)
}
fn local_addr(&self) -> Result<Self::Addr> {
self.local_addr()
}
}
/// Future for accepting a new connection from a listener.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ListenerAcceptFut<'a, L> {
listener: &'a mut L,
}
impl<'a, L> Future for ListenerAcceptFut<'a, L>
where
L: Listener,
{
type Output = Result<(L::Io, L::Addr)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.listener.poll_accept(cx)
}
}
impl<L, R> Either<L, R>
where
L: Listener,
R: Listener,
{
/// Accepts a new incoming connection from this listener.
pub async fn accept(&mut self) -> Result<Either<(L::Io, L::Addr), (R::Io, R::Addr)>> {
match self {
Either::Left(listener) => {
let (stream, addr) = listener.accept().await?;
Ok(Either::Left((stream, addr)))
}
Either::Right(listener) => {
let (stream, addr) = listener.accept().await?;
Ok(Either::Right((stream, addr)))
}
}
}
/// Returns the local address that this listener is bound to.
pub fn local_addr(&self) -> Result<Either<L::Addr, R::Addr>> {
match self {
Either::Left(listener) => {
let addr = listener.local_addr()?;
Ok(Either::Left(addr))
}
Either::Right(listener) => {
let addr = listener.local_addr()?;
Ok(Either::Right(addr))
}
}
}
}

18
vendor/tokio-util/src/net/unix/mod.rs vendored Normal file
View File

@@ -0,0 +1,18 @@
//! Unix domain socket helpers.
use super::Listener;
use std::io::Result;
use std::task::{Context, Poll};
impl Listener for tokio::net::UnixListener {
type Io = tokio::net::UnixStream;
type Addr = tokio::net::unix::SocketAddr;
fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(Self::Io, Self::Addr)>> {
Self::poll_accept(self, cx)
}
fn local_addr(&self) -> Result<Self::Addr> {
self.local_addr()
}
}

View File

@@ -0,0 +1,487 @@
//! An asynchronously awaitable [`CancellationToken`].
//! The token allows to signal a cancellation request to one or more tasks.
pub(crate) mod guard;
pub(crate) mod guard_ref;
mod tree_node;
use crate::loom::sync::Arc;
use crate::util::MaybeDangling;
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
use guard::DropGuard;
use guard_ref::DropGuardRef;
use pin_project_lite::pin_project;
/// A token which can be used to signal a cancellation request to one or more
/// tasks.
///
/// Tasks can call [`CancellationToken::cancelled()`] in order to
/// obtain a Future which will be resolved when cancellation is requested.
///
/// Cancellation can be requested through the [`CancellationToken::cancel`] method.
///
/// # Examples
///
/// ```no_run
/// use tokio::select;
/// use tokio_util::sync::CancellationToken;
///
/// #[tokio::main]
/// async fn main() {
/// let token = CancellationToken::new();
/// let cloned_token = token.clone();
///
/// let join_handle = tokio::spawn(async move {
/// // Wait for either cancellation or a very long time
/// select! {
/// _ = cloned_token.cancelled() => {
/// // The token was cancelled
/// 5
/// }
/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
/// 99
/// }
/// }
/// });
///
/// tokio::spawn(async move {
/// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
/// token.cancel();
/// });
///
/// assert_eq!(5, join_handle.await.unwrap());
/// }
/// ```
pub struct CancellationToken {
inner: Arc<tree_node::TreeNode>,
}
impl std::panic::UnwindSafe for CancellationToken {}
impl std::panic::RefUnwindSafe for CancellationToken {}
pin_project! {
/// A Future that is resolved once the corresponding [`CancellationToken`]
/// is cancelled.
#[must_use = "futures do nothing unless polled"]
pub struct WaitForCancellationFuture<'a> {
cancellation_token: &'a CancellationToken,
#[pin]
future: tokio::sync::futures::Notified<'a>,
}
}
pin_project! {
/// A Future that is resolved once the corresponding [`CancellationToken`]
/// is cancelled.
///
/// This is the counterpart to [`WaitForCancellationFuture`] that takes
/// [`CancellationToken`] by value instead of using a reference.
#[must_use = "futures do nothing unless polled"]
pub struct WaitForCancellationFutureOwned {
// This field internally has a reference to the cancellation token, but camouflages
// the relationship with `'static`. To avoid Undefined Behavior, we must ensure
// that the reference is only used while the cancellation token is still alive. To
// do that, we ensure that the future is the first field, so that it is dropped
// before the cancellation token.
//
// We use `MaybeDanglingFuture` here because without it, the compiler could assert
// the reference inside `future` to be valid even after the destructor of that
// field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed
// as an argument to a function, the reference can be asserted to be valid for the
// rest of that function.) To avoid that, we use `MaybeDangling` which tells the
// compiler that the reference stored inside it might not be valid.
//
// See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
// for more info.
#[pin]
future: MaybeDangling<tokio::sync::futures::Notified<'static>>,
cancellation_token: CancellationToken,
}
}
// ===== impl CancellationToken =====
impl core::fmt::Debug for CancellationToken {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("CancellationToken")
.field("is_cancelled", &self.is_cancelled())
.finish()
}
}
impl Clone for CancellationToken {
/// Creates a clone of the [`CancellationToken`] which will get cancelled
/// whenever the current token gets cancelled, and vice versa.
fn clone(&self) -> Self {
tree_node::increase_handle_refcount(&self.inner);
CancellationToken {
inner: self.inner.clone(),
}
}
}
impl Drop for CancellationToken {
fn drop(&mut self) {
tree_node::decrease_handle_refcount(&self.inner);
}
}
impl Default for CancellationToken {
fn default() -> CancellationToken {
CancellationToken::new()
}
}
impl CancellationToken {
/// Creates a new [`CancellationToken`] in the non-cancelled state.
pub fn new() -> CancellationToken {
CancellationToken {
inner: Arc::new(tree_node::TreeNode::new()),
}
}
/// Creates a [`CancellationToken`] which will get cancelled whenever the
/// current token gets cancelled. Unlike a cloned [`CancellationToken`],
/// cancelling a child token does not cancel the parent token.
///
/// If the current token is already cancelled, the child token will get
/// returned in cancelled state.
///
/// # Examples
///
/// ```no_run
/// use tokio::select;
/// use tokio_util::sync::CancellationToken;
///
/// #[tokio::main]
/// async fn main() {
/// let token = CancellationToken::new();
/// let child_token = token.child_token();
///
/// let join_handle = tokio::spawn(async move {
/// // Wait for either cancellation or a very long time
/// select! {
/// _ = child_token.cancelled() => {
/// // The token was cancelled
/// 5
/// }
/// _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
/// 99
/// }
/// }
/// });
///
/// tokio::spawn(async move {
/// tokio::time::sleep(std::time::Duration::from_millis(10)).await;
/// token.cancel();
/// });
///
/// assert_eq!(5, join_handle.await.unwrap());
/// }
/// ```
pub fn child_token(&self) -> CancellationToken {
CancellationToken {
inner: tree_node::child_node(&self.inner),
}
}
/// Cancel the [`CancellationToken`] and all child tokens which had been
/// derived from it.
///
/// This will wake up all tasks which are waiting for cancellation.
///
/// Be aware that cancellation is not an atomic operation. It is possible
/// for another thread running in parallel with a call to `cancel` to first
/// receive `true` from `is_cancelled` on one child node, and then receive
/// `false` from `is_cancelled` on another child node. However, once the
/// call to `cancel` returns, all child nodes have been fully cancelled.
pub fn cancel(&self) {
tree_node::cancel(&self.inner);
}
/// Returns `true` if the `CancellationToken` is cancelled.
pub fn is_cancelled(&self) -> bool {
tree_node::is_cancelled(&self.inner)
}
/// Returns a [`Future`] that gets fulfilled when cancellation is requested.
///
/// Equivalent to:
///
/// ```ignore
/// async fn cancelled(&self);
/// ```
///
/// The future will complete immediately if the token is already cancelled
/// when this method is called.
///
/// # Cancellation safety
///
/// This method is cancel safe.
pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
WaitForCancellationFuture {
cancellation_token: self,
future: self.inner.notified(),
}
}
/// Returns a [`Future`] that gets fulfilled when cancellation is requested.
///
/// Equivalent to:
///
/// ```ignore
/// async fn cancelled_owned(self);
/// ```
///
/// The future will complete immediately if the token is already cancelled
/// when this method is called.
///
/// The function takes self by value and returns a future that owns the
/// token.
///
/// # Cancellation safety
///
/// This method is cancel safe.
pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned {
WaitForCancellationFutureOwned::new(self)
}
/// Creates a [`DropGuard`] for this token.
///
/// Returned guard will cancel this token (and all its children) on drop
/// unless disarmed.
pub fn drop_guard(self) -> DropGuard {
DropGuard { inner: Some(self) }
}
/// Creates a [`DropGuardRef`] for this token.
///
/// Returned guard will cancel this token (and all its children) on drop
/// unless disarmed.
pub fn drop_guard_ref(&self) -> DropGuardRef<'_> {
DropGuardRef { inner: Some(self) }
}
/// Runs a future to completion and returns its result wrapped inside of an `Option`
/// unless the [`CancellationToken`] is cancelled. In that case the function returns
/// `None` and the future gets dropped.
///
/// # Fairness
///
/// Calling this on an already-cancelled token directly returns `None`.
/// For all subsequent polls, in case of concurrent completion and
/// cancellation, this is biased towards the future completion.
///
/// # Cancellation safety
///
/// This method is only cancel safe if `fut` is cancel safe.
pub async fn run_until_cancelled<F>(&self, fut: F) -> Option<F::Output>
where
F: Future,
{
if self.is_cancelled() {
None
} else {
RunUntilCancelledFuture {
cancellation: self.cancelled(),
future: fut,
}
.await
}
}
/// Runs a future to completion and returns its result wrapped inside of an `Option`
/// unless the [`CancellationToken`] is cancelled. In that case the function returns
/// `None` and the future gets dropped.
///
/// The function takes self by value and returns a future that owns the token.
///
/// # Fairness
///
/// Calling this on an already-cancelled token directly returns `None`.
/// For all subsequent polls, in case of concurrent completion and
/// cancellation, this is biased towards the future completion.
///
/// # Cancellation safety
///
/// This method is only cancel safe if `fut` is cancel safe.
pub async fn run_until_cancelled_owned<F>(self, fut: F) -> Option<F::Output>
where
F: Future,
{
self.run_until_cancelled(fut).await
}
}
// ===== impl WaitForCancellationFuture =====
impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WaitForCancellationFuture").finish()
}
}
impl<'a> Future for WaitForCancellationFuture<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let mut this = self.project();
loop {
if this.cancellation_token.is_cancelled() {
return Poll::Ready(());
}
// No wakeups can be lost here because there is always a call to
// `is_cancelled` between the creation of the future and the call to
// `poll`, and the code that sets the cancelled flag does so before
// waking the `Notified`.
if this.future.as_mut().poll(cx).is_pending() {
return Poll::Pending;
}
this.future.set(this.cancellation_token.inner.notified());
}
}
}
// ===== impl WaitForCancellationFutureOwned =====
impl core::fmt::Debug for WaitForCancellationFutureOwned {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("WaitForCancellationFutureOwned").finish()
}
}
impl WaitForCancellationFutureOwned {
fn new(cancellation_token: CancellationToken) -> Self {
WaitForCancellationFutureOwned {
// cancellation_token holds a heap allocation and is guaranteed to have a
// stable deref, thus it would be ok to move the cancellation_token while
// the future holds a reference to it.
//
// # Safety
//
// cancellation_token is dropped after future due to the field ordering.
future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }),
cancellation_token,
}
}
/// # Safety
/// The returned future must be destroyed before the cancellation token is
/// destroyed.
unsafe fn new_future(
cancellation_token: &CancellationToken,
) -> tokio::sync::futures::Notified<'static> {
let inner_ptr = Arc::as_ptr(&cancellation_token.inner);
// SAFETY: The `Arc::as_ptr` method guarantees that `inner_ptr` remains
// valid until the strong count of the Arc drops to zero, and the caller
// guarantees that they will drop the future before that happens.
(*inner_ptr).notified()
}
}
impl Future for WaitForCancellationFutureOwned {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let mut this = self.project();
loop {
if this.cancellation_token.is_cancelled() {
return Poll::Ready(());
}
// No wakeups can be lost here because there is always a call to
// `is_cancelled` between the creation of the future and the call to
// `poll`, and the code that sets the cancelled flag does so before
// waking the `Notified`.
if this.future.as_mut().poll(cx).is_pending() {
return Poll::Pending;
}
// # Safety
//
// cancellation_token is dropped after future due to the field ordering.
this.future.set(MaybeDangling::new(unsafe {
Self::new_future(this.cancellation_token)
}));
}
}
}
pin_project! {
/// A Future that is resolved once the corresponding [`CancellationToken`]
/// is cancelled or a given Future gets resolved. It is biased towards the
/// Future completion.
#[must_use = "futures do nothing unless polled"]
pub(crate) struct RunUntilCancelledFuture<'a, F: Future> {
#[pin]
cancellation: WaitForCancellationFuture<'a>,
#[pin]
future: F,
}
}
impl<'a, F: Future> RunUntilCancelledFuture<'a, F> {
pub(crate) fn new(cancellation_token: &'a CancellationToken, future: F) -> Self {
Self {
cancellation: cancellation_token.cancelled(),
future,
}
}
}
impl<'a, F: Future> Future for RunUntilCancelledFuture<'a, F> {
type Output = Option<F::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(res) = this.future.poll(cx) {
Poll::Ready(Some(res))
} else if this.cancellation.poll(cx).is_ready() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
pin_project! {
/// A Future that is resolved once the corresponding [`CancellationToken`]
/// is cancelled or a given Future gets resolved. It is biased towards the
/// Future completion.
#[must_use = "futures do nothing unless polled"]
pub(crate) struct RunUntilCancelledFutureOwned<F: Future> {
#[pin]
cancellation: WaitForCancellationFutureOwned,
#[pin]
future: F,
}
}
impl<F: Future> Future for RunUntilCancelledFutureOwned<F> {
type Output = Option<F::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(res) = this.future.poll(cx) {
Poll::Ready(Some(res))
} else if this.cancellation.poll(cx).is_ready() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
impl<F: Future> RunUntilCancelledFutureOwned<F> {
pub(crate) fn new(cancellation_token: CancellationToken, future: F) -> Self {
Self {
cancellation: cancellation_token.cancelled_owned(),
future,
}
}
}

View File

@@ -0,0 +1,29 @@
use crate::sync::CancellationToken;
/// A wrapper for cancellation token which automatically cancels
/// it on drop. It is created using [`drop_guard`] method on the [`CancellationToken`].
///
/// [`drop_guard`]: CancellationToken::drop_guard
#[derive(Debug)]
pub struct DropGuard {
pub(super) inner: Option<CancellationToken>,
}
impl DropGuard {
/// Returns stored cancellation token and removes this drop guard instance
/// (i.e. it will no longer cancel token). Other guards for this token
/// are not affected.
pub fn disarm(mut self) -> CancellationToken {
self.inner
.take()
.expect("`inner` can be only None in a destructor")
}
}
impl Drop for DropGuard {
fn drop(&mut self) {
if let Some(inner) = &self.inner {
inner.cancel();
}
}
}

View File

@@ -0,0 +1,32 @@
use crate::sync::CancellationToken;
/// A wrapper for cancellation token which automatically cancels
/// it on drop. It is created using [`drop_guard_ref`] method on the [`CancellationToken`].
///
/// This is a borrowed version of [`DropGuard`].
///
/// [`drop_guard_ref`]: CancellationToken::drop_guard_ref
/// [`DropGuard`]: super::DropGuard
#[derive(Debug)]
pub struct DropGuardRef<'a> {
pub(super) inner: Option<&'a CancellationToken>,
}
impl<'a> DropGuardRef<'a> {
/// Returns stored cancellation token and removes this drop guard instance
/// (i.e. it will no longer cancel token). Other guards for this token
/// are not affected.
pub fn disarm(mut self) -> &'a CancellationToken {
self.inner
.take()
.expect("`inner` can be only None in a destructor")
}
}
impl Drop for DropGuardRef<'_> {
fn drop(&mut self) {
if let Some(inner) = self.inner {
inner.cancel();
}
}
}

View File

@@ -0,0 +1,367 @@
//! This mod provides the logic for the inner tree structure of the `CancellationToken`.
//!
//! `CancellationTokens` are only light handles with references to [`TreeNode`].
//! All the logic is actually implemented in the [`TreeNode`].
//!
//! A [`TreeNode`] is part of the cancellation tree and may have one parent and an arbitrary number of
//! children.
//!
//! A [`TreeNode`] can receive the request to perform a cancellation through a `CancellationToken`.
//! This cancellation request will cancel the node and all of its descendants.
//!
//! As soon as a node cannot get cancelled any more (because it was already cancelled or it has no
//! more `CancellationTokens` pointing to it any more), it gets removed from the tree, to keep the
//! tree as small as possible.
//!
//! # Invariants
//!
//! Those invariants shall be true at any time.
//!
//! 1. A node that has no parents and no handles can no longer be cancelled.
//! This is important during both cancellation and refcounting.
//!
//! 2. If node B *is* or *was* a child of node A, then node B was created *after* node A.
//! This is important for deadlock safety, as it is used for lock order.
//! Node B can only become the child of node A in two ways:
//! - being created with `child_node()`, in which case it is trivially true that
//! node A already existed when node B was created
//! - being moved A->C->B to A->B because node C was removed in `decrease_handle_refcount()`
//! or `cancel()`. In this case the invariant still holds, as B was younger than C, and C
//! was younger than A, therefore B is also younger than A.
//!
//! 3. If two nodes are both unlocked and node A is the parent of node B, then node B is a child of
//! node A. It is important to always restore that invariant before dropping the lock of a node.
//!
//! # Deadlock safety
//!
//! We always lock in the order of creation time. We can prove this through invariant #2.
//! Specifically, through invariant #2, we know that we always have to lock a parent
//! before its child.
//!
use crate::loom::sync::{Arc, Mutex, MutexGuard};
/// A node of the cancellation tree structure
///
/// The actual data it holds is wrapped inside a mutex for synchronization.
pub(crate) struct TreeNode {
inner: Mutex<Inner>,
waker: tokio::sync::Notify,
}
impl TreeNode {
pub(crate) fn new() -> Self {
Self {
inner: Mutex::new(Inner {
parent: None,
parent_idx: 0,
children: vec![],
is_cancelled: false,
num_handles: 1,
}),
waker: tokio::sync::Notify::new(),
}
}
pub(crate) fn notified(&self) -> tokio::sync::futures::Notified<'_> {
self.waker.notified()
}
}
/// The data contained inside a `TreeNode`.
///
/// This struct exists so that the data of the node can be wrapped
/// in a Mutex.
struct Inner {
parent: Option<Arc<TreeNode>>,
parent_idx: usize,
children: Vec<Arc<TreeNode>>,
is_cancelled: bool,
num_handles: usize,
}
/// Returns whether or not the node is cancelled
pub(crate) fn is_cancelled(node: &Arc<TreeNode>) -> bool {
node.inner.lock().unwrap().is_cancelled
}
/// Creates a child node
pub(crate) fn child_node(parent: &Arc<TreeNode>) -> Arc<TreeNode> {
let mut locked_parent = parent.inner.lock().unwrap();
// Do not register as child if we are already cancelled.
// Cancelled trees can never be uncancelled and therefore
// need no connection to parents or children any more.
if locked_parent.is_cancelled {
return Arc::new(TreeNode {
inner: Mutex::new(Inner {
parent: None,
parent_idx: 0,
children: vec![],
is_cancelled: true,
num_handles: 1,
}),
waker: tokio::sync::Notify::new(),
});
}
let child = Arc::new(TreeNode {
inner: Mutex::new(Inner {
parent: Some(parent.clone()),
parent_idx: locked_parent.children.len(),
children: vec![],
is_cancelled: false,
num_handles: 1,
}),
waker: tokio::sync::Notify::new(),
});
locked_parent.children.push(child.clone());
child
}
/// Disconnects the given parent from all of its children.
///
/// Takes a reference to [Inner] to make sure the parent is already locked.
fn disconnect_children(node: &mut Inner) {
for child in std::mem::take(&mut node.children) {
let mut locked_child = child.inner.lock().unwrap();
locked_child.parent_idx = 0;
locked_child.parent = None;
}
}
/// Figures out the parent of the node and locks the node and its parent atomically.
///
/// The basic principle of preventing deadlocks in the tree is
/// that we always lock the parent first, and then the child.
/// For more info look at *deadlock safety* and *invariant #2*.
///
/// Sadly, it's impossible to figure out the parent of a node without
/// locking it. To then achieve locking order consistency, the node
/// has to be unlocked before the parent gets locked.
/// This leaves a small window where we already assume that we know the parent,
/// but neither the parent nor the node is locked. Therefore, the parent could change.
///
/// To prevent that this problem leaks into the rest of the code, it is abstracted
/// in this function.
///
/// The locked child and optionally its locked parent, if a parent exists, get passed
/// to the `func` argument via (node, None) or (node, Some(parent)).
fn with_locked_node_and_parent<F, Ret>(node: &Arc<TreeNode>, func: F) -> Ret
where
F: FnOnce(MutexGuard<'_, Inner>, Option<MutexGuard<'_, Inner>>) -> Ret,
{
use std::sync::TryLockError;
let mut locked_node = node.inner.lock().unwrap();
// Every time this fails, the number of ancestors of the node decreases,
// so the loop must succeed after a finite number of iterations.
loop {
// Look up the parent of the currently locked node.
let potential_parent = match locked_node.parent.as_ref() {
Some(potential_parent) => potential_parent.clone(),
None => return func(locked_node, None),
};
// Lock the parent. This may require unlocking the child first.
let locked_parent = match potential_parent.inner.try_lock() {
Ok(locked_parent) => locked_parent,
Err(TryLockError::WouldBlock) => {
drop(locked_node);
// Deadlock safety:
//
// Due to invariant #2, the potential parent must come before
// the child in the creation order. Therefore, we can safely
// lock the child while holding the parent lock.
let locked_parent = potential_parent.inner.lock().unwrap();
locked_node = node.inner.lock().unwrap();
locked_parent
}
// https://github.com/tokio-rs/tokio/pull/6273#discussion_r1443752911
#[allow(clippy::unnecessary_literal_unwrap)]
Err(TryLockError::Poisoned(err)) => Err(err).unwrap(),
};
// If we unlocked the child, then the parent may have changed. Check
// that we still have the right parent.
if let Some(actual_parent) = locked_node.parent.as_ref() {
if Arc::ptr_eq(actual_parent, &potential_parent) {
return func(locked_node, Some(locked_parent));
}
}
}
}
/// Moves all children from `node` to `parent`.
///
/// `parent` MUST have been a parent of the node when they both got locked,
/// otherwise there is a potential for a deadlock as invariant #2 would be violated.
///
/// To acquire the locks for node and parent, use [`with_locked_node_and_parent`].
fn move_children_to_parent(node: &mut Inner, parent: &mut Inner) {
// Pre-allocate in the parent, for performance
parent.children.reserve(node.children.len());
for child in std::mem::take(&mut node.children) {
{
let mut child_locked = child.inner.lock().unwrap();
child_locked.parent.clone_from(&node.parent);
child_locked.parent_idx = parent.children.len();
}
parent.children.push(child);
}
}
/// Removes a child from the parent.
///
/// `parent` MUST be the parent of `node`.
/// To acquire the locks for node and parent, use [`with_locked_node_and_parent`].
fn remove_child(parent: &mut Inner, mut node: MutexGuard<'_, Inner>) {
// Query the position from where to remove a node
let pos = node.parent_idx;
node.parent = None;
node.parent_idx = 0;
// Unlock node, so that only one child at a time is locked.
// Otherwise we would violate the lock order (see 'deadlock safety') as we
// don't know the creation order of the child nodes
drop(node);
// If `node` is the last element in the list, we don't need any swapping
if parent.children.len() == pos + 1 {
parent.children.pop().unwrap();
} else {
// If `node` is not the last element in the list, we need to
// replace it with the last element
let replacement_child = parent.children.pop().unwrap();
replacement_child.inner.lock().unwrap().parent_idx = pos;
parent.children[pos] = replacement_child;
}
let len = parent.children.len();
if 4 * len <= parent.children.capacity() {
parent.children.shrink_to(2 * len);
}
}
/// Increases the reference count of handles.
pub(crate) fn increase_handle_refcount(node: &Arc<TreeNode>) {
let mut locked_node = node.inner.lock().unwrap();
// Once no handles are left over, the node gets detached from the tree.
// There should never be a new handle once all handles are dropped.
assert!(locked_node.num_handles > 0);
locked_node.num_handles += 1;
}
/// Decreases the reference count of handles.
///
/// Once no handle is left, we can remove the node from the
/// tree and connect its parent directly to its children.
pub(crate) fn decrease_handle_refcount(node: &Arc<TreeNode>) {
let num_handles = {
let mut locked_node = node.inner.lock().unwrap();
locked_node.num_handles -= 1;
locked_node.num_handles
};
if num_handles == 0 {
with_locked_node_and_parent(node, |mut node, parent| {
// Remove the node from the tree
match parent {
Some(mut parent) => {
// As we want to remove ourselves from the tree,
// we have to move the children to the parent, so that
// they still receive the cancellation event without us.
// Moving them does not violate invariant #1.
move_children_to_parent(&mut node, &mut parent);
// Remove the node from the parent
remove_child(&mut parent, node);
}
None => {
// Due to invariant #1, we can assume that our
// children can no longer be cancelled through us.
// (as we now have neither a parent nor handles)
// Therefore we can disconnect them.
disconnect_children(&mut node);
}
}
});
}
}
/// Cancels a node and its children.
pub(crate) fn cancel(node: &Arc<TreeNode>) {
let mut locked_node = node.inner.lock().unwrap();
if locked_node.is_cancelled {
return;
}
// One by one, adopt grandchildren and then cancel and detach the child
while let Some(child) = locked_node.children.pop() {
// This can't deadlock because the mutex we are already
// holding is the parent of child.
let mut locked_child = child.inner.lock().unwrap();
// Detach the child from node
// No need to modify node.children, as the child already got removed with `.pop`
locked_child.parent = None;
locked_child.parent_idx = 0;
// If child is already cancelled, detaching is enough
if locked_child.is_cancelled {
continue;
}
// Cancel or adopt grandchildren
while let Some(grandchild) = locked_child.children.pop() {
// This can't deadlock because the two mutexes we are already
// holding is the parent and grandparent of grandchild.
let mut locked_grandchild = grandchild.inner.lock().unwrap();
// Detach the grandchild
locked_grandchild.parent = None;
locked_grandchild.parent_idx = 0;
// If grandchild is already cancelled, detaching is enough
if locked_grandchild.is_cancelled {
continue;
}
// For performance reasons, only adopt grandchildren that have children.
// Otherwise, just cancel them right away, no need for another iteration.
if locked_grandchild.children.is_empty() {
// Cancel the grandchild
locked_grandchild.is_cancelled = true;
locked_grandchild.children = Vec::new();
drop(locked_grandchild);
grandchild.waker.notify_waiters();
} else {
// Otherwise, adopt grandchild
locked_grandchild.parent = Some(node.clone());
locked_grandchild.parent_idx = locked_node.children.len();
drop(locked_grandchild);
locked_node.children.push(grandchild);
}
}
// Cancel the child
locked_child.is_cancelled = true;
locked_child.children = Vec::new();
drop(locked_child);
child.waker.notify_waiters();
// Now the child is cancelled and detached and all its children are adopted.
// Just continue until all (including adopted) children are cancelled and detached.
}
// Cancel the node itself.
locked_node.is_cancelled = true;
locked_node.children = Vec::new();
drop(locked_node);
node.waker.notify_waiters();
}

20
vendor/tokio-util/src/sync/mod.rs vendored Normal file
View File

@@ -0,0 +1,20 @@
//! Synchronization primitives
mod cancellation_token;
pub use cancellation_token::{
guard::DropGuard, guard_ref::DropGuardRef, CancellationToken, WaitForCancellationFuture,
WaitForCancellationFutureOwned,
};
pub(crate) use cancellation_token::{RunUntilCancelledFuture, RunUntilCancelledFutureOwned};
mod mpsc;
pub use mpsc::{PollSendError, PollSender};
mod poll_semaphore;
pub use poll_semaphore::PollSemaphore;
mod reusable_box;
pub use reusable_box::ReusableBoxFuture;
#[cfg(test)]
mod tests;

325
vendor/tokio-util/src/sync/mpsc.rs vendored Normal file
View File

@@ -0,0 +1,325 @@
use futures_sink::Sink;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem};
use tokio::sync::mpsc::OwnedPermit;
use tokio::sync::mpsc::Sender;
use super::ReusableBoxFuture;
/// Error returned by the `PollSender` when the channel is closed.
#[derive(Debug)]
pub struct PollSendError<T>(Option<T>);
impl<T> PollSendError<T> {
/// Consumes the stored value, if any.
///
/// If this error was encountered when calling `start_send`/`send_item`, this will be the item
/// that the caller attempted to send. Otherwise, it will be `None`.
pub fn into_inner(self) -> Option<T> {
self.0
}
}
impl<T> fmt::Display for PollSendError<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}
#[derive(Debug)]
enum State<T> {
Idle(Sender<T>),
Acquiring,
ReadyToSend(OwnedPermit<T>),
Closed,
}
/// A wrapper around [`mpsc::Sender`] that can be polled.
///
/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
#[derive(Debug)]
pub struct PollSender<T> {
sender: Option<Sender<T>>,
state: State<T>,
acquire: PollSenderFuture<T>,
}
// Creates a future for acquiring a permit from the underlying channel. This is used to ensure
// there's capacity for a send to complete.
//
// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to
// ReusableBoxFuture has the same underlying type, and hence the same size and alignment.
async fn make_acquire_future<T>(
data: Option<Sender<T>>,
) -> Result<OwnedPermit<T>, PollSendError<T>> {
match data {
Some(sender) => sender
.reserve_owned()
.await
.map_err(|_| PollSendError(None)),
None => unreachable!("this future should not be pollable in this state"),
}
}
type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>;
#[derive(Debug)]
// TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes
struct PollSenderFuture<T>(InnerFuture<'static, T>);
impl<T> PollSenderFuture<T> {
/// Create with an empty inner future with no `Send` bound.
fn empty() -> Self {
// We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
// compatible with the transitive bounds required by `Sender<T>`.
Self(ReusableBoxFuture::new(async { unreachable!() }))
}
}
impl<T: Send> PollSenderFuture<T> {
/// Create with an empty inner future.
fn new() -> Self {
let v = InnerFuture::new(make_acquire_future(None));
// This is safe because `make_acquire_future(None)` is actually `'static`
Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) })
}
/// Poll the inner future.
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> {
self.0.poll(cx)
}
/// Replace the inner future.
fn set(&mut self, sender: Option<Sender<T>>) {
let inner: *mut InnerFuture<'static, T> = &mut self.0;
let inner: *mut InnerFuture<'_, T> = inner.cast();
// SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T`
// becomes invalid, and this casts away the type-level lifetime check for that. However, the
// inner future is never moved out of this `PollSenderFuture<T>`, so the future will not
// live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed
// to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so
// this is ok.
let inner = unsafe { &mut *inner };
inner.set(make_acquire_future(sender));
}
}
impl<T: Send> PollSender<T> {
/// Creates a new `PollSender`.
pub fn new(sender: Sender<T>) -> Self {
Self {
sender: Some(sender.clone()),
state: State::Idle(sender),
acquire: PollSenderFuture::new(),
}
}
fn take_state(&mut self) -> State<T> {
mem::replace(&mut self.state, State::Closed)
}
/// Attempts to prepare the sender to receive a value.
///
/// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to
/// `send_item`.
///
/// This method returns `Poll::Ready` once the underlying channel is ready to receive a value,
/// by reserving a slot in the channel for the item to be sent. If this method returns
/// `Poll::Pending`, the current task is registered to be notified (via
/// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again.
///
/// # Errors
///
/// If the channel is closed, an error will be returned. This is a permanent state.
pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
loop {
let (result, next_state) = match self.take_state() {
State::Idle(sender) => {
// Start trying to acquire a permit to reserve a slot for our send, and
// immediately loop back around to poll it the first time.
self.acquire.set(Some(sender));
(None, State::Acquiring)
}
State::Acquiring => match self.acquire.poll(cx) {
// Channel has capacity.
Poll::Ready(Ok(permit)) => {
(Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
}
// Channel is closed.
Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
// Channel doesn't have capacity yet, so we need to wait.
Poll::Pending => (Some(Poll::Pending), State::Acquiring),
},
// We're closed, either by choice or because the underlying sender was closed.
s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
// We're already ready to send an item.
s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
};
self.state = next_state;
if let Some(result) = result {
return result;
}
}
}
/// Sends an item to the channel.
///
/// Before calling `send_item`, `poll_reserve` must be called with a successful return
/// value of `Poll::Ready(Ok(()))`.
///
/// # Errors
///
/// If the channel is closed, an error will be returned. This is a permanent state.
///
/// # Panics
///
/// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
/// will panic.
#[track_caller]
pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
let (result, next_state) = match self.take_state() {
State::Idle(_) | State::Acquiring => {
panic!("`send_item` called without first calling `poll_reserve`")
}
// We have a permit to send our item, so go ahead, which gets us our sender back.
State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
// We're closed, either by choice or because the underlying sender was closed.
State::Closed => (Err(PollSendError(Some(value))), State::Closed),
};
// Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
self.state = if self.sender.is_some() {
next_state
} else {
State::Closed
};
result
}
/// Checks whether this sender is closed.
///
/// The underlying channel that this sender was wrapping may still be open.
pub fn is_closed(&self) -> bool {
matches!(self.state, State::Closed) || self.sender.is_none()
}
/// Gets a reference to the `Sender` of the underlying channel.
///
/// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
/// was wrapping may still be open.
pub fn get_ref(&self) -> Option<&Sender<T>> {
self.sender.as_ref()
}
/// Closes this sender.
///
/// No more messages will be able to be sent from this sender, but the underlying channel will
/// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
///
/// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made
/// to `send_item` in order to consume the reserved slot. After that, no further sends will be
/// possible. If you do not intend to send another item, you can release the reserved slot back
/// to the underlying sender by calling [`abort_send`].
///
/// [`abort_send`]: crate::sync::PollSender::abort_send
/// [`Receiver`]: tokio::sync::mpsc::Receiver
pub fn close(&mut self) {
// Mark ourselves officially closed by dropping our main sender.
self.sender = None;
// If we're already idle, closed, or we haven't yet reserved a slot, we can quickly
// transition to the closed state. Otherwise, leave the existing permit in place for the
// caller if they want to complete the send.
match self.state {
State::Idle(_) => self.state = State::Closed,
State::Acquiring => {
self.acquire.set(None);
self.state = State::Closed;
}
_ => {}
}
}
/// Aborts the current in-progress send, if any.
///
/// Returns `true` if a send was aborted. If the sender was closed prior to calling
/// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be
/// ready to attempt another send.
pub fn abort_send(&mut self) -> bool {
// We may have been closed in the meantime, after a call to `poll_reserve` already
// succeeded. We'll check if `self.sender` is `None` to see if we should transition to the
// closed state when we actually abort a send, rather than resetting ourselves back to idle.
let (result, next_state) = match self.take_state() {
// We're currently trying to reserve a slot to send into.
State::Acquiring => {
// Replacing the future drops the in-flight one.
self.acquire.set(None);
// If we haven't closed yet, we have to clone our stored sender since we have no way
// to get it back from the acquire future we just dropped.
let state = match self.sender.clone() {
Some(sender) => State::Idle(sender),
None => State::Closed,
};
(true, state)
}
// We got the permit. If we haven't closed yet, get the sender back.
State::ReadyToSend(permit) => {
let state = if self.sender.is_some() {
State::Idle(permit.release())
} else {
State::Closed
};
(true, state)
}
s => (false, s),
};
self.state = next_state;
result
}
}
impl<T> Clone for PollSender<T> {
/// Clones this `PollSender`.
///
/// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`.
fn clone(&self) -> PollSender<T> {
let (sender, state) = match self.sender.clone() {
Some(sender) => (Some(sender.clone()), State::Idle(sender)),
None => (None, State::Closed),
};
Self {
sender,
state,
acquire: PollSenderFuture::empty(),
}
}
}
impl<T: Send> Sink<T> for PollSender<T> {
type Error = PollSendError<T>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::into_inner(self).poll_reserve(cx)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
Pin::into_inner(self).send_item(item)
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::into_inner(self).close();
Poll::Ready(Ok(()))
}
}

View File

@@ -0,0 +1,171 @@
use futures_core::Stream;
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use super::ReusableBoxFuture;
/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method.
///
/// [`Semaphore`]: tokio::sync::Semaphore
pub struct PollSemaphore {
semaphore: Arc<Semaphore>,
permit_fut: Option<(
u32, // The number of permits requested.
ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>,
)>,
}
impl PollSemaphore {
/// Create a new `PollSemaphore`.
pub fn new(semaphore: Arc<Semaphore>) -> Self {
Self {
semaphore,
permit_fut: None,
}
}
/// Closes the semaphore.
pub fn close(&self) {
self.semaphore.close();
}
/// Obtain a clone of the inner semaphore.
pub fn clone_inner(&self) -> Arc<Semaphore> {
self.semaphore.clone()
}
/// Get back the inner semaphore.
pub fn into_inner(self) -> Arc<Semaphore> {
self.semaphore
}
/// Poll to acquire a permit from the semaphore.
///
/// This can return the following values:
///
/// - `Poll::Pending` if a permit is not currently available.
/// - `Poll::Ready(Some(permit))` if a permit was acquired.
/// - `Poll::Ready(None)` if the semaphore has been closed.
///
/// When this method returns `Poll::Pending`, the current task is scheduled
/// to receive a wakeup when a permit becomes available, or when the
/// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
/// the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
self.poll_acquire_many(cx, 1)
}
/// Poll to acquire many permits from the semaphore.
///
/// This can return the following values:
///
/// - `Poll::Pending` if a permit is not currently available.
/// - `Poll::Ready(Some(permit))` if a permit was acquired.
/// - `Poll::Ready(None)` if the semaphore has been closed.
///
/// When this method returns `Poll::Pending`, the current task is scheduled
/// to receive a wakeup when the permits become available, or when the
/// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
/// the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
pub fn poll_acquire_many(
&mut self,
cx: &mut Context<'_>,
permits: u32,
) -> Poll<Option<OwnedSemaphorePermit>> {
let permit_future = match self.permit_fut.as_mut() {
Some((prev_permits, fut)) if *prev_permits == permits => fut,
Some((old_permits, fut_box)) => {
// We're requesting a different number of permits, so replace the future
// and record the new amount.
let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
fut_box.set(fut);
*old_permits = permits;
fut_box
}
None => {
// avoid allocations completely if we can grab a permit immediately
match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) {
Ok(permit) => return Poll::Ready(Some(permit)),
Err(TryAcquireError::Closed) => return Poll::Ready(None),
Err(TryAcquireError::NoPermits) => {}
}
let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
&mut self
.permit_fut
.get_or_insert((permits, ReusableBoxFuture::new(next_fut)))
.1
}
};
let result = ready!(permit_future.poll(cx));
// Assume we'll request the same amount of permits in a subsequent call.
let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
permit_future.set(next_fut);
match result {
Ok(permit) => Poll::Ready(Some(permit)),
Err(_closed) => {
self.permit_fut = None;
Poll::Ready(None)
}
}
}
/// Returns the current number of available permits.
///
/// This is equivalent to the [`Semaphore::available_permits`] method on the
/// `tokio::sync::Semaphore` type.
///
/// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
/// Adds `n` new permits to the semaphore.
///
/// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function
/// will panic if the limit is exceeded.
///
/// This is equivalent to the [`Semaphore::add_permits`] method on the
/// `tokio::sync::Semaphore` type.
///
/// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits
pub fn add_permits(&self, n: usize) {
self.semaphore.add_permits(n);
}
}
impl Stream for PollSemaphore {
type Item = OwnedSemaphorePermit;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
Pin::into_inner(self).poll_acquire(cx)
}
}
impl Clone for PollSemaphore {
fn clone(&self) -> PollSemaphore {
PollSemaphore::new(self.clone_inner())
}
}
impl fmt::Debug for PollSemaphore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PollSemaphore")
.field("semaphore", &self.semaphore)
.finish()
}
}
impl AsRef<Semaphore> for PollSemaphore {
fn as_ref(&self) -> &Semaphore {
&self.semaphore
}
}

View File

@@ -0,0 +1,157 @@
use std::alloc::Layout;
use std::fmt;
use std::future::{self, Future};
use std::mem::{self, ManuallyDrop};
use std::pin::Pin;
use std::ptr;
use std::task::{Context, Poll};
/// A reusable `Pin<Box<dyn Future<Output = T> + Send + 'a>>`.
///
/// This type lets you replace the future stored in the box without
/// reallocating when the size and alignment permits this.
pub struct ReusableBoxFuture<'a, T> {
boxed: Pin<Box<dyn Future<Output = T> + Send + 'a>>,
}
impl<'a, T> ReusableBoxFuture<'a, T> {
/// Create a new `ReusableBoxFuture<T>` containing the provided future.
pub fn new<F>(future: F) -> Self
where
F: Future<Output = T> + Send + 'a,
{
Self {
boxed: Box::pin(future),
}
}
/// Replace the future currently stored in this box.
///
/// This reallocates if and only if the layout of the provided future is
/// different from the layout of the currently stored future.
pub fn set<F>(&mut self, future: F)
where
F: Future<Output = T> + Send + 'a,
{
if let Err(future) = self.try_set(future) {
*self = Self::new(future);
}
}
/// Replace the future currently stored in this box.
///
/// This function never reallocates, but returns an error if the provided
/// future has a different size or alignment from the currently stored
/// future.
pub fn try_set<F>(&mut self, future: F) -> Result<(), F>
where
F: Future<Output = T> + Send + 'a,
{
// If we try to inline the contents of this function, the type checker complains because
// the bound `T: 'a` is not satisfied in the call to `pending()`. But by putting it in an
// inner function that doesn't have `T` as a generic parameter, we implicitly get the bound
// `F::Output: 'a` transitively through `F: 'a`, allowing us to call `pending()`.
#[inline(always)]
fn real_try_set<'a, F>(
this: &mut ReusableBoxFuture<'a, F::Output>,
future: F,
) -> Result<(), F>
where
F: Future + Send + 'a,
{
// future::Pending<T> is a ZST so this never allocates.
let boxed = mem::replace(&mut this.boxed, Box::pin(future::pending()));
reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed))
}
real_try_set(self, future)
}
/// Get a pinned reference to the underlying future.
pub fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> {
self.boxed.as_mut()
}
/// Poll the future stored inside this box.
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> {
self.get_pin().poll(cx)
}
}
impl<T> Future for ReusableBoxFuture<'_, T> {
type Output = T;
/// Poll the future stored inside this box.
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
Pin::into_inner(self).get_pin().poll(cx)
}
}
// The only method called on self.boxed is poll, which takes &mut self, so this
// struct being Sync does not permit any invalid access to the Future, even if
// the future is not Sync.
unsafe impl<T> Sync for ReusableBoxFuture<'_, T> {}
impl<T> fmt::Debug for ReusableBoxFuture<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReusableBoxFuture").finish()
}
}
fn reuse_pin_box<T: ?Sized, U, O, F>(boxed: Pin<Box<T>>, new_value: U, callback: F) -> Result<O, U>
where
F: FnOnce(Box<U>) -> O,
{
let layout = Layout::for_value::<T>(&*boxed);
if layout != Layout::new::<U>() {
return Err(new_value);
}
// SAFETY: We don't ever construct a non-pinned reference to the old `T` from now on, and we
// always drop the `T`.
let raw: *mut T = Box::into_raw(unsafe { Pin::into_inner_unchecked(boxed) });
// When dropping the old value panics, we still want to call `callback` — so move the rest of
// the code into a guard type.
let guard = CallOnDrop::new(|| {
let raw: *mut U = raw.cast::<U>();
unsafe { raw.write(new_value) };
// SAFETY:
// - `T` and `U` have the same layout.
// - `raw` comes from a `Box` that uses the same allocator as this one.
// - `raw` points to a valid instance of `U` (we just wrote it in).
let boxed = unsafe { Box::from_raw(raw) };
callback(boxed)
});
// Drop the old value.
unsafe { ptr::drop_in_place(raw) };
// Run the rest of the code.
Ok(guard.call())
}
struct CallOnDrop<O, F: FnOnce() -> O> {
f: ManuallyDrop<F>,
}
impl<O, F: FnOnce() -> O> CallOnDrop<O, F> {
fn new(f: F) -> Self {
let f = ManuallyDrop::new(f);
Self { f }
}
fn call(self) -> O {
let mut this = ManuallyDrop::new(self);
let f = unsafe { ManuallyDrop::take(&mut this.f) };
f()
}
}
impl<O, F: FnOnce() -> O> Drop for CallOnDrop<O, F> {
fn drop(&mut self) {
let f = unsafe { ManuallyDrop::take(&mut self.f) };
f();
}
}

View File

@@ -0,0 +1,185 @@
use crate::sync::CancellationToken;
use loom::{future::block_on, thread};
use tokio_test::assert_ok;
#[test]
fn cancel_token() {
loom::model(|| {
let token = CancellationToken::new();
let token1 = token.clone();
let th1 = thread::spawn(move || {
block_on(async {
token1.cancelled().await;
});
});
let th2 = thread::spawn(move || {
token.cancel();
});
assert_ok!(th1.join());
assert_ok!(th2.join());
});
}
#[test]
fn cancel_token_owned() {
loom::model(|| {
let token = CancellationToken::new();
let token1 = token.clone();
let th1 = thread::spawn(move || {
block_on(async {
token1.cancelled_owned().await;
});
});
let th2 = thread::spawn(move || {
token.cancel();
});
assert_ok!(th1.join());
assert_ok!(th2.join());
});
}
#[test]
fn cancel_with_child() {
loom::model(|| {
let token = CancellationToken::new();
let token1 = token.clone();
let token2 = token.clone();
let child_token = token.child_token();
let th1 = thread::spawn(move || {
block_on(async {
token1.cancelled().await;
});
});
let th2 = thread::spawn(move || {
token2.cancel();
});
let th3 = thread::spawn(move || {
block_on(async {
child_token.cancelled().await;
});
});
assert_ok!(th1.join());
assert_ok!(th2.join());
assert_ok!(th3.join());
});
}
#[test]
fn drop_token_no_child() {
loom::model(|| {
let token = CancellationToken::new();
let token1 = token.clone();
let token2 = token.clone();
let th1 = thread::spawn(move || {
drop(token1);
});
let th2 = thread::spawn(move || {
drop(token2);
});
let th3 = thread::spawn(move || {
drop(token);
});
assert_ok!(th1.join());
assert_ok!(th2.join());
assert_ok!(th3.join());
});
}
// Temporarily disabled due to a false positive in loom -
// see https://github.com/tokio-rs/tokio/pull/7644#issuecomment-3328381344
#[ignore]
#[test]
fn drop_token_with_children() {
loom::model(|| {
let token1 = CancellationToken::new();
let child_token1 = token1.child_token();
let child_token2 = token1.child_token();
let th1 = thread::spawn(move || {
drop(token1);
});
let th2 = thread::spawn(move || {
drop(child_token1);
});
let th3 = thread::spawn(move || {
drop(child_token2);
});
assert_ok!(th1.join());
assert_ok!(th2.join());
assert_ok!(th3.join());
});
}
// Temporarily disabled due to a false positive in loom -
// see https://github.com/tokio-rs/tokio/pull/7644#issuecomment-3328381344
#[ignore]
#[test]
fn drop_and_cancel_token() {
loom::model(|| {
let token1 = CancellationToken::new();
let token2 = token1.clone();
let child_token = token1.child_token();
let th1 = thread::spawn(move || {
drop(token1);
});
let th2 = thread::spawn(move || {
token2.cancel();
});
let th3 = thread::spawn(move || {
drop(child_token);
});
assert_ok!(th1.join());
assert_ok!(th2.join());
assert_ok!(th3.join());
});
}
// Temporarily disabled due to a false positive in loom -
// see https://github.com/tokio-rs/tokio/pull/7644#issuecomment-3328381344
#[ignore]
#[test]
fn cancel_parent_and_child() {
loom::model(|| {
let token1 = CancellationToken::new();
let token2 = token1.clone();
let child_token = token1.child_token();
let th1 = thread::spawn(move || {
drop(token1);
});
let th2 = thread::spawn(move || {
token2.cancel();
});
let th3 = thread::spawn(move || {
child_token.cancel();
});
assert_ok!(th1.join());
assert_ok!(th2.join());
assert_ok!(th3.join());
});
}

View File

@@ -0,0 +1,2 @@
#[cfg(loom)]
mod loom_cancellation_token;

View File

@@ -0,0 +1,95 @@
//! An [`AbortOnDropHandle`] is like a [`JoinHandle`], except that it
//! will abort the task as soon as it is dropped.
use tokio::task::{AbortHandle, JoinError, JoinHandle};
use std::{
future::Future,
mem::ManuallyDrop,
pin::Pin,
task::{Context, Poll},
};
/// A wrapper around a [`tokio::task::JoinHandle`],
/// which [aborts] the task when it is dropped.
///
/// [aborts]: tokio::task::JoinHandle::abort
#[must_use = "Dropping the handle aborts the task immediately"]
pub struct AbortOnDropHandle<T>(JoinHandle<T>);
impl<T> Drop for AbortOnDropHandle<T> {
fn drop(&mut self) {
self.0.abort()
}
}
impl<T> AbortOnDropHandle<T> {
/// Create an [`AbortOnDropHandle`] from a [`JoinHandle`].
pub fn new(handle: JoinHandle<T>) -> Self {
Self(handle)
}
/// Abort the task associated with this handle,
/// equivalent to [`JoinHandle::abort`].
pub fn abort(&self) {
self.0.abort()
}
/// Checks if the task associated with this handle is finished,
/// equivalent to [`JoinHandle::is_finished`].
pub fn is_finished(&self) -> bool {
self.0.is_finished()
}
/// Returns a new [`AbortHandle`] that can be used to remotely abort this task,
/// equivalent to [`JoinHandle::abort_handle`].
pub fn abort_handle(&self) -> AbortHandle {
self.0.abort_handle()
}
/// Cancels aborting on drop and returns the original [`JoinHandle`].
pub fn detach(self) -> JoinHandle<T> {
// Avoid invoking `AbortOnDropHandle`'s `Drop` impl
let this = ManuallyDrop::new(self);
// SAFETY: `&this.0` is a reference, so it is certainly initialized, and
// it won't be double-dropped because it's in a `ManuallyDrop`
unsafe { std::ptr::read(&this.0) }
}
}
impl<T> std::fmt::Debug for AbortOnDropHandle<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AbortOnDropHandle")
.field("id", &self.0.id())
.finish()
}
}
impl<T> Future for AbortOnDropHandle<T> {
type Output = Result<T, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
}
}
impl<T> AsRef<JoinHandle<T>> for AbortOnDropHandle<T> {
fn as_ref(&self) -> &JoinHandle<T> {
&self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
/// A simple type that does not implement [`std::fmt::Debug`].
struct NotDebug;
fn is_debug<T: std::fmt::Debug>() {}
#[test]
fn assert_debug() {
is_debug::<AbortOnDropHandle<NotDebug>>();
}
}

850
vendor/tokio-util/src/task/join_map.rs vendored Normal file
View File

@@ -0,0 +1,850 @@
use hashbrown::hash_table::Entry;
use hashbrown::{HashMap, HashTable};
use std::borrow::Borrow;
use std::collections::hash_map::RandomState;
use std::fmt;
use std::future::Future;
use std::hash::{BuildHasher, Hash};
use std::marker::PhantomData;
use tokio::runtime::Handle;
use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
/// A collection of tasks spawned on a Tokio runtime, associated with hash map
/// keys.
///
/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
/// addition of a set of keys associated with each task. These keys allow
/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
/// `JoinMap` based on their keys, or [test whether a task corresponding to a
/// given key exists][contains] in the `JoinMap`.
///
/// In addition, when tasks in the `JoinMap` complete, they will return the
/// associated key along with the value returned by the task, if any.
///
/// A `JoinMap` can be used to await the completion of some or all of the tasks
/// in the map. The map is not ordered, and the tasks will be returned in the
/// order they complete.
///
/// All of the tasks must have the same return type `V`.
///
/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
///
/// # Examples
///
/// Spawn multiple tasks and wait for them:
///
/// ```
/// use tokio_util::task::JoinMap;
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let mut map = JoinMap::new();
///
/// for i in 0..10 {
/// // Spawn a task on the `JoinMap` with `i` as its key.
/// map.spawn(i, async move { /* ... */ });
/// }
///
/// let mut seen = [false; 10];
///
/// // When a task completes, `join_next` returns the task's key along
/// // with its output.
/// while let Some((key, res)) = map.join_next().await {
/// seen[key] = true;
/// assert!(res.is_ok(), "task {} completed successfully!", key);
/// }
///
/// for i in 0..10 {
/// assert!(seen[i]);
/// }
/// # }
/// ```
///
/// Cancel tasks based on their keys:
///
/// ```
/// use tokio_util::task::JoinMap;
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let mut map = JoinMap::new();
///
/// map.spawn("hello world", std::future::ready(1));
/// map.spawn("goodbye world", std::future::pending());
///
/// // Look up the "goodbye world" task in the map and abort it.
/// let aborted = map.abort("goodbye world");
///
/// // `JoinMap::abort` returns `true` if a task existed for the
/// // provided key.
/// assert!(aborted);
///
/// while let Some((key, res)) = map.join_next().await {
/// if key == "goodbye world" {
/// // The aborted task should complete with a cancelled `JoinError`.
/// assert!(res.unwrap_err().is_cancelled());
/// } else {
/// // Other tasks should complete normally.
/// assert_eq!(res.unwrap(), 1);
/// }
/// }
/// # }
/// ```
///
/// [`JoinSet`]: tokio::task::JoinSet
/// [abort]: fn@Self::abort
/// [abort_matching]: fn@Self::abort_matching
/// [contains]: fn@Self::contains_key
pub struct JoinMap<K, V, S = RandomState> {
/// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
/// indexed by their keys.
tasks_by_key: HashTable<(K, AbortHandle)>,
/// A map from task IDs to the hash of the key associated with that task.
///
/// This map is used to perform reverse lookups of tasks in the
/// `tasks_by_key` map based on their task IDs. When a task terminates, the
/// ID is provided to us by the `JoinSet`, so we can look up the hash value
/// of that task's key, and then remove it from the `tasks_by_key` map using
/// the raw hash code, resolving collisions by comparing task IDs.
hashes_by_task: HashMap<Id, u64, S>,
/// The [`JoinSet`] that awaits the completion of tasks spawned on this
/// `JoinMap`.
tasks: JoinSet<V>,
}
impl<K, V> JoinMap<K, V> {
/// Creates a new empty `JoinMap`.
///
/// The `JoinMap` is initially created with a capacity of 0, so it will not
/// allocate until a task is first spawned on it.
///
/// # Examples
///
/// ```
/// use tokio_util::task::JoinMap;
/// let map: JoinMap<&str, i32> = JoinMap::new();
/// ```
#[inline]
#[must_use]
pub fn new() -> Self {
Self::with_hasher(RandomState::new())
}
/// Creates an empty `JoinMap` with the specified capacity.
///
/// The `JoinMap` will be able to hold at least `capacity` tasks without
/// reallocating.
///
/// # Examples
///
/// ```
/// use tokio_util::task::JoinMap;
/// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
/// ```
#[inline]
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
JoinMap::with_capacity_and_hasher(capacity, Default::default())
}
}
impl<K, V, S> JoinMap<K, V, S> {
/// Creates an empty `JoinMap` which will use the given hash builder to hash
/// keys.
///
/// The created map has the default initial capacity.
///
/// Warning: `hash_builder` is normally randomly generated, and
/// is designed to allow `JoinMap` to be resistant to attacks that
/// cause many collisions and very poor performance. Setting it
/// manually using this function can expose a DoS attack vector.
///
/// The `hash_builder` passed should implement the [`BuildHasher`] trait for
/// the `JoinMap` to be useful, see its documentation for details.
#[inline]
#[must_use]
pub fn with_hasher(hash_builder: S) -> Self {
Self::with_capacity_and_hasher(0, hash_builder)
}
/// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
/// to hash the keys.
///
/// The `JoinMap` will be able to hold at least `capacity` elements without
/// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
///
/// Warning: `hash_builder` is normally randomly generated, and
/// is designed to allow HashMaps to be resistant to attacks that
/// cause many collisions and very poor performance. Setting it
/// manually using this function can expose a DoS attack vector.
///
/// The `hash_builder` passed should implement the [`BuildHasher`] trait for
/// the `JoinMap`to be useful, see its documentation for details.
///
/// # Examples
///
/// ```
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// use tokio_util::task::JoinMap;
/// use std::collections::hash_map::RandomState;
///
/// let s = RandomState::new();
/// let mut map = JoinMap::with_capacity_and_hasher(10, s);
/// map.spawn(1, async move { "hello world!" });
/// # }
/// ```
#[inline]
#[must_use]
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
Self {
tasks_by_key: HashTable::with_capacity(capacity),
hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
tasks: JoinSet::new(),
}
}
/// Returns the number of tasks currently in the `JoinMap`.
pub fn len(&self) -> usize {
let len = self.tasks_by_key.len();
debug_assert_eq!(len, self.hashes_by_task.len());
len
}
/// Returns whether the `JoinMap` is empty.
pub fn is_empty(&self) -> bool {
let empty = self.tasks_by_key.is_empty();
debug_assert_eq!(empty, self.hashes_by_task.is_empty());
empty
}
/// Returns the number of tasks the map can hold without reallocating.
///
/// This number is a lower bound; the `JoinMap` might be able to hold
/// more, but is guaranteed to be able to hold at least this many.
///
/// # Examples
///
/// ```
/// use tokio_util::task::JoinMap;
///
/// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
/// assert!(map.capacity() >= 100);
/// ```
#[inline]
pub fn capacity(&self) -> usize {
let capacity = self.tasks_by_key.capacity();
debug_assert_eq!(capacity, self.hashes_by_task.capacity());
capacity
}
}
impl<K, V, S> JoinMap<K, V, S>
where
K: Hash + Eq,
V: 'static,
S: BuildHasher,
{
/// Spawn the provided task and store it in this `JoinMap` with the provided
/// key.
///
/// If a task previously existed in the `JoinMap` for this key, that task
/// will be cancelled and replaced with the new one. The previous task will
/// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
/// *not* return a cancelled [`JoinError`] for that task.
///
/// # Panics
///
/// This method panics if called outside of a Tokio runtime.
///
/// [`join_next`]: Self::join_next
#[track_caller]
pub fn spawn<F>(&mut self, key: K, task: F)
where
F: Future<Output = V>,
F: Send + 'static,
V: Send,
{
let task = self.tasks.spawn(task);
self.insert(key, task)
}
/// Spawn the provided task on the provided runtime and store it in this
/// `JoinMap` with the provided key.
///
/// If a task previously existed in the `JoinMap` for this key, that task
/// will be cancelled and replaced with the new one. The previous task will
/// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
/// *not* return a cancelled [`JoinError`] for that task.
///
/// [`join_next`]: Self::join_next
#[track_caller]
pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
where
F: Future<Output = V>,
F: Send + 'static,
V: Send,
{
let task = self.tasks.spawn_on(task, handle);
self.insert(key, task);
}
/// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
/// key.
///
/// If a task previously existed in the `JoinMap` for this key, that task
/// will be cancelled and replaced with the new one. The previous task will
/// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
/// *not* return a cancelled [`JoinError`] for that task.
///
/// Note that blocking tasks cannot be cancelled after execution starts.
/// Replaced blocking tasks will still run to completion if the task has begun
/// to execute when it is replaced. A blocking task which is replaced before
/// it has been scheduled on a blocking worker thread will be cancelled.
///
/// # Panics
///
/// This method panics if called outside of a Tokio runtime.
///
/// [`join_next`]: Self::join_next
#[track_caller]
pub fn spawn_blocking<F>(&mut self, key: K, f: F)
where
F: FnOnce() -> V,
F: Send + 'static,
V: Send,
{
let task = self.tasks.spawn_blocking(f);
self.insert(key, task)
}
/// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
/// `JoinMap` with the provided key.
///
/// If a task previously existed in the `JoinMap` for this key, that task
/// will be cancelled and replaced with the new one. The previous task will
/// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
/// *not* return a cancelled [`JoinError`] for that task.
///
/// Note that blocking tasks cannot be cancelled after execution starts.
/// Replaced blocking tasks will still run to completion if the task has begun
/// to execute when it is replaced. A blocking task which is replaced before
/// it has been scheduled on a blocking worker thread will be cancelled.
///
/// [`join_next`]: Self::join_next
#[track_caller]
pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
where
F: FnOnce() -> V,
F: Send + 'static,
V: Send,
{
let task = self.tasks.spawn_blocking_on(f, handle);
self.insert(key, task);
}
/// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`]
/// and store it in this `JoinMap` with the provided key.
///
/// If a task previously existed in the `JoinMap` for this key, that task
/// will be cancelled and replaced with the new one. The previous task will
/// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
/// *not* return a cancelled [`JoinError`] for that task.
///
/// # Panics
///
/// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
///
/// [`LocalSet`]: tokio::task::LocalSet
/// [`LocalRuntime`]: tokio::runtime::LocalRuntime
/// [`join_next`]: Self::join_next
#[track_caller]
pub fn spawn_local<F>(&mut self, key: K, task: F)
where
F: Future<Output = V>,
F: 'static,
{
let task = self.tasks.spawn_local(task);
self.insert(key, task);
}
/// Spawn the provided task on the provided [`LocalSet`] and store it in
/// this `JoinMap` with the provided key.
///
/// If a task previously existed in the `JoinMap` for this key, that task
/// will be cancelled and replaced with the new one. The previous task will
/// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
/// *not* return a cancelled [`JoinError`] for that task.
///
/// [`LocalSet`]: tokio::task::LocalSet
/// [`join_next`]: Self::join_next
#[track_caller]
pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
where
F: Future<Output = V>,
F: 'static,
{
let task = self.tasks.spawn_local_on(task, local_set);
self.insert(key, task)
}
fn insert(&mut self, mut key: K, mut abort: AbortHandle) {
let hash_builder = self.hashes_by_task.hasher();
let hash = hash_builder.hash_one(&key);
let id = abort.id();
// Insert the new key into the map of tasks by keys.
let entry =
self.tasks_by_key
.entry(hash, |(k, _)| *k == key, |(k, _)| hash_builder.hash_one(k));
match entry {
Entry::Occupied(occ) => {
// There was a previous task spawned with the same key! Cancel
// that task, and remove its ID from the map of hashes by task IDs.
(key, abort) = std::mem::replace(occ.into_mut(), (key, abort));
// Remove the old task ID.
let _prev_hash = self.hashes_by_task.remove(&abort.id());
debug_assert_eq!(Some(hash), _prev_hash);
// Associate the key's hash with the new task's ID, for looking up tasks by ID.
let _prev = self.hashes_by_task.insert(id, hash);
debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
// Note: it's important to drop `key` and abort the task here.
// This defends against any panics during drop handling for causing inconsistent state.
abort.abort();
drop(key);
}
Entry::Vacant(vac) => {
vac.insert((key, abort));
// Associate the key's hash with this task's ID, for looking up tasks by ID.
let _prev = self.hashes_by_task.insert(id, hash);
debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
}
};
}
/// Waits until one of the tasks in the map completes and returns its
/// output, along with the key corresponding to that task.
///
/// Returns `None` if the map is empty.
///
/// # Cancel Safety
///
/// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
/// statement and some other branch completes first, it is guaranteed that no tasks were
/// removed from this `JoinMap`.
///
/// # Returns
///
/// This function returns:
///
/// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
/// completed. The `value` is the return value of that ask, and `key` is
/// the key associated with the task.
/// * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
/// panicked or been aborted. `key` is the key associated with the task
/// that panicked or was aborted.
/// * `None` if the `JoinMap` is empty.
///
/// [`tokio::select!`]: tokio::select
pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
loop {
let (res, id) = match self.tasks.join_next_with_id().await {
Some(Ok((id, output))) => (Ok(output), id),
Some(Err(e)) => {
let id = e.id();
(Err(e), id)
}
None => return None,
};
if let Some(key) = self.remove_by_id(id) {
break Some((key, res));
}
}
}
/// Aborts all tasks and waits for them to finish shutting down.
///
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
/// a loop until it returns `None`.
///
/// This method ignores any panics in the tasks shutting down. When this call returns, the
/// `JoinMap` will be empty.
///
/// [`abort_all`]: fn@Self::abort_all
/// [`join_next`]: fn@Self::join_next
pub async fn shutdown(&mut self) {
self.abort_all();
while self.join_next().await.is_some() {}
}
/// Abort the task corresponding to the provided `key`.
///
/// If this `JoinMap` contains a task corresponding to `key`, this method
/// will abort that task and return `true`. Otherwise, if no task exists for
/// `key`, this method returns `false`.
///
/// # Examples
///
/// Aborting a task by key:
///
/// ```
/// use tokio_util::task::JoinMap;
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let mut map = JoinMap::new();
///
/// map.spawn("hello world", std::future::ready(1));
/// map.spawn("goodbye world", std::future::pending());
///
/// // Look up the "goodbye world" task in the map and abort it.
/// map.abort("goodbye world");
///
/// while let Some((key, res)) = map.join_next().await {
/// if key == "goodbye world" {
/// // The aborted task should complete with a cancelled `JoinError`.
/// assert!(res.unwrap_err().is_cancelled());
/// } else {
/// // Other tasks should complete normally.
/// assert_eq!(res.unwrap(), 1);
/// }
/// }
/// # }
/// ```
///
/// `abort` returns `true` if a task was aborted:
/// ```
/// use tokio_util::task::JoinMap;
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let mut map = JoinMap::new();
///
/// map.spawn("hello world", async move { /* ... */ });
/// map.spawn("goodbye world", async move { /* ... */});
///
/// // A task for the key "goodbye world" should exist in the map:
/// assert!(map.abort("goodbye world"));
///
/// // Aborting a key that does not exist will return `false`:
/// assert!(!map.abort("goodbye universe"));
/// # }
/// ```
pub fn abort<Q>(&mut self, key: &Q) -> bool
where
Q: ?Sized + Hash + Eq,
K: Borrow<Q>,
{
match self.get_by_key(key) {
Some((_, handle)) => {
handle.abort();
true
}
None => false,
}
}
/// Aborts all tasks with keys matching `predicate`.
///
/// `predicate` is a function called with a reference to each key in the
/// map. If it returns `true` for a given key, the corresponding task will
/// be cancelled.
///
/// # Examples
/// ```
/// use tokio_util::task::JoinMap;
///
/// # // use the current thread rt so that spawned tasks don't
/// # // complete in the background before they can be aborted.
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let mut map = JoinMap::new();
///
/// map.spawn("hello world", async move {
/// // ...
/// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
/// });
/// map.spawn("goodbye world", async move {
/// // ...
/// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
/// });
/// map.spawn("hello san francisco", async move {
/// // ...
/// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
/// });
/// map.spawn("goodbye universe", async move {
/// // ...
/// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
/// });
///
/// // Abort all tasks whose keys begin with "goodbye"
/// map.abort_matching(|key| key.starts_with("goodbye"));
///
/// let mut seen = 0;
/// while let Some((key, res)) = map.join_next().await {
/// seen += 1;
/// if key.starts_with("goodbye") {
/// // The aborted task should complete with a cancelled `JoinError`.
/// assert!(res.unwrap_err().is_cancelled());
/// } else {
/// // Other tasks should complete normally.
/// assert!(key.starts_with("hello"));
/// assert!(res.is_ok());
/// }
/// }
///
/// // All spawned tasks should have completed.
/// assert_eq!(seen, 4);
/// # }
/// ```
pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
// Note: this method iterates over the tasks and keys *without* removing
// any entries, so that the keys from aborted tasks can still be
// returned when calling `join_next` in the future.
for (key, task) in &self.tasks_by_key {
if predicate(key) {
task.abort();
}
}
}
/// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
///
/// If a task has completed, but its output hasn't yet been consumed by a
/// call to [`join_next`], this method will still return its key.
///
/// [`join_next`]: fn@Self::join_next
pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
JoinMapKeys {
iter: self.tasks_by_key.iter(),
_value: PhantomData,
}
}
/// Returns `true` if this `JoinMap` contains a task for the provided key.
///
/// If the task has completed, but its output hasn't yet been consumed by a
/// call to [`join_next`], this method will still return `true`.
///
/// [`join_next`]: fn@Self::join_next
pub fn contains_key<Q>(&self, key: &Q) -> bool
where
Q: ?Sized + Hash + Eq,
K: Borrow<Q>,
{
self.get_by_key(key).is_some()
}
/// Returns `true` if this `JoinMap` contains a task with the provided
/// [task ID].
///
/// If the task has completed, but its output hasn't yet been consumed by a
/// call to [`join_next`], this method will still return `true`.
///
/// [`join_next`]: fn@Self::join_next
/// [task ID]: tokio::task::Id
pub fn contains_task(&self, task: &Id) -> bool {
self.hashes_by_task.contains_key(task)
}
/// Reserves capacity for at least `additional` more tasks to be spawned
/// on this `JoinMap` without reallocating for the map of task keys. The
/// collection may reserve more space to avoid frequent reallocations.
///
/// Note that spawning a task will still cause an allocation for the task
/// itself.
///
/// # Panics
///
/// Panics if the new allocation size overflows [`usize`].
///
/// # Examples
///
/// ```
/// use tokio_util::task::JoinMap;
///
/// let mut map: JoinMap<&str, i32> = JoinMap::new();
/// map.reserve(10);
/// ```
#[inline]
pub fn reserve(&mut self, additional: usize) {
self.tasks_by_key.reserve(additional, |(k, _)| {
self.hashes_by_task.hasher().hash_one(k)
});
self.hashes_by_task.reserve(additional);
}
/// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
/// down as much as possible while maintaining the internal rules
/// and possibly leaving some space in accordance with the resize policy.
///
/// # Examples
///
/// ```
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// use tokio_util::task::JoinMap;
///
/// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
/// map.spawn(1, async move { 2 });
/// map.spawn(3, async move { 4 });
/// assert!(map.capacity() >= 100);
/// map.shrink_to_fit();
/// assert!(map.capacity() >= 2);
/// # }
/// ```
#[inline]
pub fn shrink_to_fit(&mut self) {
self.hashes_by_task.shrink_to_fit();
self.tasks_by_key
.shrink_to_fit(|(k, _)| self.hashes_by_task.hasher().hash_one(k));
}
/// Shrinks the capacity of the map with a lower limit. It will drop
/// down no lower than the supplied limit while maintaining the internal rules
/// and possibly leaving some space in accordance with the resize policy.
///
/// If the current capacity is less than the lower limit, this is a no-op.
///
/// # Examples
///
/// ```
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// use tokio_util::task::JoinMap;
///
/// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
/// map.spawn(1, async move { 2 });
/// map.spawn(3, async move { 4 });
/// assert!(map.capacity() >= 100);
/// map.shrink_to(10);
/// assert!(map.capacity() >= 10);
/// map.shrink_to(0);
/// assert!(map.capacity() >= 2);
/// # }
/// ```
#[inline]
pub fn shrink_to(&mut self, min_capacity: usize) {
self.hashes_by_task.shrink_to(min_capacity);
self.tasks_by_key.shrink_to(min_capacity, |(k, _)| {
self.hashes_by_task.hasher().hash_one(k)
})
}
/// Look up a task in the map by its key, returning the key and abort handle.
fn get_by_key<'map, Q>(&'map self, key: &Q) -> Option<&'map (K, AbortHandle)>
where
Q: ?Sized + Hash + Eq,
K: Borrow<Q>,
{
let hash = self.hashes_by_task.hasher().hash_one(key);
self.tasks_by_key.find(hash, |(k, _)| k.borrow() == key)
}
/// Remove a task from the map by ID, returning the key for that task.
fn remove_by_id(&mut self, id: Id) -> Option<K> {
// Get the hash for the given ID.
let hash = self.hashes_by_task.remove(&id)?;
// Remove the entry for that hash.
let entry = self
.tasks_by_key
.find_entry(hash, |(_, abort)| abort.id() == id);
let (key, _) = match entry {
Ok(entry) => entry.remove().0,
_ => return None,
};
Some(key)
}
}
impl<K, V, S> JoinMap<K, V, S>
where
V: 'static,
{
/// Aborts all tasks on this `JoinMap`.
///
/// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
/// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
pub fn abort_all(&mut self) {
self.tasks.abort_all()
}
/// Removes all tasks from this `JoinMap` without aborting them.
///
/// The tasks removed by this call will continue to run in the background even if the `JoinMap`
/// is dropped. They may still be aborted by key.
pub fn detach_all(&mut self) {
self.tasks.detach_all();
self.tasks_by_key.clear();
self.hashes_by_task.clear();
}
}
// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
// Debug`, since no value is ever actually stored in the map.
impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// format the task keys and abort handles a little nicer by just
// printing the key and task ID pairs, without format the `Key` struct
// itself or the `AbortHandle`, which would just format the task's ID
// again.
struct KeySet<'a, K: fmt::Debug>(&'a HashTable<(K, AbortHandle)>);
impl<K: fmt::Debug> fmt::Debug for KeySet<'_, K> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map()
.entries(self.0.iter().map(|(key, abort)| (key, abort.id())))
.finish()
}
}
f.debug_struct("JoinMap")
// The `tasks_by_key` map is the only one that contains information
// that's really worth formatting for the user, since it contains
// the tasks' keys and IDs. The other fields are basically
// implementation details.
.field("tasks", &KeySet(&self.tasks_by_key))
.finish()
}
}
impl<K, V> Default for JoinMap<K, V> {
fn default() -> Self {
Self::new()
}
}
/// An iterator over the keys of a [`JoinMap`].
#[derive(Debug, Clone)]
pub struct JoinMapKeys<'a, K, V> {
iter: hashbrown::hash_table::Iter<'a, (K, AbortHandle)>,
/// To make it easier to change `JoinMap` in the future, keep V as a generic
/// parameter.
_value: PhantomData<&'a V>,
}
impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
type Item = &'a K;
fn next(&mut self) -> Option<&'a K> {
self.iter.next().map(|(key, _)| key)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
fn len(&self) -> usize {
self.iter.len()
}
}
impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}

425
vendor/tokio-util/src/task/join_queue.rs vendored Normal file
View File

@@ -0,0 +1,425 @@
use super::AbortOnDropHandle;
use std::{
collections::VecDeque,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::{
runtime::Handle,
task::{AbortHandle, Id, JoinError, JoinHandle},
};
/// A FIFO queue for of tasks spawned on a Tokio runtime.
///
/// A [`JoinQueue`] can be used to await the completion of the tasks in FIFO
/// order. That is, if tasks are spawned in the order A, B, C, then
/// awaiting the next completed task will always return A first, then B,
/// then C, regardless of the order in which the tasks actually complete.
///
/// All of the tasks must have the same return type `T`.
///
/// When the [`JoinQueue`] is dropped, all tasks in the [`JoinQueue`] are
/// immediately aborted.
pub struct JoinQueue<T>(VecDeque<AbortOnDropHandle<T>>);
impl<T> JoinQueue<T> {
/// Create a new empty [`JoinQueue`].
pub const fn new() -> Self {
Self(VecDeque::new())
}
/// Creates an empty [`JoinQueue`] with space for at least `capacity` tasks.
pub fn with_capacity(capacity: usize) -> Self {
Self(VecDeque::with_capacity(capacity))
}
/// Returns the number of tasks currently in the [`JoinQueue`].
///
/// This includes both tasks that are currently running and tasks that have
/// completed but not yet been removed from the queue because outputting of
/// them waits for FIFO order.
pub fn len(&self) -> usize {
self.0.len()
}
/// Returns whether the [`JoinQueue`] is empty.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Spawn the provided task on the [`JoinQueue`], returning an [`AbortHandle`]
/// that can be used to remotely cancel the task.
///
/// The provided future will start running in the background immediately
/// when this method is called, even if you don't await anything on this
/// [`JoinQueue`].
///
/// # Panics
///
/// This method panics if called outside of a Tokio runtime.
///
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.push_back(tokio::spawn(task))
}
/// Spawn the provided task on the provided runtime and store it in this
/// [`JoinQueue`] returning an [`AbortHandle`] that can be used to remotely
/// cancel the task.
///
/// The provided future will start running in the background immediately
/// when this method is called, even if you don't await anything on this
/// [`JoinQueue`].
///
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.push_back(handle.spawn(task))
}
/// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`]
/// and store it in this [`JoinQueue`], returning an [`AbortHandle`] that
/// can be used to remotely cancel the task.
///
/// The provided future will start running in the background immediately
/// when this method is called, even if you don't await anything on this
/// [`JoinQueue`].
///
/// # Panics
///
/// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
///
/// [`LocalSet`]: tokio::task::LocalSet
/// [`LocalRuntime`]: tokio::runtime::LocalRuntime
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T> + 'static,
T: 'static,
{
self.push_back(tokio::task::spawn_local(task))
}
/// Spawn the blocking code on the blocking threadpool and store
/// it in this [`JoinQueue`], returning an [`AbortHandle`] that can be
/// used to remotely cancel the task.
///
/// # Panics
///
/// This method panics if called outside of a Tokio runtime.
///
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self.push_back(tokio::task::spawn_blocking(f))
}
/// Spawn the blocking code on the blocking threadpool of the
/// provided runtime and store it in this [`JoinQueue`], returning an
/// [`AbortHandle`] that can be used to remotely cancel the task.
///
/// [`AbortHandle`]: tokio::task::AbortHandle
#[track_caller]
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self.push_back(handle.spawn_blocking(f))
}
fn push_back(&mut self, jh: JoinHandle<T>) -> AbortHandle {
let jh = AbortOnDropHandle::new(jh);
let abort_handle = jh.abort_handle();
self.0.push_back(jh);
abort_handle
}
/// Waits until the next task in FIFO order completes and returns its output.
///
/// Returns `None` if the queue is empty.
///
/// # Cancel Safety
///
/// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
/// statement and some other branch completes first, it is guaranteed that no tasks were
/// removed from this [`JoinQueue`].
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
std::future::poll_fn(|cx| self.poll_join_next(cx)).await
}
/// Waits until the next task in FIFO order completes and returns its output,
/// along with the [task ID] of the completed task.
///
/// Returns `None` if the queue is empty.
///
/// When this method returns an error, then the id of the task that failed can be accessed
/// using the [`JoinError::id`] method.
///
/// # Cancel Safety
///
/// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
/// statement and some other branch completes first, it is guaranteed that no tasks were
/// removed from this [`JoinQueue`].
///
/// [task ID]: tokio::task::Id
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
}
/// Tries to poll an `AbortOnDropHandle` without blocking or yielding.
///
/// Note that on success the handle will panic on subsequent polls
/// since it becomes consumed.
fn try_poll_handle(jh: &mut AbortOnDropHandle<T>) -> Option<Result<T, JoinError>> {
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);
// Since this function is not async and cannot be forced to yield, we should
// disable budgeting when we want to check for the `JoinHandle` readiness.
let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh));
if let Poll::Ready(res) = jh.poll(&mut cx) {
Some(res)
} else {
None
}
}
/// Tries to join the next task in FIFO order if it has completed.
///
/// Returns `None` if the queue is empty or if the next task is not yet ready.
pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
let jh = self.0.front_mut()?;
let res = Self::try_poll_handle(jh)?;
// Use `detach` to avoid calling `abort` on a task that has already completed.
// Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
// we only need to drop the `JoinHandle` for cleanup.
drop(self.0.pop_front().unwrap().detach());
Some(res)
}
/// Tries to join the next task in FIFO order if it has completed and return its output,
/// along with its [task ID].
///
/// Returns `None` if the queue is empty or if the next task is not yet ready.
///
/// When this method returns an error, then the id of the task that failed can be accessed
/// using the [`JoinError::id`] method.
///
/// [task ID]: tokio::task::Id
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
let jh = self.0.front_mut()?;
let res = Self::try_poll_handle(jh)?;
// Use `detach` to avoid calling `abort` on a task that has already completed.
// Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
// we only need to drop the `JoinHandle` for cleanup.
let jh = self.0.pop_front().unwrap().detach();
let id = jh.id();
drop(jh);
Some(res.map(|output| (id, output)))
}
/// Aborts all tasks and waits for them to finish shutting down.
///
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
/// a loop until it returns `None`.
///
/// This method ignores any panics in the tasks shutting down. When this call returns, the
/// [`JoinQueue`] will be empty.
///
/// [`abort_all`]: fn@Self::abort_all
/// [`join_next`]: fn@Self::join_next
pub async fn shutdown(&mut self) {
self.abort_all();
while self.join_next().await.is_some() {}
}
/// Awaits the completion of all tasks in this [`JoinQueue`], returning a vector of their results.
///
/// The results will be stored in the order they were spawned, not the order they completed.
/// This is a convenience method that is equivalent to calling [`join_next`] in
/// a loop. If any tasks on the [`JoinQueue`] fail with an [`JoinError`], then this call
/// to `join_all` will panic and all remaining tasks on the [`JoinQueue`] are
/// cancelled. To handle errors in any other way, manually call [`join_next`]
/// in a loop.
///
/// # Cancel Safety
///
/// This method is not cancel safe as it calls `join_next` in a loop. If you need
/// cancel safety, manually call `join_next` in a loop with `Vec` accumulator.
///
/// [`join_next`]: fn@Self::join_next
/// [`JoinError::id`]: fn@tokio::task::JoinError::id
pub async fn join_all(mut self) -> Vec<T> {
let mut output = Vec::with_capacity(self.len());
while let Some(res) = self.join_next().await {
match res {
Ok(t) => output.push(t),
Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
Err(err) => panic!("{err}"),
}
}
output
}
/// Aborts all tasks on this [`JoinQueue`].
///
/// This does not remove the tasks from the [`JoinQueue`]. To wait for the tasks to complete
/// cancellation, you should call `join_next` in a loop until the [`JoinQueue`] is empty.
pub fn abort_all(&mut self) {
self.0.iter().for_each(|jh| jh.abort());
}
/// Removes all tasks from this [`JoinQueue`] without aborting them.
///
/// The tasks removed by this call will continue to run in the background even if the [`JoinQueue`]
/// is dropped.
pub fn detach_all(&mut self) {
self.0.drain(..).for_each(|jh| drop(jh.detach()));
}
/// Polls for the next task in [`JoinQueue`] to complete.
///
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
///
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
/// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
///
/// # Returns
///
/// This function returns:
///
/// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
/// available right now.
/// * `Poll::Ready(Some(Ok(value)))` if the next task in this [`JoinQueue`] has completed.
/// The `value` is the return value that task.
/// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
/// * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
let jh = match self.0.front_mut() {
None => return Poll::Ready(None),
Some(jh) => jh,
};
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
// Use `detach` to avoid calling `abort` on a task that has already completed.
// Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
// we only need to drop the `JoinHandle` for cleanup.
drop(self.0.pop_front().unwrap().detach());
Poll::Ready(Some(res))
} else {
Poll::Pending
}
}
/// Polls for the next task in [`JoinQueue`] to complete.
///
/// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
///
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
/// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
/// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
///
/// # Returns
///
/// This function returns:
///
/// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
/// available right now.
/// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this [`JoinQueue`] has completed.
/// The `value` is the return value that task, and `id` is its [task ID].
/// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
/// aborted. The `err` is the `JoinError` from the panicked/aborted task.
/// * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
///
/// [task ID]: tokio::task::Id
pub fn poll_join_next_with_id(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Id, T), JoinError>>> {
let jh = match self.0.front_mut() {
None => return Poll::Ready(None),
Some(jh) => jh,
};
if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
// Use `detach` to avoid calling `abort` on a task that has already completed.
// Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
// we only need to drop the `JoinHandle` for cleanup.
let jh = self.0.pop_front().unwrap().detach();
let id = jh.id();
drop(jh);
// If the task succeeded, add the task ID to the output. Otherwise, the
// `JoinError` will already have the task's ID.
Poll::Ready(Some(res.map(|output| (id, output))))
} else {
Poll::Pending
}
}
}
impl<T> std::fmt::Debug for JoinQueue<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_list()
.entries(self.0.iter().map(|jh| JoinHandle::id(jh.as_ref())))
.finish()
}
}
impl<T> Default for JoinQueue<T> {
fn default() -> Self {
Self::new()
}
}
/// Collect an iterator of futures into a [`JoinQueue`].
///
/// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator.
impl<T, F> std::iter::FromIterator<F> for JoinQueue<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
let mut set = Self::new();
iter.into_iter().for_each(|task| {
set.spawn(task);
});
set
}
}
#[cfg(test)]
mod tests {
use super::*;
/// A simple type that does not implement [`std::fmt::Debug`].
struct NotDebug;
fn is_debug<T: std::fmt::Debug>() {}
#[test]
fn assert_debug() {
is_debug::<JoinQueue<NotDebug>>();
}
}

25
vendor/tokio-util/src/task/mod.rs vendored Normal file
View File

@@ -0,0 +1,25 @@
//! Extra utilities for spawning tasks
//!
//! This module is only available when the `rt` feature is enabled. Note that enabling the
//! `join-map` feature will automatically also enable the `rt` feature.
cfg_rt! {
mod spawn_pinned;
pub use spawn_pinned::LocalPoolHandle;
pub mod task_tracker;
#[doc(inline)]
pub use task_tracker::TaskTracker;
mod abort_on_drop;
pub use abort_on_drop::AbortOnDropHandle;
mod join_queue;
pub use join_queue::JoinQueue;
}
#[cfg(feature = "join-map")]
mod join_map;
#[cfg(feature = "join-map")]
#[cfg_attr(docsrs, doc(cfg(feature = "join-map")))]
pub use join_map::{JoinMap, JoinMapKeys};

View File

@@ -0,0 +1,445 @@
use futures_util::future::{AbortHandle, Abortable};
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::runtime::Builder;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::task::{spawn_local, JoinHandle, LocalSet};
/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
///
/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
/// execute on the same thread) inside the Future you supply to the various spawn methods
/// of `LocalPoolHandle`.
///
/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
///
/// # Examples
///
/// ```
/// # #[cfg(not(target_family = "wasm"))]
/// # {
/// use std::rc::Rc;
/// use tokio::task;
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main(flavor = "current_thread")]
/// async fn main() {
/// let pool = LocalPoolHandle::new(5);
///
/// let output = pool.spawn_pinned(|| {
/// // `data` is !Send + !Sync
/// let data = Rc::new("local data");
/// let data_clone = data.clone();
///
/// async move {
/// task::spawn_local(async move {
/// println!("{}", data_clone);
/// });
///
/// data.to_string()
/// }
/// }).await.unwrap();
/// println!("output: {}", output);
/// }
/// # }
/// ```
///
#[derive(Clone)]
pub struct LocalPoolHandle {
pool: Arc<LocalPool>,
}
impl LocalPoolHandle {
/// Create a new pool of threads to handle `!Send` tasks. Spawn tasks onto this
/// pool via [`LocalPoolHandle::spawn_pinned`].
///
/// # Panics
///
/// Panics if the pool size is less than one.
#[track_caller]
pub fn new(pool_size: usize) -> LocalPoolHandle {
assert!(pool_size > 0);
let workers = (0..pool_size)
.map(|_| LocalWorkerHandle::new_worker())
.collect();
let pool = Arc::new(LocalPool { workers });
LocalPoolHandle { pool }
}
/// Returns the number of threads of the Pool.
#[inline]
pub fn num_threads(&self) -> usize {
self.pool.workers.len()
}
/// Returns the number of tasks scheduled on each worker. The indices of the
/// worker threads correspond to the indices of the returned `Vec`.
pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
self.pool
.workers
.iter()
.map(|worker| worker.task_count.load(Ordering::SeqCst))
.collect::<Vec<_>>()
}
/// Spawn a task onto a worker thread and pin it there so it can't be moved
/// off of the thread. Note that the future is not [`Send`], but the
/// [`FnOnce`] which creates it is.
///
/// # Examples
/// ```
/// # #[cfg(not(target_family = "wasm"))]
/// # {
/// use std::rc::Rc;
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main]
/// async fn main() {
/// // Create the local pool
/// let pool = LocalPoolHandle::new(1);
///
/// // Spawn a !Send future onto the pool and await it
/// let output = pool
/// .spawn_pinned(|| {
/// // Rc is !Send + !Sync
/// let local_data = Rc::new("test");
///
/// // This future holds an Rc, so it is !Send
/// async move { local_data.to_string() }
/// })
/// .await
/// .unwrap();
///
/// assert_eq!(output, "test");
/// }
/// # }
/// ```
pub fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool
.spawn_pinned(create_task, WorkerChoice::LeastBurdened)
}
/// Differs from `spawn_pinned` only in that you can choose a specific worker thread
/// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
/// number of tasks scheduled.
///
/// A worker thread is chosen by index. Indices are 0 based and the largest index
/// is given by `num_threads() - 1`
///
/// # Panics
///
/// This method panics if the index is out of bounds.
///
/// # Examples
///
/// This method can be used to spawn a task on all worker threads of the pool:
///
/// ```
/// # #[cfg(not(target_family = "wasm"))]
/// # {
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main]
/// async fn main() {
/// const NUM_WORKERS: usize = 3;
/// let pool = LocalPoolHandle::new(NUM_WORKERS);
/// let handles = (0..pool.num_threads())
/// .map(|worker_idx| {
/// pool.spawn_pinned_by_idx(
/// || {
/// async {
/// "test"
/// }
/// },
/// worker_idx,
/// )
/// })
/// .collect::<Vec<_>>();
///
/// for handle in handles {
/// handle.await.unwrap();
/// }
/// }
/// # }
/// ```
///
#[track_caller]
pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool
.spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
}
}
impl Debug for LocalPoolHandle {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("LocalPoolHandle")
}
}
enum WorkerChoice {
LeastBurdened,
ByIdx(usize),
}
struct LocalPool {
workers: Box<[LocalWorkerHandle]>,
}
impl LocalPool {
/// Spawn a `?Send` future onto a worker
#[track_caller]
fn spawn_pinned<F, Fut>(
&self,
create_task: F,
worker_choice: WorkerChoice,
) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
let (sender, receiver) = oneshot::channel();
let (worker, job_guard) = match worker_choice {
WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
};
let worker_spawner = worker.spawner.clone();
// Spawn a future onto the worker's runtime so we can immediately return
// a join handle.
worker.runtime_handle.spawn(async move {
// Move the job guard into the task
let _job_guard = job_guard;
// Propagate aborts via Abortable/AbortHandle
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let _abort_guard = AbortGuard(abort_handle);
// Inside the future we can't run spawn_local yet because we're not
// in the context of a LocalSet. We need to send create_task to the
// LocalSet task for spawning.
let spawn_task = Box::new(move || {
// Once we're in the LocalSet context we can call spawn_local
let join_handle =
spawn_local(
async move { Abortable::new(create_task(), abort_registration).await },
);
// Send the join handle back to the spawner. If sending fails,
// we assume the parent task was canceled, so cancel this task
// as well.
if let Err(join_handle) = sender.send(join_handle) {
join_handle.abort()
}
});
// Send the callback to the LocalSet task
if let Err(e) = worker_spawner.send(spawn_task) {
// Propagate the error as a panic in the join handle.
panic!("Failed to send job to worker: {e}");
}
// Wait for the task's join handle
let join_handle = match receiver.await {
Ok(handle) => handle,
Err(e) => {
// We sent the task successfully, but failed to get its
// join handle... We assume something happened to the worker
// and the task was not spawned. Propagate the error as a
// panic in the join handle.
panic!("Worker failed to send join handle: {e}");
}
};
// Wait for the task to complete
let join_result = join_handle.await;
match join_result {
Ok(Ok(output)) => output,
Ok(Err(_)) => {
// Pinned task was aborted. But that only happens if this
// task is aborted. So this is an impossible branch.
unreachable!(
"Reaching this branch means this task was previously \
aborted but it continued running anyways"
)
}
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else if e.is_cancelled() {
// No one else should have the join handle, so this is
// unexpected. Forward this error as a panic in the join
// handle.
panic!("spawn_pinned task was canceled: {e}");
} else {
// Something unknown happened (not a panic or
// cancellation). Forward this error as a panic in the
// join handle.
panic!("spawn_pinned task failed: {e}");
}
}
}
})
}
/// Find the worker with the least number of tasks, increment its task
/// count, and return its handle. Make sure to actually spawn a task on
/// the worker so the task count is kept consistent with load.
///
/// A job count guard is also returned to ensure the task count gets
/// decremented when the job is done.
fn find_and_incr_least_burdened_worker(&self) -> (&LocalWorkerHandle, JobCountGuard) {
loop {
let (worker, task_count) = self
.workers
.iter()
.map(|worker| (worker, worker.task_count.load(Ordering::SeqCst)))
.min_by_key(|&(_, count)| count)
.expect("There must be more than one worker");
// Make sure the task count hasn't changed since when we choose this
// worker. Otherwise, restart the search.
if worker
.task_count
.compare_exchange(
task_count,
task_count + 1,
Ordering::SeqCst,
Ordering::Relaxed,
)
.is_ok()
{
return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
}
}
}
#[track_caller]
fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
let worker = &self.workers[idx];
worker.task_count.fetch_add(1, Ordering::SeqCst);
(worker, JobCountGuard(Arc::clone(&worker.task_count)))
}
}
/// Automatically decrements a worker's job count when a job finishes (when
/// this gets dropped).
struct JobCountGuard(Arc<AtomicUsize>);
impl Drop for JobCountGuard {
fn drop(&mut self) {
// Decrement the job count
let previous_value = self.0.fetch_sub(1, Ordering::SeqCst);
debug_assert!(previous_value >= 1);
}
}
/// Calls abort on the handle when dropped.
struct AbortGuard(AbortHandle);
impl Drop for AbortGuard {
fn drop(&mut self) {
self.0.abort();
}
}
type PinnedFutureSpawner = Box<dyn FnOnce() + Send + 'static>;
struct LocalWorkerHandle {
runtime_handle: tokio::runtime::Handle,
spawner: UnboundedSender<PinnedFutureSpawner>,
task_count: Arc<AtomicUsize>,
}
impl LocalWorkerHandle {
/// Create a new worker for executing pinned tasks
fn new_worker() -> LocalWorkerHandle {
let (sender, receiver) = unbounded_channel();
let runtime = Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to start a pinned worker thread runtime");
let runtime_handle = runtime.handle().clone();
let task_count = Arc::new(AtomicUsize::new(0));
let task_count_clone = Arc::clone(&task_count);
std::thread::spawn(|| Self::run(runtime, receiver, task_count_clone));
LocalWorkerHandle {
runtime_handle,
spawner: sender,
task_count,
}
}
fn run(
runtime: tokio::runtime::Runtime,
mut task_receiver: UnboundedReceiver<PinnedFutureSpawner>,
task_count: Arc<AtomicUsize>,
) {
let local_set = LocalSet::new();
local_set.block_on(&runtime, async {
while let Some(spawn_task) = task_receiver.recv().await {
// Calls spawn_local(future)
(spawn_task)();
}
});
// If there are any tasks on the runtime associated with a LocalSet task
// that has already completed, but whose output has not yet been
// reported, let that task complete.
//
// Since the task_count is decremented when the runtime task exits,
// reading that counter lets us know if any such tasks completed during
// the call to `block_on`.
//
// Tasks on the LocalSet can't complete during this loop since they're
// stored on the LocalSet and we aren't accessing it.
let mut previous_task_count = task_count.load(Ordering::SeqCst);
loop {
// This call will also run tasks spawned on the runtime.
runtime.block_on(tokio::task::yield_now());
let new_task_count = task_count.load(Ordering::SeqCst);
if new_task_count == previous_task_count {
break;
} else {
previous_task_count = new_task_count;
}
}
// It's now no longer possible for a task on the runtime to be
// associated with a LocalSet task that has completed. Drop both the
// LocalSet and runtime to let tasks on the runtime be cancelled if and
// only if they are still on the LocalSet.
//
// Drop the LocalSet task first so that anyone awaiting the runtime
// JoinHandle will see the cancelled error after the LocalSet task
// destructor has completed.
drop(local_set);
drop(runtime);
}
}

View File

@@ -0,0 +1,730 @@
//! Types related to the [`TaskTracker`] collection.
//!
//! See the documentation of [`TaskTracker`] for more information.
use pin_project_lite::pin_project;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::{futures::Notified, Notify};
#[cfg(feature = "rt")]
use tokio::{
runtime::Handle,
task::{JoinHandle, LocalSet},
};
/// A task tracker used for waiting until tasks exit.
///
/// This is usually used together with [`CancellationToken`] to implement [graceful shutdown]. The
/// `CancellationToken` is used to signal to tasks that they should shut down, and the
/// `TaskTracker` is used to wait for them to finish shutting down.
///
/// The `TaskTracker` will also keep track of a `closed` boolean. This is used to handle the case
/// where the `TaskTracker` is empty, but we don't want to shut down yet. This means that the
/// [`wait`] method will wait until *both* of the following happen at the same time:
///
/// * The `TaskTracker` must be closed using the [`close`] method.
/// * The `TaskTracker` must be empty, that is, all tasks that it is tracking must have exited.
///
/// When a call to [`wait`] returns, it is guaranteed that all tracked tasks have exited and that
/// the destructor of the future has finished running. However, there might be a short amount of
/// time where [`JoinHandle::is_finished`] returns false.
///
/// # Comparison to `JoinSet`
///
/// The main Tokio crate has a similar collection known as [`JoinSet`]. The `JoinSet` type has a
/// lot more features than `TaskTracker`, so `TaskTracker` should only be used when one of its
/// unique features is required:
///
/// 1. When tasks exit, a `TaskTracker` will allow the task to immediately free its memory.
/// 2. By not closing the `TaskTracker`, [`wait`] will be prevented from returning even if
/// the `TaskTracker` is empty.
/// 3. A `TaskTracker` does not require mutable access to insert tasks.
/// 4. A `TaskTracker` can be cloned to share it with many tasks.
///
/// The first point is the most important one. A [`JoinSet`] keeps track of the return value of
/// every inserted task. This means that if the caller keeps inserting tasks and never calls
/// [`join_next`], then their return values will keep building up and consuming memory, _even if_
/// most of the tasks have already exited. This can cause the process to run out of memory. With a
/// `TaskTracker`, this does not happen. Once tasks exit, they are immediately removed from the
/// `TaskTracker`.
///
/// Note that unlike [`JoinSet`], dropping a `TaskTracker` does not abort the tasks.
///
/// # Examples
///
/// For more examples, please see the topic page on [graceful shutdown].
///
/// ## Spawn tasks and wait for them to exit
///
/// This is a simple example. For this case, [`JoinSet`] should probably be used instead.
///
/// ```
/// use tokio_util::task::TaskTracker;
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let tracker = TaskTracker::new();
///
/// for i in 0..10 {
/// tracker.spawn(async move {
/// println!("Task {} is running!", i);
/// });
/// }
/// // Once we spawned everything, we close the tracker.
/// tracker.close();
///
/// // Wait for everything to finish.
/// tracker.wait().await;
///
/// println!("This is printed after all of the tasks.");
/// # }
/// ```
///
/// ## Wait for tasks to exit
///
/// This example shows the intended use-case of `TaskTracker`. It is used together with
/// [`CancellationToken`] to implement graceful shutdown.
/// ```
/// use tokio_util::sync::CancellationToken;
/// use tokio_util::task::TaskTracker;
/// use tokio_util::time::FutureExt;
///
/// use tokio::time::{self, Duration};
///
/// async fn background_task(num: u64) {
/// for i in 0..10 {
/// time::sleep(Duration::from_millis(100*num)).await;
/// println!("Background task {} in iteration {}.", num, i);
/// }
/// }
///
/// #[tokio::main]
/// # async fn _hidden() {}
/// # #[tokio::main(flavor = "current_thread", start_paused = true)]
/// async fn main() {
/// let tracker = TaskTracker::new();
/// let token = CancellationToken::new();
///
/// for i in 0..10 {
/// let token = token.clone();
/// tracker.spawn(async move {
/// // Use a `with_cancellation_token_owned` to kill the background task
/// // if the token is cancelled.
/// match background_task(i)
/// .with_cancellation_token_owned(token)
/// .await
/// {
/// Some(()) => println!("Task {} exiting normally.", i),
/// None => {
/// // Do some cleanup before we really exit.
/// time::sleep(Duration::from_millis(50)).await;
/// println!("Task {} finished cleanup.", i);
/// }
/// }
/// });
/// }
///
/// // Spawn a background task that will send the shutdown signal.
/// {
/// let tracker = tracker.clone();
/// tokio::spawn(async move {
/// // Normally you would use something like ctrl-c instead of
/// // sleeping.
/// time::sleep(Duration::from_secs(2)).await;
/// tracker.close();
/// token.cancel();
/// });
/// }
///
/// // Wait for all tasks to exit.
/// tracker.wait().await;
///
/// println!("All tasks have exited now.");
/// }
/// ```
///
/// [`CancellationToken`]: crate::sync::CancellationToken
/// [`JoinHandle::is_finished`]: tokio::task::JoinHandle::is_finished
/// [`JoinSet`]: tokio::task::JoinSet
/// [`close`]: Self::close
/// [`join_next`]: tokio::task::JoinSet::join_next
/// [`wait`]: Self::wait
/// [graceful shutdown]: https://tokio.rs/tokio/topics/shutdown
pub struct TaskTracker {
inner: Arc<TaskTrackerInner>,
}
/// Represents a task tracked by a [`TaskTracker`].
#[must_use]
#[derive(Debug)]
pub struct TaskTrackerToken {
task_tracker: TaskTracker,
}
struct TaskTrackerInner {
/// Keeps track of the state.
///
/// The lowest bit is whether the task tracker is closed.
///
/// The rest of the bits count the number of tracked tasks.
state: AtomicUsize,
/// Used to notify when the last task exits.
on_last_exit: Notify,
}
pin_project! {
/// A future that is tracked as a task by a [`TaskTracker`].
///
/// The associated [`TaskTracker`] cannot complete until this future is dropped.
///
/// This future is returned by [`TaskTracker::track_future`].
#[must_use = "futures do nothing unless polled"]
pub struct TrackedFuture<F> {
#[pin]
future: F,
token: TaskTrackerToken,
}
}
pin_project! {
/// A future that completes when the [`TaskTracker`] is empty and closed.
///
/// This future is returned by [`TaskTracker::wait`].
#[must_use = "futures do nothing unless polled"]
pub struct TaskTrackerWaitFuture<'a> {
#[pin]
future: Notified<'a>,
inner: Option<&'a TaskTrackerInner>,
}
}
impl TaskTrackerInner {
#[inline]
fn new() -> Self {
Self {
state: AtomicUsize::new(0),
on_last_exit: Notify::new(),
}
}
#[inline]
fn is_closed_and_empty(&self) -> bool {
// If empty and closed bit set, then we are done.
//
// The acquire load will synchronize with the release store of any previous call to
// `set_closed` and `drop_task`.
self.state.load(Ordering::Acquire) == 1
}
#[inline]
fn set_closed(&self) -> bool {
// The AcqRel ordering makes the closed bit behave like a `Mutex<bool>` for synchronization
// purposes. We do this because it makes the return value of `TaskTracker::{close,reopen}`
// more meaningful for the user. Without these orderings, this assert could fail:
// ```
// // thread 1
// some_other_atomic.store(true, Relaxed);
// tracker.close();
//
// // thread 2
// if tracker.reopen() {
// assert!(some_other_atomic.load(Relaxed));
// }
// ```
// However, with the AcqRel ordering, we establish a happens-before relationship from the
// call to `close` and the later call to `reopen` that returned true.
let state = self.state.fetch_or(1, Ordering::AcqRel);
// If there are no tasks, and if it was not already closed:
if state == 0 {
self.notify_now();
}
(state & 1) == 0
}
#[inline]
fn set_open(&self) -> bool {
// See `set_closed` regarding the AcqRel ordering.
let state = self.state.fetch_and(!1, Ordering::AcqRel);
(state & 1) == 1
}
#[inline]
fn add_task(&self) {
self.state.fetch_add(2, Ordering::Relaxed);
}
#[inline]
fn drop_task(&self) {
let state = self.state.fetch_sub(2, Ordering::Release);
// If this was the last task and we are closed:
if state == 3 {
self.notify_now();
}
}
#[cold]
fn notify_now(&self) {
// Insert an acquire fence. This matters for `drop_task` but doesn't matter for
// `set_closed` since it already uses AcqRel.
//
// This synchronizes with the release store of any other call to `drop_task`, and with the
// release store in the call to `set_closed`. That ensures that everything that happened
// before those other calls to `drop_task` or `set_closed` will be visible after this load,
// and those things will also be visible to anything woken by the call to `notify_waiters`.
self.state.load(Ordering::Acquire);
self.on_last_exit.notify_waiters();
}
}
impl TaskTracker {
/// Creates a new `TaskTracker`.
///
/// The `TaskTracker` will start out as open.
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(TaskTrackerInner::new()),
}
}
/// Waits until this `TaskTracker` is both closed and empty.
///
/// If the `TaskTracker` is already closed and empty when this method is called, then it
/// returns immediately.
///
/// The `wait` future is resistant against [ABA problems][aba]. That is, if the `TaskTracker`
/// becomes both closed and empty for a short amount of time, then it is guarantee that all
/// `wait` futures that were created before the short time interval will trigger, even if they
/// are not polled during that short time interval.
///
/// # Cancel safety
///
/// This method is cancel safe.
///
/// However, the resistance against [ABA problems][aba] is lost when using `wait` as the
/// condition in a `tokio::select!` loop.
///
/// [aba]: https://en.wikipedia.org/wiki/ABA_problem
#[inline]
pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
TaskTrackerWaitFuture {
future: self.inner.on_last_exit.notified(),
inner: if self.inner.is_closed_and_empty() {
None
} else {
Some(&self.inner)
},
}
}
/// Close this `TaskTracker`.
///
/// This allows [`wait`] futures to complete. It does not prevent you from spawning new tasks.
///
/// Returns `true` if this closed the `TaskTracker`, or `false` if it was already closed.
///
/// [`wait`]: Self::wait
#[inline]
pub fn close(&self) -> bool {
self.inner.set_closed()
}
/// Reopen this `TaskTracker`.
///
/// This prevents [`wait`] futures from completing even if the `TaskTracker` is empty.
///
/// Returns `true` if this reopened the `TaskTracker`, or `false` if it was already open.
///
/// [`wait`]: Self::wait
#[inline]
pub fn reopen(&self) -> bool {
self.inner.set_open()
}
/// Returns `true` if this `TaskTracker` is [closed](Self::close).
#[inline]
#[must_use]
pub fn is_closed(&self) -> bool {
(self.inner.state.load(Ordering::Acquire) & 1) != 0
}
/// Returns the number of tasks tracked by this `TaskTracker`.
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.inner.state.load(Ordering::Acquire) >> 1
}
/// Returns `true` if there are no tasks in this `TaskTracker`.
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.state.load(Ordering::Acquire) <= 1
}
/// Spawn the provided future on the current Tokio runtime, and track it in this `TaskTracker`.
///
/// This is equivalent to `tokio::spawn(tracker.track_future(task))`.
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
tokio::task::spawn(self.track_future(task))
}
/// Spawn the provided future on the provided Tokio runtime, and track it in this `TaskTracker`.
///
/// This is equivalent to `handle.spawn(tracker.track_future(task))`.
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_on<F>(&self, task: F, handle: &Handle) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
handle.spawn(self.track_future(task))
}
/// Spawn the provided future on the current [`LocalSet`] or [`LocalRuntime`]
/// and track it in this `TaskTracker`.
///
/// This is equivalent to `tokio::task::spawn_local(tracker.track_future(task))`.
///
/// # Panics
///
/// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
///
/// [`LocalSet`]: tokio::task::LocalSet
/// [`LocalRuntime`]: tokio::runtime::LocalRuntime
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_local<F>(&self, task: F) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
tokio::task::spawn_local(self.track_future(task))
}
/// Spawn the provided future on the provided [`LocalSet`], and track it in this `TaskTracker`.
///
/// This is equivalent to `local_set.spawn_local(tracker.track_future(task))`.
///
/// [`LocalSet`]: tokio::task::LocalSet
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_local_on<F>(&self, task: F, local_set: &LocalSet) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
local_set.spawn_local(self.track_future(task))
}
/// Spawn the provided blocking task on the current Tokio runtime, and track it in this `TaskTracker`.
///
/// This is equivalent to `tokio::task::spawn_blocking(tracker.track_future(task))`.
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg(not(target_family = "wasm"))]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_blocking<F, T>(&self, task: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
let token = self.token();
tokio::task::spawn_blocking(move || {
let res = task();
drop(token);
res
})
}
/// Spawn the provided blocking task on the provided Tokio runtime, and track it in this `TaskTracker`.
///
/// This is equivalent to `handle.spawn_blocking(tracker.track_future(task))`.
#[inline]
#[track_caller]
#[cfg(feature = "rt")]
#[cfg(not(target_family = "wasm"))]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub fn spawn_blocking_on<F, T>(&self, task: F, handle: &Handle) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
let token = self.token();
handle.spawn_blocking(move || {
let res = task();
drop(token);
res
})
}
/// Track the provided future.
///
/// The returned [`TrackedFuture`] will count as a task tracked by this collection, and will
/// prevent calls to [`wait`] from returning until the task is dropped.
///
/// The task is removed from the collection when it is dropped, not when [`poll`] returns
/// [`Poll::Ready`].
///
/// # Examples
///
/// Track a future spawned with [`tokio::spawn`].
///
/// ```
/// # async fn my_async_fn() {}
/// use tokio_util::task::TaskTracker;
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let tracker = TaskTracker::new();
///
/// tokio::spawn(tracker.track_future(my_async_fn()));
/// # }
/// ```
///
/// Track a future spawned on a [`JoinSet`].
/// ```
/// # async fn my_async_fn() {}
/// use tokio::task::JoinSet;
/// use tokio_util::task::TaskTracker;
///
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() {
/// let tracker = TaskTracker::new();
/// let mut join_set = JoinSet::new();
///
/// join_set.spawn(tracker.track_future(my_async_fn()));
/// # }
/// ```
///
/// [`JoinSet`]: tokio::task::JoinSet
/// [`Poll::Pending`]: std::task::Poll::Pending
/// [`poll`]: std::future::Future::poll
/// [`wait`]: Self::wait
#[inline]
pub fn track_future<F: Future>(&self, future: F) -> TrackedFuture<F> {
TrackedFuture {
future,
token: self.token(),
}
}
/// Creates a [`TaskTrackerToken`] representing a task tracked by this `TaskTracker`.
///
/// This token is a lower-level utility than the spawn methods. Each token is considered to
/// correspond to a task. As long as the token exists, the `TaskTracker` cannot complete.
/// Furthermore, the count returned by the [`len`] method will include the tokens in the count.
///
/// Dropping the token indicates to the `TaskTracker` that the task has exited.
///
/// [`len`]: TaskTracker::len
#[inline]
pub fn token(&self) -> TaskTrackerToken {
self.inner.add_task();
TaskTrackerToken {
task_tracker: self.clone(),
}
}
/// Returns `true` if both task trackers correspond to the same set of tasks.
///
/// # Examples
///
/// ```
/// use tokio_util::task::TaskTracker;
///
/// let tracker_1 = TaskTracker::new();
/// let tracker_2 = TaskTracker::new();
/// let tracker_1_clone = tracker_1.clone();
///
/// assert!(TaskTracker::ptr_eq(&tracker_1, &tracker_1_clone));
/// assert!(!TaskTracker::ptr_eq(&tracker_1, &tracker_2));
/// ```
#[inline]
#[must_use]
pub fn ptr_eq(left: &TaskTracker, right: &TaskTracker) -> bool {
Arc::ptr_eq(&left.inner, &right.inner)
}
}
impl Default for TaskTracker {
/// Creates a new `TaskTracker`.
///
/// The `TaskTracker` will start out as open.
#[inline]
fn default() -> TaskTracker {
TaskTracker::new()
}
}
impl Clone for TaskTracker {
/// Returns a new `TaskTracker` that tracks the same set of tasks.
///
/// Since the new `TaskTracker` shares the same set of tasks, changes to one set are visible in
/// all other clones.
///
/// # Examples
///
/// ```
/// use tokio_util::task::TaskTracker;
///
/// #[tokio::main]
/// # async fn _hidden() {}
/// # #[tokio::main(flavor = "current_thread")]
/// async fn main() {
/// let tracker = TaskTracker::new();
/// let cloned = tracker.clone();
///
/// // Spawns on `tracker` are visible in `cloned`.
/// tracker.spawn(std::future::pending::<()>());
/// assert_eq!(cloned.len(), 1);
///
/// // Spawns on `cloned` are visible in `tracker`.
/// cloned.spawn(std::future::pending::<()>());
/// assert_eq!(tracker.len(), 2);
///
/// // Calling `close` is visible to `cloned`.
/// tracker.close();
/// assert!(cloned.is_closed());
///
/// // Calling `reopen` is visible to `tracker`.
/// cloned.reopen();
/// assert!(!tracker.is_closed());
/// }
/// ```
#[inline]
fn clone(&self) -> TaskTracker {
Self {
inner: self.inner.clone(),
}
}
}
fn debug_inner(inner: &TaskTrackerInner, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = inner.state.load(Ordering::Acquire);
let is_closed = (state & 1) != 0;
let len = state >> 1;
f.debug_struct("TaskTracker")
.field("len", &len)
.field("is_closed", &is_closed)
.field("inner", &(inner as *const TaskTrackerInner))
.finish()
}
impl fmt::Debug for TaskTracker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
debug_inner(&self.inner, f)
}
}
impl TaskTrackerToken {
/// Returns the [`TaskTracker`] that this token is associated with.
#[inline]
#[must_use]
pub fn task_tracker(&self) -> &TaskTracker {
&self.task_tracker
}
}
impl Clone for TaskTrackerToken {
/// Returns a new `TaskTrackerToken` associated with the same [`TaskTracker`].
///
/// This is equivalent to `token.task_tracker().token()`.
#[inline]
fn clone(&self) -> TaskTrackerToken {
self.task_tracker.token()
}
}
impl Drop for TaskTrackerToken {
/// Dropping the token indicates to the [`TaskTracker`] that the task has exited.
#[inline]
fn drop(&mut self) {
self.task_tracker.inner.drop_task();
}
}
impl<F: Future> Future for TrackedFuture<F> {
type Output = F::Output;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
self.project().future.poll(cx)
}
}
impl<F: fmt::Debug> fmt::Debug for TrackedFuture<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TrackedFuture")
.field("future", &self.future)
.field("task_tracker", self.token.task_tracker())
.finish()
}
}
impl<'a> Future for TaskTrackerWaitFuture<'a> {
type Output = ();
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let me = self.project();
let inner = match me.inner.as_ref() {
None => return Poll::Ready(()),
Some(inner) => inner,
};
let ready = inner.is_closed_and_empty() || me.future.poll(cx).is_ready();
if ready {
*me.inner = None;
Poll::Ready(())
} else {
Poll::Pending
}
}
}
impl<'a> fmt::Debug for TaskTrackerWaitFuture<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Helper<'a>(&'a TaskTrackerInner);
impl fmt::Debug for Helper<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
debug_inner(self.0, f)
}
}
f.debug_struct("TaskTrackerWaitFuture")
.field("future", &self.future)
.field("task_tracker", &self.inner.map(Helper))
.finish()
}
}

1354
vendor/tokio-util/src/time/delay_queue.rs vendored Normal file

File diff suppressed because it is too large Load Diff

51
vendor/tokio-util/src/time/mod.rs vendored Normal file
View File

@@ -0,0 +1,51 @@
//! Additional utilities for tracking time.
//!
//! This module provides additional utilities for executing code after a set period
//! of time. Currently there is only one:
//!
//! * `DelayQueue`: A queue where items are returned once the requested delay
//! has expired.
//!
//! This type must be used from within the context of the `Runtime`.
use std::time::Duration;
mod wheel;
pub mod delay_queue;
// re-export `FutureExt` to avoid breaking change
#[doc(inline)]
pub use crate::future::FutureExt;
#[doc(inline)]
pub use delay_queue::DelayQueue;
// ===== Internal utils =====
enum Round {
Up,
Down,
}
/// Convert a `Duration` to milliseconds, rounding up and saturating at
/// `u64::MAX`.
///
/// The saturating is fine because `u64::MAX` milliseconds are still many
/// million years.
#[inline]
fn ms(duration: Duration, round: Round) -> u64 {
const NANOS_PER_MILLI: u32 = 1_000_000;
const MILLIS_PER_SEC: u64 = 1_000;
// Round up.
let millis = match round {
Round::Up => (duration.subsec_nanos() + NANOS_PER_MILLI - 1) / NANOS_PER_MILLI,
Round::Down => duration.subsec_millis(),
};
duration
.as_secs()
.saturating_mul(MILLIS_PER_SEC)
.saturating_add(u64::from(millis))
}

View File

@@ -0,0 +1,198 @@
use crate::time::wheel::Stack;
use std::fmt;
/// Wheel for a single level in the timer. This wheel contains 64 slots.
pub(crate) struct Level<T> {
level: usize,
/// Bit field tracking which slots currently contain entries.
///
/// Using a bit field to track slots that contain entries allows avoiding a
/// scan to find entries. This field is updated when entries are added or
/// removed from a slot.
///
/// The least-significant bit represents slot zero.
occupied: u64,
/// Slots
slot: [T; LEVEL_MULT],
}
/// Indicates when a slot must be processed next.
#[derive(Debug)]
pub(crate) struct Expiration {
/// The level containing the slot.
pub(crate) level: usize,
/// The slot index.
pub(crate) slot: usize,
/// The instant at which the slot needs to be processed.
pub(crate) deadline: u64,
}
/// Level multiplier.
///
/// Being a power of 2 is very important.
const LEVEL_MULT: usize = 64;
impl<T: Stack> Level<T> {
pub(crate) fn new(level: usize) -> Level<T> {
Level {
level,
occupied: 0,
slot: std::array::from_fn(|_| T::default()),
}
}
/// Finds the slot that needs to be processed next and returns the slot and
/// `Instant` at which this slot must be processed.
pub(crate) fn next_expiration(&self, now: u64) -> Option<Expiration> {
// Use the `occupied` bit field to get the index of the next slot that
// needs to be processed.
let slot = self.next_occupied_slot(now)?;
// From the slot index, calculate the `Instant` at which it needs to be
// processed. This value *must* be in the future with respect to `now`.
let level_range = level_range(self.level);
let slot_range = slot_range(self.level);
// TODO: This can probably be simplified w/ power of 2 math
let level_start = now - (now % level_range);
let mut deadline = level_start + slot as u64 * slot_range;
if deadline < now {
// A timer is in a slot "prior" to the current time. This can occur
// because we do not have an infinite hierarchy of timer levels, and
// eventually a timer scheduled for a very distant time might end up
// being placed in a slot that is beyond the end of all of the
// arrays.
//
// To deal with this, we first limit timers to being scheduled no
// more than MAX_DURATION ticks in the future; that is, they're at
// most one rotation of the top level away. Then, we force timers
// that logically would go into the top+1 level, to instead go into
// the top level's slots.
//
// What this means is that the top level's slots act as a
// pseudo-ring buffer, and we rotate around them indefinitely. If we
// compute a deadline before now, and it's the top level, it
// therefore means we're actually looking at a slot in the future.
debug_assert_eq!(self.level, super::NUM_LEVELS - 1);
deadline += level_range;
}
debug_assert!(
deadline >= now,
"deadline={:016X}; now={:016X}; level={}; slot={}; occupied={:b}",
deadline,
now,
self.level,
slot,
self.occupied
);
Some(Expiration {
level: self.level,
slot,
deadline,
})
}
fn next_occupied_slot(&self, now: u64) -> Option<usize> {
if self.occupied == 0 {
return None;
}
// Get the slot for now using Maths
let now_slot = (now / slot_range(self.level)) as usize;
let occupied = self.occupied.rotate_right(now_slot as u32);
let zeros = occupied.trailing_zeros() as usize;
let slot = (zeros + now_slot) % 64;
Some(slot)
}
pub(crate) fn add_entry(&mut self, when: u64, item: T::Owned, store: &mut T::Store) {
let slot = slot_for(when, self.level);
self.slot[slot].push(item, store);
self.occupied |= occupied_bit(slot);
}
pub(crate) fn remove_entry(&mut self, when: u64, item: &T::Borrowed, store: &mut T::Store) {
let slot = slot_for(when, self.level);
self.slot[slot].remove(item, store);
if self.slot[slot].is_empty() {
// The bit is currently set
debug_assert!(self.occupied & occupied_bit(slot) != 0);
// Unset the bit
self.occupied ^= occupied_bit(slot);
}
}
pub(crate) fn pop_entry_slot(&mut self, slot: usize, store: &mut T::Store) -> Option<T::Owned> {
let ret = self.slot[slot].pop(store);
if ret.is_some() && self.slot[slot].is_empty() {
// The bit is currently set
debug_assert!(self.occupied & occupied_bit(slot) != 0);
self.occupied ^= occupied_bit(slot);
}
ret
}
pub(crate) fn peek_entry_slot(&self, slot: usize) -> Option<T::Owned> {
self.slot[slot].peek()
}
}
impl<T> fmt::Debug for Level<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Level")
.field("occupied", &self.occupied)
.finish()
}
}
fn occupied_bit(slot: usize) -> u64 {
1 << slot
}
fn slot_range(level: usize) -> u64 {
LEVEL_MULT.pow(level as u32) as u64
}
fn level_range(level: usize) -> u64 {
LEVEL_MULT as u64 * slot_range(level)
}
/// Convert a duration (milliseconds) and a level to a slot position
fn slot_for(duration: u64, level: usize) -> usize {
((duration >> (level * 6)) % LEVEL_MULT as u64) as usize
}
#[cfg(all(test, not(loom)))]
mod test {
use super::*;
#[test]
fn test_slot_for() {
for pos in 0..64 {
assert_eq!(pos as usize, slot_for(pos, 0));
}
for level in 1..5 {
for pos in level..64 {
let a = pos * 64_usize.pow(level as u32);
assert_eq!(pos, slot_for(a as u64, level));
}
}
}
}

317
vendor/tokio-util/src/time/wheel/mod.rs vendored Normal file
View File

@@ -0,0 +1,317 @@
mod level;
pub(crate) use self::level::Expiration;
use self::level::Level;
mod stack;
pub(crate) use self::stack::Stack;
use std::borrow::Borrow;
use std::fmt::Debug;
/// Timing wheel implementation.
///
/// This type provides the hashed timing wheel implementation that backs
/// [`DelayQueue`].
///
/// The structure is generic over `T: Stack`. This allows handling timeout data
/// being stored on the heap or in a slab. In order to support the latter case,
/// the slab must be passed into each function allowing the implementation to
/// lookup timer entries.
///
/// See `Driver` documentation for some implementation notes.
///
/// [`DelayQueue`]: crate::time::DelayQueue
#[derive(Debug)]
pub(crate) struct Wheel<T> {
/// The number of milliseconds elapsed since the wheel started.
elapsed: u64,
/// Timer wheel.
///
/// Levels:
///
/// * 1 ms slots / 64 ms range
/// * 64 ms slots / ~ 4 sec range
/// * ~ 4 sec slots / ~ 4 min range
/// * ~ 4 min slots / ~ 4 hr range
/// * ~ 4 hr slots / ~ 12 day range
/// * ~ 12 day slots / ~ 2 yr range
levels: Box<[Level<T>]>,
}
/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots
/// each, the timer is able to track time up to 2 years into the future with a
/// precision of 1 millisecond.
const NUM_LEVELS: usize = 6;
/// The maximum duration of a delay
const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1;
#[derive(Debug)]
pub(crate) enum InsertError {
Elapsed,
Invalid,
}
impl<T> Wheel<T>
where
T: Stack,
{
/// Create a new timing wheel
pub(crate) fn new() -> Wheel<T> {
let levels = (0..NUM_LEVELS).map(Level::new).collect();
Wheel { elapsed: 0, levels }
}
/// Return the number of milliseconds that have elapsed since the timing
/// wheel's creation.
pub(crate) fn elapsed(&self) -> u64 {
self.elapsed
}
/// Insert an entry into the timing wheel.
///
/// # Arguments
///
/// * `when`: is the instant at which the entry should be fired. It is
/// represented as the number of milliseconds since the creation
/// of the timing wheel.
///
/// * `item`: The item to insert into the wheel.
///
/// * `store`: The slab or `()` when using heap storage.
///
/// # Return
///
/// Returns `Ok` when the item is successfully inserted, `Err` otherwise.
///
/// `Err(Elapsed)` indicates that `when` represents an instant that has
/// already passed. In this case, the caller should fire the timeout
/// immediately.
///
/// `Err(Invalid)` indicates an invalid `when` argument as been supplied.
pub(crate) fn insert(
&mut self,
when: u64,
item: T::Owned,
store: &mut T::Store,
) -> Result<(), (T::Owned, InsertError)> {
if when <= self.elapsed {
return Err((item, InsertError::Elapsed));
} else if when - self.elapsed > MAX_DURATION {
return Err((item, InsertError::Invalid));
}
// Get the level at which the entry should be stored
let level = self.level_for(when);
self.levels[level].add_entry(when, item, store);
debug_assert!({
self.levels[level]
.next_expiration(self.elapsed)
.map(|e| e.deadline >= self.elapsed)
.unwrap_or(true)
});
Ok(())
}
/// Remove `item` from the timing wheel.
#[track_caller]
pub(crate) fn remove(&mut self, item: &T::Borrowed, store: &mut T::Store) {
let when = T::when(item, store);
assert!(
self.elapsed <= when,
"elapsed={}; when={}",
self.elapsed,
when
);
let level = self.level_for(when);
self.levels[level].remove_entry(when, item, store);
}
/// Instant at which to poll
pub(crate) fn poll_at(&self) -> Option<u64> {
self.next_expiration().map(|expiration| expiration.deadline)
}
/// Next key that will expire
pub(crate) fn peek(&self) -> Option<T::Owned> {
self.next_expiration()
.and_then(|expiration| self.peek_entry(&expiration))
}
/// Advances the timer up to the instant represented by `now`.
pub(crate) fn poll(&mut self, now: u64, store: &mut T::Store) -> Option<T::Owned> {
loop {
let expiration = self.next_expiration().and_then(|expiration| {
if expiration.deadline > now {
None
} else {
Some(expiration)
}
});
match expiration {
Some(ref expiration) => {
if let Some(item) = self.poll_expiration(expiration, store) {
return Some(item);
}
self.set_elapsed(expiration.deadline);
}
None => {
// in this case the poll did not indicate an expiration
// _and_ we were not able to find a next expiration in
// the current list of timers. advance to the poll's
// current time and do nothing else.
self.set_elapsed(now);
return None;
}
}
}
}
/// Returns the instant at which the next timeout expires.
fn next_expiration(&self) -> Option<Expiration> {
// Check all levels
for level in 0..NUM_LEVELS {
if let Some(expiration) = self.levels[level].next_expiration(self.elapsed) {
// There cannot be any expirations at a higher level that happen
// before this one.
debug_assert!(self.no_expirations_before(level + 1, expiration.deadline));
return Some(expiration);
}
}
None
}
/// Used for debug assertions
fn no_expirations_before(&self, start_level: usize, before: u64) -> bool {
let mut res = true;
for l2 in start_level..NUM_LEVELS {
if let Some(e2) = self.levels[l2].next_expiration(self.elapsed) {
if e2.deadline < before {
res = false;
}
}
}
res
}
/// iteratively find entries that are between the wheel's current
/// time and the expiration time. for each in that population either
/// return it for notification (in the case of the last level) or tier
/// it down to the next level (in all other cases).
pub(crate) fn poll_expiration(
&mut self,
expiration: &Expiration,
store: &mut T::Store,
) -> Option<T::Owned> {
while let Some(item) = self.pop_entry(expiration, store) {
if expiration.level == 0 {
debug_assert_eq!(T::when(item.borrow(), store), expiration.deadline);
return Some(item);
} else {
let when = T::when(item.borrow(), store);
let next_level = expiration.level - 1;
self.levels[next_level].add_entry(when, item, store);
}
}
None
}
fn set_elapsed(&mut self, when: u64) {
assert!(
self.elapsed <= when,
"elapsed={:?}; when={:?}",
self.elapsed,
when
);
if when > self.elapsed {
self.elapsed = when;
}
}
fn pop_entry(&mut self, expiration: &Expiration, store: &mut T::Store) -> Option<T::Owned> {
self.levels[expiration.level].pop_entry_slot(expiration.slot, store)
}
fn peek_entry(&self, expiration: &Expiration) -> Option<T::Owned> {
self.levels[expiration.level].peek_entry_slot(expiration.slot)
}
fn level_for(&self, when: u64) -> usize {
level_for(self.elapsed, when)
}
}
fn level_for(elapsed: u64, when: u64) -> usize {
const SLOT_MASK: u64 = (1 << 6) - 1;
// Mask in the trailing bits ignored by the level calculation in order to cap
// the possible leading zeros
let mut masked = elapsed ^ when | SLOT_MASK;
if masked >= MAX_DURATION {
// Fudge the timer into the top level
masked = MAX_DURATION - 1;
}
let leading_zeros = masked.leading_zeros() as usize;
let significant = 63 - leading_zeros;
significant / 6
}
#[cfg(all(test, not(loom)))]
mod test {
use super::*;
#[test]
fn test_level_for() {
for pos in 0..64 {
assert_eq!(0, level_for(0, pos), "level_for({pos}) -- binary = {pos:b}");
}
for level in 1..5 {
for pos in level..64 {
let a = pos * 64_usize.pow(level as u32);
assert_eq!(
level,
level_for(0, a as u64),
"level_for({a}) -- binary = {a:b}"
);
if pos > level {
let a = a - 1;
assert_eq!(
level,
level_for(0, a as u64),
"level_for({a}) -- binary = {a:b}"
);
}
if pos < 64 {
let a = a + 1;
assert_eq!(
level,
level_for(0, a as u64),
"level_for({a}) -- binary = {a:b}"
);
}
}
}
}
}

View File

@@ -0,0 +1,31 @@
use std::borrow::Borrow;
use std::cmp::Eq;
use std::hash::Hash;
/// Abstracts the stack operations needed to track timeouts.
pub(crate) trait Stack: Default {
/// Type of the item stored in the stack
type Owned: Borrow<Self::Borrowed>;
/// Borrowed item
type Borrowed: Eq + Hash;
/// Item storage, this allows a slab to be used instead of just the heap
type Store;
/// Returns `true` if the stack is empty
fn is_empty(&self) -> bool;
/// Push an item onto the stack
fn push(&mut self, item: Self::Owned, store: &mut Self::Store);
/// Pop an item from the stack
fn pop(&mut self, store: &mut Self::Store) -> Option<Self::Owned>;
/// Peek into the stack.
fn peek(&self) -> Option<Self::Owned>;
fn remove(&mut self, item: &Self::Borrowed, store: &mut Self::Store);
fn when(item: &Self::Borrowed, store: &Self::Store) -> u64;
}

6
vendor/tokio-util/src/tracing.rs vendored Normal file
View File

@@ -0,0 +1,6 @@
macro_rules! trace {
($($arg:tt)*) => {
#[cfg(feature = "tracing")]
tracing::trace!($($arg)*);
};
}

249
vendor/tokio-util/src/udp/frame.rs vendored Normal file
View File

@@ -0,0 +1,249 @@
use crate::codec::{Decoder, Encoder};
use futures_core::Stream;
use tokio::{io::ReadBuf, net::UdpSocket};
use bytes::{BufMut, BytesMut};
use futures_sink::Sink;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::{
borrow::Borrow,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
};
/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
/// the `Encoder` and `Decoder` traits to encode and decode frames.
///
/// Raw UDP sockets work with datagrams, but higher-level code usually wants to
/// batch these into meaningful chunks, called "frames". This method layers
/// framing on top of this socket by using the `Encoder` and `Decoder` traits to
/// handle encoding and decoding of messages frames. Note that the incoming and
/// outgoing frame types may be distinct.
///
/// This function returns a *single* object that is both [`Stream`] and [`Sink`];
/// grouping this into a single object is often useful for layering things which
/// require both read and write access to the underlying object.
///
/// If you want to work more directly with the streams and sink, consider
/// calling [`split`] on the `UdpFramed` returned by this method, which will break
/// them into separate objects, allowing them to interact more easily.
///
/// [`Stream`]: futures_core::Stream
/// [`Sink`]: futures_sink::Sink
/// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
#[must_use = "sinks do nothing unless polled"]
#[derive(Debug)]
pub struct UdpFramed<C, T = UdpSocket> {
socket: T,
codec: C,
rd: BytesMut,
wr: BytesMut,
out_addr: SocketAddr,
flushed: bool,
is_readable: bool,
current_addr: Option<SocketAddr>,
}
const INITIAL_RD_CAPACITY: usize = 64 * 1024;
const INITIAL_WR_CAPACITY: usize = 8 * 1024;
impl<C, T> Unpin for UdpFramed<C, T> {}
impl<C, T> Stream for UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
C: Decoder,
{
type Item = Result<(C::Item, SocketAddr), C::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let pin = self.get_mut();
pin.rd.reserve(INITIAL_RD_CAPACITY);
loop {
// Are there still bytes left in the read buffer to decode?
if pin.is_readable {
if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
let current_addr = pin
.current_addr
.expect("will always be set before this line is called");
return Poll::Ready(Some(Ok((frame, current_addr))));
}
// if this line has been reached then decode has returned `None`.
pin.is_readable = false;
pin.rd.clear();
}
// We're out of data. Try and fetch more data to decode
let addr = {
// Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
// transparent wrapper around `[MaybeUninit<u8>]`.
let buf = unsafe { pin.rd.chunk_mut().as_uninit_slice_mut() };
let mut read = ReadBuf::uninit(buf);
let ptr = read.filled().as_ptr();
let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
assert_eq!(ptr, read.filled().as_ptr());
let addr = res?;
let filled = read.filled().len();
// Safety: This is guaranteed to be the number of initialized (and read) bytes due
// to the invariants provided by `ReadBuf::filled`.
unsafe { pin.rd.advance_mut(filled) };
addr
};
pin.current_addr = Some(addr);
pin.is_readable = true;
}
}
}
impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
C: Encoder<I>,
{
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if !self.flushed {
match self.poll_flush(cx)? {
Poll::Ready(()) => {}
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
let (frame, out_addr) = item;
let pin = self.get_mut();
pin.codec.encode(frame, &mut pin.wr)?;
pin.out_addr = out_addr;
pin.flushed = false;
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.flushed {
return Poll::Ready(Ok(()));
}
let Self {
ref socket,
ref mut out_addr,
ref mut wr,
..
} = *self;
let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
let wrote_all = n == self.wr.len();
self.wr.clear();
self.flushed = true;
let res = if wrote_all {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"failed to write entire datagram to socket",
)
.into())
};
Poll::Ready(res)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.poll_flush(cx))?;
Poll::Ready(Ok(()))
}
}
impl<C, T> UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
{
/// Create a new `UdpFramed` backed by the given socket and codec.
///
/// See struct level documentation for more details.
pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
Self {
socket,
codec,
out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
flushed: true,
is_readable: false,
current_addr: None,
}
}
/// Returns a reference to the underlying I/O stream wrapped by `Framed`.
///
/// # Note
///
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_ref(&self) -> &T {
&self.socket
}
/// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
///
/// # Note
///
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_mut(&mut self) -> &mut T {
&mut self.socket
}
/// Returns a reference to the underlying codec wrapped by
/// `Framed`.
///
/// Note that care should be taken to not tamper with the underlying codec
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn codec(&self) -> &C {
&self.codec
}
/// Returns a mutable reference to the underlying codec wrapped by
/// `UdpFramed`.
///
/// Note that care should be taken to not tamper with the underlying codec
/// as it may corrupt the stream of frames otherwise being worked with.
pub fn codec_mut(&mut self) -> &mut C {
&mut self.codec
}
/// Returns a reference to the read buffer.
pub fn read_buffer(&self) -> &BytesMut {
&self.rd
}
/// Returns a mutable reference to the read buffer.
pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
&mut self.rd
}
/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> T {
self.socket
}
}

6
vendor/tokio-util/src/udp/mod.rs vendored Normal file
View File

@@ -0,0 +1,6 @@
#![cfg(not(loom))]
//! UDP framing
mod frame;
pub use frame::UdpFramed;

View File

@@ -0,0 +1,67 @@
use core::future::Future;
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::task::{Context, Poll};
/// A wrapper type that tells the compiler that the contents might not be valid.
///
/// This is necessary mainly when `T` contains a reference. In that case, the
/// compiler will sometimes assume that the reference is always valid; in some
/// cases it will assume this even after the destructor of `T` runs. For
/// example, when a reference is used as a function argument, then the compiler
/// will assume that the reference is valid until the function returns, even if
/// the reference is destroyed during the function. When the reference is used
/// as part of a self-referential struct, that assumption can be false. Wrapping
/// the reference in this type prevents the compiler from making that
/// assumption.
///
/// # Invariants
///
/// The `MaybeUninit` will always contain a valid value until the destructor runs.
//
// Reference
// See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
//
// TODO: replace this with an official solution once RFC #3336 or similar is available.
// <https://github.com/rust-lang/rfcs/pull/3336>
#[repr(transparent)]
pub(crate) struct MaybeDangling<T>(MaybeUninit<T>);
impl<T> Drop for MaybeDangling<T> {
fn drop(&mut self) {
// Safety: `0` is always initialized.
unsafe { core::ptr::drop_in_place(self.0.as_mut_ptr()) };
}
}
impl<T> MaybeDangling<T> {
pub(crate) fn new(inner: T) -> Self {
Self(MaybeUninit::new(inner))
}
}
impl<F: Future> Future for MaybeDangling<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Safety: `0` is always initialized.
let fut = unsafe { self.map_unchecked_mut(|this| this.0.assume_init_mut()) };
fut.poll(cx)
}
}
#[test]
fn maybedangling_runs_drop() {
struct SetOnDrop<'a>(&'a mut bool);
impl Drop for SetOnDrop<'_> {
fn drop(&mut self) {
*self.0 = true;
}
}
let mut success = false;
drop(MaybeDangling::new(SetOnDrop(&mut success)));
assert!(success);
}

31
vendor/tokio-util/src/util/mod.rs vendored Normal file
View File

@@ -0,0 +1,31 @@
mod maybe_dangling;
#[cfg(any(feature = "io", feature = "codec"))]
mod poll_buf;
pub(crate) use maybe_dangling::MaybeDangling;
#[cfg(any(feature = "io", feature = "codec"))]
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
pub use poll_buf::{poll_read_buf, poll_write_buf};
cfg_rt! {
#[cfg_attr(not(feature = "io"), allow(unused))]
pub(crate) use tokio::task::coop::poll_proceed;
}
cfg_not_rt! {
#[cfg_attr(not(feature = "io"), allow(unused))]
use std::task::{Context, Poll};
#[cfg_attr(not(feature = "io"), allow(unused))]
pub(crate) struct RestoreOnPending;
#[cfg_attr(not(feature = "io"), allow(unused))]
impl RestoreOnPending {
pub(crate) fn made_progress(&self) {}
}
#[cfg_attr(not(feature = "io"), allow(unused))]
pub(crate) fn poll_proceed(_cx: &mut Context<'_>) -> Poll<RestoreOnPending> {
Poll::Ready(RestoreOnPending)
}
}

143
vendor/tokio-util/src/util/poll_buf.rs vendored Normal file
View File

@@ -0,0 +1,143 @@
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use bytes::{Buf, BufMut};
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::task::{ready, Context, Poll};
/// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
///
/// [`BufMut`]: bytes::Buf
///
/// # Example
///
/// ```
/// use bytes::{Bytes, BytesMut};
/// use tokio_stream as stream;
/// use tokio::io::Result;
/// use tokio_util::io::{StreamReader, poll_read_buf};
/// use std::future::poll_fn;
/// use std::pin::Pin;
/// # #[tokio::main(flavor = "current_thread")]
/// # async fn main() -> std::io::Result<()> {
///
/// // Create a reader from an iterator. This particular reader will always be
/// // ready.
/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
///
/// let mut buf = BytesMut::new();
/// let mut reads = 0;
///
/// loop {
/// reads += 1;
/// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
///
/// if n == 0 {
/// break;
/// }
/// }
///
/// // one or more reads might be necessary.
/// assert!(reads >= 1);
/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
/// # Ok(())
/// # }
/// ```
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
pub fn poll_read_buf<T: AsyncRead + ?Sized, B: BufMut>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let n = {
let dst = buf.chunk_mut();
// Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
// transparent wrapper around `[MaybeUninit<u8>]`.
let dst = unsafe { dst.as_uninit_slice_mut() };
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
ready!(io.poll_read(cx, &mut buf)?);
// Ensure the pointer does not change from under us
assert_eq!(ptr, buf.filled().as_ptr());
buf.filled().len()
};
// Safety: This is guaranteed to be the number of initialized (and read)
// bytes due to the invariants provided by `ReadBuf::filled`.
unsafe {
buf.advance_mut(n);
}
Poll::Ready(Ok(n))
}
/// Try to write data from an implementer of the [`Buf`] trait to an
/// [`AsyncWrite`], advancing the buffer's internal cursor.
///
/// This function will use [vectored writes] when the [`AsyncWrite`] supports
/// vectored writes.
///
/// # Examples
///
/// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
/// [`Buf`]:
///
/// ```no_run
/// use tokio_util::io::poll_write_buf;
/// use tokio::io;
/// use tokio::fs::File;
///
/// use bytes::Buf;
/// use std::future::poll_fn;
/// use std::io::Cursor;
/// use std::pin::Pin;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut file = File::create("foo.txt").await?;
/// let mut buf = Cursor::new(b"data to write");
///
/// // Loop until the entire contents of the buffer are written to
/// // the file.
/// while buf.has_remaining() {
/// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
/// }
///
/// Ok(())
/// }
/// ```
///
/// [`Buf`]: bytes::Buf
/// [`AsyncWrite`]: tokio::io::AsyncWrite
/// [`File`]: tokio::fs::File
/// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
pub fn poll_write_buf<T: AsyncWrite + ?Sized, B: Buf>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
const MAX_BUFS: usize = 64;
if !buf.has_remaining() {
return Poll::Ready(Ok(0));
}
let n = if io.is_write_vectored() {
let mut slices = [IoSlice::new(&[]); MAX_BUFS];
let cnt = buf.chunks_vectored(&mut slices);
ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
} else {
ready!(io.poll_write(cx, buf.chunk()))?
};
buf.advance(n);
Poll::Ready(Ok(n))
}