312 lines
12 KiB
Rust
312 lines
12 KiB
Rust
use crate::codec::decoder::Decoder;
|
|
use crate::codec::encoder::Encoder;
|
|
|
|
use futures_core::Stream;
|
|
use tokio::io::{AsyncRead, AsyncWrite};
|
|
|
|
use bytes::BytesMut;
|
|
use futures_sink::Sink;
|
|
use pin_project_lite::pin_project;
|
|
use std::borrow::{Borrow, BorrowMut};
|
|
use std::io;
|
|
use std::pin::Pin;
|
|
use std::task::{ready, Context, Poll};
|
|
|
|
pin_project! {
|
|
#[derive(Debug)]
|
|
pub(crate) struct FramedImpl<T, U, State> {
|
|
#[pin]
|
|
pub(crate) inner: T,
|
|
pub(crate) state: State,
|
|
pub(crate) codec: U,
|
|
}
|
|
}
|
|
|
|
const INITIAL_CAPACITY: usize = 8 * 1024;
|
|
|
|
#[derive(Debug)]
|
|
pub(crate) struct ReadFrame {
|
|
pub(crate) eof: bool,
|
|
pub(crate) is_readable: bool,
|
|
pub(crate) buffer: BytesMut,
|
|
pub(crate) has_errored: bool,
|
|
}
|
|
|
|
pub(crate) struct WriteFrame {
|
|
pub(crate) buffer: BytesMut,
|
|
pub(crate) backpressure_boundary: usize,
|
|
}
|
|
|
|
#[derive(Default)]
|
|
pub(crate) struct RWFrames {
|
|
pub(crate) read: ReadFrame,
|
|
pub(crate) write: WriteFrame,
|
|
}
|
|
|
|
impl Default for ReadFrame {
|
|
fn default() -> Self {
|
|
Self {
|
|
eof: false,
|
|
is_readable: false,
|
|
buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
|
|
has_errored: false,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for WriteFrame {
|
|
fn default() -> Self {
|
|
Self {
|
|
buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
|
|
backpressure_boundary: INITIAL_CAPACITY,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<BytesMut> for ReadFrame {
|
|
fn from(mut buffer: BytesMut) -> Self {
|
|
let size = buffer.capacity();
|
|
if size < INITIAL_CAPACITY {
|
|
buffer.reserve(INITIAL_CAPACITY - size);
|
|
}
|
|
|
|
Self {
|
|
buffer,
|
|
is_readable: size > 0,
|
|
eof: false,
|
|
has_errored: false,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<BytesMut> for WriteFrame {
|
|
fn from(mut buffer: BytesMut) -> Self {
|
|
let size = buffer.capacity();
|
|
if size < INITIAL_CAPACITY {
|
|
buffer.reserve(INITIAL_CAPACITY - size);
|
|
}
|
|
|
|
Self {
|
|
buffer,
|
|
backpressure_boundary: INITIAL_CAPACITY,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Borrow<ReadFrame> for RWFrames {
|
|
fn borrow(&self) -> &ReadFrame {
|
|
&self.read
|
|
}
|
|
}
|
|
impl BorrowMut<ReadFrame> for RWFrames {
|
|
fn borrow_mut(&mut self) -> &mut ReadFrame {
|
|
&mut self.read
|
|
}
|
|
}
|
|
impl Borrow<WriteFrame> for RWFrames {
|
|
fn borrow(&self) -> &WriteFrame {
|
|
&self.write
|
|
}
|
|
}
|
|
impl BorrowMut<WriteFrame> for RWFrames {
|
|
fn borrow_mut(&mut self) -> &mut WriteFrame {
|
|
&mut self.write
|
|
}
|
|
}
|
|
impl<T, U, R> Stream for FramedImpl<T, U, R>
|
|
where
|
|
T: AsyncRead,
|
|
U: Decoder,
|
|
R: BorrowMut<ReadFrame>,
|
|
{
|
|
type Item = Result<U::Item, U::Error>;
|
|
|
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
use crate::util::poll_read_buf;
|
|
|
|
let mut pinned = self.project();
|
|
let state: &mut ReadFrame = pinned.state.borrow_mut();
|
|
// The following loops implements a state machine with each state corresponding
|
|
// to a combination of the `is_readable` and `eof` flags. States persist across
|
|
// loop entries and most state transitions occur with a return.
|
|
//
|
|
// The initial state is `reading`.
|
|
//
|
|
// | state | eof | is_readable | has_errored |
|
|
// |---------|-------|-------------|-------------|
|
|
// | reading | false | false | false |
|
|
// | framing | false | true | false |
|
|
// | pausing | true | true | false |
|
|
// | paused | true | false | false |
|
|
// | errored | <any> | <any> | true |
|
|
// `decode_eof` returns Err
|
|
// ┌────────────────────────────────────────────────────────┐
|
|
// `decode_eof` returns │ │
|
|
// `Ok(Some)` │ │
|
|
// ┌─────┐ │ `decode_eof` returns After returning │
|
|
// Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐
|
|
// ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │
|
|
// │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘
|
|
// Pending read │ │ │ │ │ │
|
|
// ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │
|
|
// │ │ │ ┌──────┐ │ Pending │ │
|
|
// │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │
|
|
// └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │
|
|
// └──┬─▲────┘ └─────┬──┬┘ │ │
|
|
// │ │ │ │ `decode` returns Err │ │
|
|
// │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │
|
|
// │ read returns Err │
|
|
// └────────────────────────────────────────────────────────────────────────────────────────────┘
|
|
loop {
|
|
// Return `None` if we have encountered an error from the underlying decoder
|
|
// See: https://github.com/tokio-rs/tokio/issues/3976
|
|
if state.has_errored {
|
|
// preparing has_errored -> paused
|
|
trace!("Returning None and setting paused");
|
|
state.is_readable = false;
|
|
state.has_errored = false;
|
|
return Poll::Ready(None);
|
|
}
|
|
|
|
// Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
|
|
// i.e. it _might_ contain data consumable as a frame or closing frame.
|
|
// Both signal that there is no such data by returning `None`.
|
|
//
|
|
// If `decode` couldn't read a frame and the upstream source has returned eof,
|
|
// `decode_eof` will attempt to decode the remaining bytes as closing frames.
|
|
//
|
|
// If the underlying AsyncRead is resumable, we may continue after an EOF,
|
|
// but must finish emitting all of it's associated `decode_eof` frames.
|
|
// Furthermore, we don't want to emit any `decode_eof` frames on retried
|
|
// reads after an EOF unless we've actually read more data.
|
|
if state.is_readable {
|
|
// pausing or framing
|
|
if state.eof {
|
|
// pausing
|
|
let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
|
|
trace!("Got an error, going to errored state");
|
|
state.has_errored = true;
|
|
err
|
|
})?;
|
|
if frame.is_none() {
|
|
state.is_readable = false; // prepare pausing -> paused
|
|
}
|
|
// implicit pausing -> pausing or pausing -> paused
|
|
return Poll::Ready(frame.map(Ok));
|
|
}
|
|
|
|
// framing
|
|
trace!("attempting to decode a frame");
|
|
|
|
if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
|
|
trace!("Got an error, going to errored state");
|
|
state.has_errored = true;
|
|
op
|
|
})? {
|
|
trace!("frame decoded from buffer");
|
|
// implicit framing -> framing
|
|
return Poll::Ready(Some(Ok(frame)));
|
|
}
|
|
|
|
// framing -> reading
|
|
state.is_readable = false;
|
|
}
|
|
// reading or paused
|
|
// If we can't build a frame yet, try to read more data and try again.
|
|
// Make sure we've got room for at least one byte to read to ensure
|
|
// that we don't get a spurious 0 that looks like EOF.
|
|
state.buffer.reserve(1);
|
|
#[allow(clippy::blocks_in_conditions)]
|
|
let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
|
|
|err| {
|
|
trace!("Got an error, going to errored state");
|
|
state.has_errored = true;
|
|
err
|
|
},
|
|
)? {
|
|
Poll::Ready(ct) => ct,
|
|
// implicit reading -> reading or implicit paused -> paused
|
|
Poll::Pending => return Poll::Pending,
|
|
};
|
|
if bytect == 0 {
|
|
if state.eof {
|
|
// We're already at an EOF, and since we've reached this path
|
|
// we're also not readable. This implies that we've already finished
|
|
// our `decode_eof` handling, so we can simply return `None`.
|
|
// implicit paused -> paused
|
|
return Poll::Ready(None);
|
|
}
|
|
// prepare reading -> paused
|
|
state.eof = true;
|
|
} else {
|
|
// prepare paused -> framing or noop reading -> framing
|
|
state.eof = false;
|
|
}
|
|
|
|
// paused -> framing or reading -> framing or reading -> pausing
|
|
state.is_readable = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
|
|
where
|
|
T: AsyncWrite,
|
|
U: Encoder<I>,
|
|
U::Error: From<io::Error>,
|
|
W: BorrowMut<WriteFrame>,
|
|
{
|
|
type Error = U::Error;
|
|
|
|
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary {
|
|
self.as_mut().poll_flush(cx)
|
|
} else {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
|
|
let pinned = self.project();
|
|
pinned
|
|
.codec
|
|
.encode(item, &mut pinned.state.borrow_mut().buffer)?;
|
|
Ok(())
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
use crate::util::poll_write_buf;
|
|
trace!("flushing framed transport");
|
|
let mut pinned = self.project();
|
|
|
|
while !pinned.state.borrow_mut().buffer.is_empty() {
|
|
let WriteFrame { buffer, .. } = pinned.state.borrow_mut();
|
|
trace!(remaining = buffer.len(), "writing;");
|
|
|
|
let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
|
|
|
|
if n == 0 {
|
|
return Poll::Ready(Err(io::Error::new(
|
|
io::ErrorKind::WriteZero,
|
|
"failed to \
|
|
write frame to transport",
|
|
)
|
|
.into()));
|
|
}
|
|
}
|
|
|
|
// Try flushing the underlying IO
|
|
ready!(pinned.inner.poll_flush(cx))?;
|
|
|
|
trace!("framed transport flushed");
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
ready!(self.as_mut().poll_flush(cx))?;
|
|
ready!(self.project().inner.poll_shutdown(cx))?;
|
|
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
}
|