use std::future::Future; use std::io::Cursor; use std::pin::Pin; use std::task::{Context, Poll}; use bytes::{Buf, Bytes}; use futures_channel::{mpsc, oneshot}; use futures_core::{ready, Stream}; use h2::{Reason, RecvStream, SendStream}; use pin_project_lite::pin_project; use super::ping::Recorder; use super::SendBuf; use crate::rt::{Read, ReadBufCursor, Write}; pub(super) fn pair( send_stream: SendStream>, recv_stream: RecvStream, ping: Recorder, ) -> (H2Upgraded, UpgradedSendStreamTask) { let (tx, rx) = mpsc::channel(1); let (error_tx, error_rx) = oneshot::channel(); ( H2Upgraded { send_stream: UpgradedSendStreamBridge { tx, error_rx }, recv_stream, ping, buf: Bytes::new(), }, UpgradedSendStreamTask { h2_tx: send_stream, rx, error_tx: Some(error_tx), }, ) } pub(super) struct H2Upgraded { ping: Recorder, send_stream: UpgradedSendStreamBridge, recv_stream: RecvStream, buf: Bytes, } struct UpgradedSendStreamBridge { tx: mpsc::Sender>>, error_rx: oneshot::Receiver, } pin_project! { #[must_use = "futures do nothing unless polled"] pub struct UpgradedSendStreamTask { #[pin] h2_tx: SendStream>, #[pin] rx: mpsc::Receiver>>, error_tx: Option>, } } // ===== impl UpgradedSendStreamTask ===== impl UpgradedSendStreamTask where B: Buf, { fn tick(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut me = self.project(); // this is a manual `select()` over 3 "futures", so we always need // to be sure they are ready and/or we are waiting notification of // one of the sides hanging up, so the task doesn't live around // longer than it's meant to. loop { // we don't have the next chunk of data yet, so just reserve 1 byte to make // sure there's some capacity available. h2 will handle the capacity management // for the actual body chunk. me.h2_tx.reserve_capacity(1); if me.h2_tx.capacity() == 0 { // poll_capacity oddly needs a loop 'capacity: loop { match me.h2_tx.poll_capacity(cx) { Poll::Ready(Some(Ok(0))) => {} Poll::Ready(Some(Ok(_))) => break, Poll::Ready(Some(Err(e))) => { return Poll::Ready(Err(crate::Error::new_body_write(e))) } Poll::Ready(None) => { // None means the stream is no longer in a // streaming state, we either finished it // somehow, or the remote reset us. return Poll::Ready(Err(crate::Error::new_body_write( "send stream capacity unexpectedly closed", ))); } Poll::Pending => break 'capacity, } } } match me.h2_tx.poll_reset(cx) { Poll::Ready(Ok(reason)) => { trace!("stream received RST_STREAM: {:?}", reason); return Poll::Ready(Err(crate::Error::new_body_write(::h2::Error::from( reason, )))); } Poll::Ready(Err(err)) => { return Poll::Ready(Err(crate::Error::new_body_write(err))) } Poll::Pending => (), } match me.rx.as_mut().poll_next(cx) { Poll::Ready(Some(cursor)) => { me.h2_tx .send_data(SendBuf::Cursor(cursor), false) .map_err(crate::Error::new_body_write)?; } Poll::Ready(None) => { me.h2_tx .send_data(SendBuf::None, true) .map_err(crate::Error::new_body_write)?; return Poll::Ready(Ok(())); } Poll::Pending => { return Poll::Pending; } } } } } impl Future for UpgradedSendStreamTask where B: Buf, { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().tick(cx) { Poll::Ready(Ok(())) => Poll::Ready(()), Poll::Ready(Err(err)) => { if let Some(tx) = self.error_tx.take() { let _oh_well = tx.send(err); } Poll::Ready(()) } Poll::Pending => Poll::Pending, } } } // ===== impl H2Upgraded ===== impl Read for H2Upgraded { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, mut read_buf: ReadBufCursor<'_>, ) -> Poll> { if self.buf.is_empty() { self.buf = loop { match ready!(self.recv_stream.poll_data(cx)) { None => return Poll::Ready(Ok(())), Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => { continue } Some(Ok(buf)) => { self.ping.record_data(buf.len()); break buf; } Some(Err(e)) => { return Poll::Ready(match e.reason() { Some(Reason::NO_ERROR) | Some(Reason::CANCEL) => Ok(()), Some(Reason::STREAM_CLOSED) => { Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)) } _ => Err(h2_to_io_error(e)), }) } } }; } let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); read_buf.put_slice(&self.buf[..cnt]); self.buf.advance(cnt); let _ = self.recv_stream.flow_control().release_capacity(cnt); Poll::Ready(Ok(())) } } impl Write for H2Upgraded { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { if buf.is_empty() { return Poll::Ready(Ok(0)); } match self.send_stream.tx.poll_ready(cx) { Poll::Ready(Ok(())) => {} Poll::Ready(Err(_task_dropped)) => { // if the task dropped, check if there was an error // otherwise i guess its a broken pipe return match Pin::new(&mut self.send_stream.error_rx).poll(cx) { Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), Poll::Ready(Err(_task_dropped)) => { Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) } Poll::Pending => Poll::Pending, }; } Poll::Pending => return Poll::Pending, } let n = buf.len(); match self.send_stream.tx.start_send(Cursor::new(buf.into())) { Ok(()) => Poll::Ready(Ok(n)), Err(_task_dropped) => { // if the task dropped, check if there was an error // otherwise i guess its a broken pipe match Pin::new(&mut self.send_stream.error_rx).poll(cx) { Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), Poll::Ready(Err(_task_dropped)) => { Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())) } Poll::Pending => Poll::Pending, } } } } fn poll_flush( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { match self.send_stream.tx.poll_ready(cx) { Poll::Ready(Ok(())) => Poll::Ready(Ok(())), Poll::Ready(Err(_task_dropped)) => { // if the task dropped, check if there was an error // otherwise it was a clean close match Pin::new(&mut self.send_stream.error_rx).poll(cx) { Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())), Poll::Pending => Poll::Pending, } } Poll::Pending => Poll::Pending, } } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { self.send_stream.tx.close_channel(); match Pin::new(&mut self.send_stream.error_rx).poll(cx) { Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())), Poll::Pending => Poll::Pending, } } } fn io_error(e: crate::Error) -> std::io::Error { std::io::Error::new(std::io::ErrorKind::Other, e) } fn h2_to_io_error(e: h2::Error) -> std::io::Error { if e.is_io() { e.into_io().unwrap() } else { std::io::Error::new(std::io::ErrorKind::Other, e) } }