chore: checkpoint before Python removal

This commit is contained in:
2026-03-26 22:33:59 +00:00
parent 683cec9307
commit e568ddf82a
29972 changed files with 11269302 additions and 2 deletions

1327
vendor/quinn/src/connection.rs vendored Normal file

File diff suppressed because it is too large Load Diff

875
vendor/quinn/src/endpoint.rs vendored Normal file
View File

@@ -0,0 +1,875 @@
use std::{
collections::VecDeque,
fmt,
future::Future,
io,
io::IoSliceMut,
mem,
net::{SocketAddr, SocketAddrV6},
pin::Pin,
str,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
use crate::runtime::default_runtime;
use crate::{
Instant,
runtime::{AsyncUdpSocket, Runtime},
udp_transmit,
};
use bytes::{Bytes, BytesMut};
use pin_project_lite::pin_project;
use proto::{
self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
EndpointEvent, ServerConfig,
};
use rustc_hash::FxHashMap;
#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"),))]
use socket2::{Domain, Protocol, Socket, Type};
use tokio::sync::{Notify, futures::Notified, mpsc};
use tracing::{Instrument, Span};
use udp::{BATCH_SIZE, RecvMeta};
use crate::{
ConnectionEvent, EndpointConfig, IO_LOOP_BOUND, RECV_TIME_BOUND, VarInt,
connection::Connecting, incoming::Incoming, work_limiter::WorkLimiter,
};
/// A QUIC endpoint.
///
/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
/// client and server for different connections.
///
/// May be cloned to obtain another handle to the same endpoint.
#[derive(Debug, Clone)]
pub struct Endpoint {
pub(crate) inner: EndpointRef,
pub(crate) default_client_config: Option<ClientConfig>,
runtime: Arc<dyn Runtime>,
}
impl Endpoint {
/// Helper to construct an endpoint for use with outgoing connections only
///
/// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
/// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
/// IPv6 address respectively from an OS-assigned port.
///
/// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
/// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
/// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
/// address. For example:
///
/// ```
/// quinn::Endpoint::client((std::net::Ipv6Addr::UNSPECIFIED, 0).into());
/// ```
///
/// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
/// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
pub fn client(addr: SocketAddr) -> io::Result<Self> {
let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
if addr.is_ipv6() {
if let Err(e) = socket.set_only_v6(false) {
tracing::debug!(%e, "unable to make socket dual-stack");
}
}
socket.bind(&addr.into())?;
let runtime =
default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
Self::new_with_abstract_socket(
EndpointConfig::default(),
None,
runtime.wrap_udp_socket(socket.into())?,
runtime,
)
}
/// Returns relevant stats from this Endpoint
pub fn stats(&self) -> EndpointStats {
self.inner.state.lock().unwrap().stats
}
/// Helper to construct an endpoint for use with both incoming and outgoing connections
///
/// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
/// IPv6 address on Windows will not by default be able to communicate with IPv4
/// addresses. Portable applications should bind an address that matches the family they wish to
/// communicate within.
#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] // `EndpointConfig::default()` is only available with these
pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
let socket = std::net::UdpSocket::bind(addr)?;
let runtime =
default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
Self::new_with_abstract_socket(
EndpointConfig::default(),
Some(config),
runtime.wrap_udp_socket(socket)?,
runtime,
)
}
/// Construct an endpoint with arbitrary configuration and socket
#[cfg(not(wasm_browser))]
pub fn new(
config: EndpointConfig,
server_config: Option<ServerConfig>,
socket: std::net::UdpSocket,
runtime: Arc<dyn Runtime>,
) -> io::Result<Self> {
let socket = runtime.wrap_udp_socket(socket)?;
Self::new_with_abstract_socket(config, server_config, socket, runtime)
}
/// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
///
/// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
/// ownership is needed.
pub fn new_with_abstract_socket(
config: EndpointConfig,
server_config: Option<ServerConfig>,
socket: Arc<dyn AsyncUdpSocket>,
runtime: Arc<dyn Runtime>,
) -> io::Result<Self> {
let addr = socket.local_addr()?;
let allow_mtud = !socket.may_fragment();
let rc = EndpointRef::new(
socket,
proto::Endpoint::new(
Arc::new(config),
server_config.map(Arc::new),
allow_mtud,
None,
),
addr.is_ipv6(),
runtime.clone(),
);
let driver = EndpointDriver(rc.clone());
runtime.spawn(Box::pin(
async {
if let Err(e) = driver.await {
tracing::error!("I/O error: {}", e);
}
}
.instrument(Span::current()),
));
Ok(Self {
inner: rc,
default_client_config: None,
runtime,
})
}
/// Get the next incoming connection attempt from a client
///
/// Yields [`Incoming`]s, or `None` if the endpoint is [`close`](Self::close)d. [`Incoming`]
/// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
/// filter connection attempts or force address validation, or converted into an intermediate
/// `Connecting` future which can be used to e.g. send 0.5-RTT data.
pub fn accept(&self) -> Accept<'_> {
Accept {
endpoint: self,
notify: self.inner.shared.incoming.notified(),
}
}
/// Set the client configuration used by `connect`
pub fn set_default_client_config(&mut self, config: ClientConfig) {
self.default_client_config = Some(config);
}
/// Connect to a remote endpoint
///
/// `server_name` must be covered by the certificate presented by the server. This prevents a
/// connection from being intercepted by an attacker with a valid certificate for some other
/// server.
///
/// May fail immediately due to configuration errors, or in the future if the connection could
/// not be established.
pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
let config = match &self.default_client_config {
Some(config) => config.clone(),
None => return Err(ConnectError::NoDefaultClientConfig),
};
self.connect_with(config, addr, server_name)
}
/// Connect to a remote endpoint using a custom configuration.
///
/// See [`connect()`] for details.
///
/// [`connect()`]: Endpoint::connect
pub fn connect_with(
&self,
config: ClientConfig,
addr: SocketAddr,
server_name: &str,
) -> Result<Connecting, ConnectError> {
let mut endpoint = self.inner.state.lock().unwrap();
if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
return Err(ConnectError::EndpointStopping);
}
if addr.is_ipv6() && !endpoint.ipv6 {
return Err(ConnectError::InvalidRemoteAddress(addr));
}
let addr = if endpoint.ipv6 {
SocketAddr::V6(ensure_ipv6(addr))
} else {
addr
};
let (ch, conn) = endpoint
.inner
.connect(self.runtime.now(), config, addr, server_name)?;
let socket = endpoint.socket.clone();
endpoint.stats.outgoing_handshakes += 1;
Ok(endpoint
.recv_state
.connections
.insert(ch, conn, socket, self.runtime.clone()))
}
/// Switch to a new UDP socket
///
/// See [`Endpoint::rebind_abstract()`] for details.
#[cfg(not(wasm_browser))]
pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
}
/// Switch to a new UDP socket
///
/// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
/// connections and connections to servers unreachable from the new address will be lost.
///
/// On error, the old UDP socket is retained.
pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
let addr = socket.local_addr()?;
let mut inner = self.inner.state.lock().unwrap();
inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
inner.ipv6 = addr.is_ipv6();
// Update connection socket references
for sender in inner.recv_state.connections.senders.values() {
// Ignoring errors from dropped connections
let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
}
if let Some(driver) = inner.driver.take() {
// Ensure the driver can register for wake-ups from the new socket
driver.wake();
}
Ok(())
}
/// Replace the server configuration, affecting new incoming connections only
///
/// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
self.inner
.state
.lock()
.unwrap()
.inner
.set_server_config(server_config.map(Arc::new))
}
/// Get the local `SocketAddr` the underlying socket is bound to
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.state.lock().unwrap().socket.local_addr()
}
/// Get the number of connections that are currently open
pub fn open_connections(&self) -> usize {
self.inner.state.lock().unwrap().inner.open_connections()
}
/// Close all of this endpoint's connections immediately and cease accepting new connections.
///
/// See [`Connection::close()`] for details.
///
/// [`Connection::close()`]: crate::Connection::close
pub fn close(&self, error_code: VarInt, reason: &[u8]) {
let reason = Bytes::copy_from_slice(reason);
let mut endpoint = self.inner.state.lock().unwrap();
endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
for sender in endpoint.recv_state.connections.senders.values() {
// Ignoring errors from dropped connections
let _ = sender.send(ConnectionEvent::Close {
error_code,
reason: reason.clone(),
});
}
self.inner.shared.incoming.notify_waiters();
}
/// Wait for all connections on the endpoint to be cleanly shut down
///
/// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
/// peers of recent connection closes, whereas exiting immediately could force them to wait out
/// the idle timeout period.
///
/// Does not proactively close existing connections or cause incoming connections to be
/// rejected. Consider calling [`close()`] if that is desired.
///
/// [`close()`]: Endpoint::close
pub async fn wait_idle(&self) {
loop {
{
let endpoint = &mut *self.inner.state.lock().unwrap();
if endpoint.recv_state.connections.is_empty() {
break;
}
// Construct future while lock is held to avoid race
self.inner.shared.idle.notified()
}
.await;
}
}
}
/// Statistics on [Endpoint] activity
#[non_exhaustive]
#[derive(Debug, Default, Copy, Clone)]
pub struct EndpointStats {
/// Cummulative number of Quic handshakes accepted by this [Endpoint]
pub accepted_handshakes: u64,
/// Cummulative number of Quic handshakees sent from this [Endpoint]
pub outgoing_handshakes: u64,
/// Cummulative number of Quic handshakes refused on this [Endpoint]
pub refused_handshakes: u64,
/// Cummulative number of Quic handshakes ignored on this [Endpoint]
pub ignored_handshakes: u64,
}
/// A future that drives IO on an endpoint
///
/// This task functions as the switch point between the UDP socket object and the
/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
/// In order to do so, it also facilitates the exchange of different types of events
/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
/// running this task is necessary to keep the endpoint's connections running.
///
/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
/// an I/O error occurs.
#[must_use = "endpoint drivers must be spawned for I/O to occur"]
#[derive(Debug)]
pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
impl Future for EndpointDriver {
type Output = Result<(), io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut endpoint = self.0.state.lock().unwrap();
if endpoint.driver.is_none() {
endpoint.driver = Some(cx.waker().clone());
}
let now = endpoint.runtime.now();
let mut keep_going = false;
keep_going |= endpoint.drive_recv(cx, now)?;
keep_going |= endpoint.handle_events(cx, &self.0.shared);
if !endpoint.recv_state.incoming.is_empty() {
self.0.shared.incoming.notify_waiters();
}
if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
Poll::Ready(Ok(()))
} else {
drop(endpoint);
// If there is more work to do schedule the endpoint task again.
// `wake_by_ref()` is called outside the lock to minimize
// lock contention on a multithreaded runtime.
if keep_going {
cx.waker().wake_by_ref();
}
Poll::Pending
}
}
}
impl Drop for EndpointDriver {
fn drop(&mut self) {
let mut endpoint = self.0.state.lock().unwrap();
endpoint.driver_lost = true;
self.0.shared.incoming.notify_waiters();
// Drop all outgoing channels, signaling the termination of the endpoint to the associated
// connections.
endpoint.recv_state.connections.senders.clear();
}
}
#[derive(Debug)]
pub(crate) struct EndpointInner {
pub(crate) state: Mutex<State>,
pub(crate) shared: Shared,
}
impl EndpointInner {
pub(crate) fn accept(
&self,
incoming: proto::Incoming,
server_config: Option<Arc<ServerConfig>>,
) -> Result<Connecting, ConnectionError> {
let mut state = self.state.lock().unwrap();
let mut response_buffer = Vec::new();
let now = state.runtime.now();
match state
.inner
.accept(incoming, now, &mut response_buffer, server_config)
{
Ok((handle, conn)) => {
state.stats.accepted_handshakes += 1;
let socket = state.socket.clone();
let runtime = state.runtime.clone();
Ok(state
.recv_state
.connections
.insert(handle, conn, socket, runtime))
}
Err(error) => {
if let Some(transmit) = error.response {
respond(transmit, &response_buffer, &*state.socket);
}
Err(error.cause)
}
}
}
pub(crate) fn refuse(&self, incoming: proto::Incoming) {
let mut state = self.state.lock().unwrap();
state.stats.refused_handshakes += 1;
let mut response_buffer = Vec::new();
let transmit = state.inner.refuse(incoming, &mut response_buffer);
respond(transmit, &response_buffer, &*state.socket);
}
pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
let mut state = self.state.lock().unwrap();
let mut response_buffer = Vec::new();
let transmit = state.inner.retry(incoming, &mut response_buffer)?;
respond(transmit, &response_buffer, &*state.socket);
Ok(())
}
pub(crate) fn ignore(&self, incoming: proto::Incoming) {
let mut state = self.state.lock().unwrap();
state.stats.ignored_handshakes += 1;
state.inner.ignore(incoming);
}
}
#[derive(Debug)]
pub(crate) struct State {
socket: Arc<dyn AsyncUdpSocket>,
/// During an active migration, abandoned_socket receives traffic
/// until the first packet arrives on the new socket.
prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
inner: proto::Endpoint,
recv_state: RecvState,
driver: Option<Waker>,
ipv6: bool,
events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
/// Number of live handles that can be used to initiate or handle I/O; excludes the driver
ref_count: usize,
driver_lost: bool,
runtime: Arc<dyn Runtime>,
stats: EndpointStats,
}
#[derive(Debug)]
pub(crate) struct Shared {
incoming: Notify,
idle: Notify,
}
impl State {
fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
let get_time = || self.runtime.now();
self.recv_state.recv_limiter.start_cycle(get_time);
if let Some(socket) = &self.prev_socket {
// We don't care about the `PollProgress` from old sockets.
let poll_res =
self.recv_state
.poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
if poll_res.is_err() {
self.prev_socket = None;
}
};
let poll_res =
self.recv_state
.poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
self.recv_state.recv_limiter.finish_cycle(get_time);
let poll_res = poll_res?;
if poll_res.received_connection_packet {
// Traffic has arrived on self.socket, therefore there is no need for the abandoned
// one anymore. TODO: Account for multiple outgoing connections.
self.prev_socket = None;
}
Ok(poll_res.keep_going)
}
fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
for _ in 0..IO_LOOP_BOUND {
let (ch, event) = match self.events.poll_recv(cx) {
Poll::Ready(Some(x)) => x,
Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
Poll::Pending => {
return false;
}
};
if event.is_drained() {
self.recv_state.connections.senders.remove(&ch);
if self.recv_state.connections.is_empty() {
shared.idle.notify_waiters();
}
}
let Some(event) = self.inner.handle_event(ch, event) else {
continue;
};
// Ignoring errors from dropped connections that haven't yet been cleaned up
let _ = self
.recv_state
.connections
.senders
.get_mut(&ch)
.unwrap()
.send(ConnectionEvent::Proto(event));
}
true
}
}
impl Drop for State {
fn drop(&mut self) {
for incoming in self.recv_state.incoming.drain(..) {
self.inner.ignore(incoming);
}
}
}
fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
// Send if there's kernel buffer space; otherwise, drop it
//
// As an endpoint-generated packet, we know this is an
// immediate, stateless response to an unconnected peer,
// one of:
//
// - A version negotiation response due to an unknown version
// - A `CLOSE` due to a malformed or unwanted connection attempt
// - A stateless reset due to an unrecognized connection
// - A `Retry` packet due to a connection attempt when
// `use_retry` is set
//
// In each case, a well-behaved peer can be trusted to retry a
// few times, which is guaranteed to produce the same response
// from us. Repeated failures might at worst cause a peer's new
// connection attempt to time out, which is acceptable if we're
// under such heavy load that there's never room for this code
// to transmit. This is morally equivalent to the packet getting
// lost due to congestion further along the link, which
// similarly relies on peer retries for recovery.
_ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
}
#[inline]
fn proto_ecn(ecn: udp::EcnCodepoint) -> proto::EcnCodepoint {
match ecn {
udp::EcnCodepoint::Ect0 => proto::EcnCodepoint::Ect0,
udp::EcnCodepoint::Ect1 => proto::EcnCodepoint::Ect1,
udp::EcnCodepoint::Ce => proto::EcnCodepoint::Ce,
}
}
#[derive(Debug)]
struct ConnectionSet {
/// Senders for communicating with the endpoint's connections
senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
/// Stored to give out clones to new ConnectionInners
sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
/// Set if the endpoint has been manually closed
close: Option<(VarInt, Bytes)>,
}
impl ConnectionSet {
fn insert(
&mut self,
handle: ConnectionHandle,
conn: proto::Connection,
socket: Arc<dyn AsyncUdpSocket>,
runtime: Arc<dyn Runtime>,
) -> Connecting {
let (send, recv) = mpsc::unbounded_channel();
if let Some((error_code, ref reason)) = self.close {
send.send(ConnectionEvent::Close {
error_code,
reason: reason.clone(),
})
.unwrap();
}
self.senders.insert(handle, send);
Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
}
fn is_empty(&self) -> bool {
self.senders.is_empty()
}
}
fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
match x {
SocketAddr::V6(x) => x,
SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
}
}
pin_project! {
/// Future produced by [`Endpoint::accept`]
pub struct Accept<'a> {
endpoint: &'a Endpoint,
#[pin]
notify: Notified<'a>,
}
}
impl Future for Accept<'_> {
type Output = Option<Incoming>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
let mut endpoint = this.endpoint.inner.state.lock().unwrap();
if endpoint.driver_lost {
return Poll::Ready(None);
}
if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
// Release the mutex lock on endpoint so cloning it doesn't deadlock
drop(endpoint);
let incoming = Incoming::new(incoming, this.endpoint.inner.clone());
return Poll::Ready(Some(incoming));
}
if endpoint.recv_state.connections.close.is_some() {
return Poll::Ready(None);
}
loop {
match this.notify.as_mut().poll(ctx) {
// `state` lock ensures we didn't race with readiness
Poll::Pending => return Poll::Pending,
// Spurious wakeup, get a new future
Poll::Ready(()) => this
.notify
.set(this.endpoint.inner.shared.incoming.notified()),
}
}
}
}
#[derive(Debug)]
pub(crate) struct EndpointRef(Arc<EndpointInner>);
impl EndpointRef {
pub(crate) fn new(
socket: Arc<dyn AsyncUdpSocket>,
inner: proto::Endpoint,
ipv6: bool,
runtime: Arc<dyn Runtime>,
) -> Self {
let (sender, events) = mpsc::unbounded_channel();
let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
Self(Arc::new(EndpointInner {
shared: Shared {
incoming: Notify::new(),
idle: Notify::new(),
},
state: Mutex::new(State {
socket,
prev_socket: None,
inner,
ipv6,
events,
driver: None,
ref_count: 0,
driver_lost: false,
recv_state,
runtime,
stats: EndpointStats::default(),
}),
}))
}
}
impl Clone for EndpointRef {
fn clone(&self) -> Self {
self.0.state.lock().unwrap().ref_count += 1;
Self(self.0.clone())
}
}
impl Drop for EndpointRef {
fn drop(&mut self) {
let endpoint = &mut *self.0.state.lock().unwrap();
if let Some(x) = endpoint.ref_count.checked_sub(1) {
endpoint.ref_count = x;
if x == 0 {
// If the driver is about to be on its own, ensure it can shut down if the last
// connection is gone.
if let Some(task) = endpoint.driver.take() {
task.wake();
}
}
}
}
}
impl std::ops::Deref for EndpointRef {
type Target = EndpointInner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// State directly involved in handling incoming packets
struct RecvState {
incoming: VecDeque<proto::Incoming>,
connections: ConnectionSet,
recv_buf: Box<[u8]>,
recv_limiter: WorkLimiter,
}
impl RecvState {
fn new(
sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
max_receive_segments: usize,
endpoint: &proto::Endpoint,
) -> Self {
let recv_buf = vec![
0;
endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
* max_receive_segments
* BATCH_SIZE
];
Self {
connections: ConnectionSet {
senders: FxHashMap::default(),
sender,
close: None,
},
incoming: VecDeque::new(),
recv_buf: recv_buf.into(),
recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
}
}
fn poll_socket(
&mut self,
cx: &mut Context,
endpoint: &mut proto::Endpoint,
socket: &dyn AsyncUdpSocket,
runtime: &dyn Runtime,
now: Instant,
) -> Result<PollProgress, io::Error> {
let mut received_connection_packet = false;
let mut metas = [RecvMeta::default(); BATCH_SIZE];
let mut iovs: [IoSliceMut; BATCH_SIZE] = {
let mut bufs = self
.recv_buf
.chunks_mut(self.recv_buf.len() / BATCH_SIZE)
.map(IoSliceMut::new);
// expect() safe as self.recv_buf is chunked into BATCH_SIZE items
// and iovs will be of size BATCH_SIZE, thus from_fn is called
// exactly BATCH_SIZE times.
std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
};
loop {
match socket.poll_recv(cx, &mut iovs, &mut metas) {
Poll::Ready(Ok(msgs)) => {
self.recv_limiter.record_work(msgs);
for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
let mut data: BytesMut = buf[0..meta.len].into();
while !data.is_empty() {
let buf = data.split_to(meta.stride.min(data.len()));
let mut response_buffer = Vec::new();
match endpoint.handle(
now,
meta.addr,
meta.dst_ip,
meta.ecn.map(proto_ecn),
buf,
&mut response_buffer,
) {
Some(DatagramEvent::NewConnection(incoming)) => {
if self.connections.close.is_none() {
self.incoming.push_back(incoming);
} else {
let transmit =
endpoint.refuse(incoming, &mut response_buffer);
respond(transmit, &response_buffer, socket);
}
}
Some(DatagramEvent::ConnectionEvent(handle, event)) => {
// Ignoring errors from dropped connections that haven't yet been cleaned up
received_connection_packet = true;
let _ = self
.connections
.senders
.get_mut(&handle)
.unwrap()
.send(ConnectionEvent::Proto(event));
}
Some(DatagramEvent::Response(transmit)) => {
respond(transmit, &response_buffer, socket);
}
None => {}
}
}
}
}
Poll::Pending => {
return Ok(PollProgress {
received_connection_packet,
keep_going: false,
});
}
// Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
// attacker
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
continue;
}
Poll::Ready(Err(e)) => {
return Err(e);
}
}
if !self.recv_limiter.allow_work(|| runtime.now()) {
return Ok(PollProgress {
received_connection_packet,
keep_going: true,
});
}
}
}
}
impl fmt::Debug for RecvState {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("RecvState")
.field("incoming", &self.incoming)
.field("connections", &self.connections)
// recv_buf too large
.field("recv_limiter", &self.recv_limiter)
.finish_non_exhaustive()
}
}
#[derive(Default)]
struct PollProgress {
/// Whether a datagram was routed to an existing connection
received_connection_packet: bool,
/// Whether datagram handling was interrupted early by the work limiter for fairness
keep_going: bool,
}

152
vendor/quinn/src/incoming.rs vendored Normal file
View File

@@ -0,0 +1,152 @@
use std::{
future::{Future, IntoFuture},
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use proto::{ConnectionError, ConnectionId, ServerConfig};
use thiserror::Error;
use crate::{
connection::{Connecting, Connection},
endpoint::EndpointRef,
};
/// An incoming connection for which the server has not yet begun its part of the handshake
#[derive(Debug)]
pub struct Incoming(Option<State>);
impl Incoming {
pub(crate) fn new(inner: proto::Incoming, endpoint: EndpointRef) -> Self {
Self(Some(State { inner, endpoint }))
}
/// Attempt to accept this incoming connection (an error may still occur)
pub fn accept(mut self) -> Result<Connecting, ConnectionError> {
let state = self.0.take().unwrap();
state.endpoint.accept(state.inner, None)
}
/// Accept this incoming connection using a custom configuration
///
/// See [`accept()`][Incoming::accept] for more details.
pub fn accept_with(
mut self,
server_config: Arc<ServerConfig>,
) -> Result<Connecting, ConnectionError> {
let state = self.0.take().unwrap();
state.endpoint.accept(state.inner, Some(server_config))
}
/// Reject this incoming connection attempt
pub fn refuse(mut self) {
let state = self.0.take().unwrap();
state.endpoint.refuse(state.inner);
}
/// Respond with a retry packet, requiring the client to retry with address validation
///
/// Errors if `may_retry()` is false.
pub fn retry(mut self) -> Result<(), RetryError> {
let state = self.0.take().unwrap();
state.endpoint.retry(state.inner).map_err(|e| {
RetryError(Box::new(Self(Some(State {
inner: e.into_incoming(),
endpoint: state.endpoint,
}))))
})
}
/// Ignore this incoming connection attempt, not sending any packet in response
pub fn ignore(mut self) {
let state = self.0.take().unwrap();
state.endpoint.ignore(state.inner);
}
/// The local IP address which was used when the peer established the connection
pub fn local_ip(&self) -> Option<IpAddr> {
self.0.as_ref().unwrap().inner.local_ip()
}
/// The peer's UDP address
pub fn remote_address(&self) -> SocketAddr {
self.0.as_ref().unwrap().inner.remote_address()
}
/// Whether the socket address that is initiating this connection has been validated
///
/// This means that the sender of the initial packet has proved that they can receive traffic
/// sent to `self.remote_address()`.
///
/// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true.
/// The inverse is not guaranteed.
pub fn remote_address_validated(&self) -> bool {
self.0.as_ref().unwrap().inner.remote_address_validated()
}
/// Whether it is legal to respond with a retry packet
///
/// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true.
/// The inverse is not guaranteed.
pub fn may_retry(&self) -> bool {
self.0.as_ref().unwrap().inner.may_retry()
}
/// The original destination CID when initiating the connection
pub fn orig_dst_cid(&self) -> ConnectionId {
*self.0.as_ref().unwrap().inner.orig_dst_cid()
}
}
impl Drop for Incoming {
fn drop(&mut self) {
// Implicit reject, similar to Connection's implicit close
if let Some(state) = self.0.take() {
state.endpoint.refuse(state.inner);
}
}
}
#[derive(Debug)]
struct State {
inner: proto::Incoming,
endpoint: EndpointRef,
}
/// Error for attempting to retry an [`Incoming`] which already bears a token from a previous retry
#[derive(Debug, Error)]
#[error("retry() with validated Incoming")]
pub struct RetryError(Box<Incoming>);
impl RetryError {
/// Get the [`Incoming`]
pub fn into_incoming(self) -> Incoming {
*self.0
}
}
/// Basic adapter to let [`Incoming`] be `await`-ed like a [`Connecting`]
#[derive(Debug)]
pub struct IncomingFuture(Result<Connecting, ConnectionError>);
impl Future for IncomingFuture {
type Output = Result<Connection, ConnectionError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match &mut self.0 {
Ok(ref mut connecting) => Pin::new(connecting).poll(cx),
Err(e) => Poll::Ready(Err(e.clone())),
}
}
}
impl IntoFuture for Incoming {
type Output = Result<Connection, ConnectionError>;
type IntoFuture = IncomingFuture;
fn into_future(self) -> Self::IntoFuture {
IncomingFuture(self.accept())
}
}

136
vendor/quinn/src/lib.rs vendored Normal file
View File

@@ -0,0 +1,136 @@
//! QUIC transport protocol implementation
//!
//! [QUIC](https://en.wikipedia.org/wiki/QUIC) is a modern transport protocol addressing
//! shortcomings of TCP, such as head-of-line blocking, poor security, slow handshakes, and
//! inefficient congestion control. This crate provides a portable userspace implementation. It
//! builds on top of quinn-proto, which implements protocol logic independent of any particular
//! runtime.
//!
//! The entry point of this crate is the [`Endpoint`].
//!
//! # About QUIC
//!
//! A QUIC connection is an association between two endpoints. The endpoint which initiates the
//! connection is termed the client, and the endpoint which accepts it is termed the server. A
//! single endpoint may function as both client and server for different connections, for example
//! in a peer-to-peer application. To communicate application data, each endpoint may open streams
//! up to a limit dictated by its peer. Typically, that limit is increased as old streams are
//! finished.
//!
//! Streams may be unidirectional or bidirectional, and are cheap to create and disposable. For
//! example, a traditionally datagram-oriented application could use a new stream for every
//! message it wants to send, no longer needing to worry about MTUs. Bidirectional streams behave
//! much like a traditional TCP connection, and are useful for sending messages that have an
//! immediate response, such as an HTTP request. Stream data is delivered reliably, and there is no
//! ordering enforced between data on different streams.
//!
//! By avoiding head-of-line blocking and providing unified congestion control across all streams
//! of a connection, QUIC is able to provide higher throughput and lower latency than one or
//! multiple TCP connections between the same two hosts, while providing more useful behavior than
//! raw UDP sockets.
//!
//! Quinn also exposes unreliable datagrams, which are a low-level primitive preferred when
//! automatic fragmentation and retransmission of certain data is not desired.
//!
//! QUIC uses encryption and identity verification built directly on TLS 1.3. Just as with a TLS
//! server, it is useful for a QUIC server to be identified by a certificate signed by a trusted
//! authority. If this is infeasible--for example, if servers are short-lived or not associated
//! with a domain name--then as with TLS, self-signed certificates can be used to provide
//! encryption alone.
#![warn(missing_docs)]
#![warn(unreachable_pub)]
#![warn(clippy::use_self)]
use std::sync::Arc;
mod connection;
mod endpoint;
mod incoming;
mod mutex;
mod recv_stream;
mod runtime;
mod send_stream;
mod work_limiter;
#[cfg(not(wasm_browser))]
pub(crate) use std::time::{Duration, Instant};
#[cfg(wasm_browser)]
pub(crate) use web_time::{Duration, Instant};
#[cfg(feature = "bloom")]
pub use proto::BloomTokenLog;
pub use proto::{
AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ClosedStream, ConfigError,
ConnectError, ConnectionClose, ConnectionError, ConnectionId, ConnectionIdGenerator,
ConnectionStats, Dir, EcnCodepoint, EndpointConfig, FrameStats, FrameType, IdleTimeout,
MtuDiscoveryConfig, NoneTokenLog, NoneTokenStore, PathStats, ServerConfig, Side, StdSystemTime,
StreamId, TimeSource, TokenLog, TokenMemoryCache, TokenReuseError, TokenStore, Transmit,
TransportConfig, TransportErrorCode, UdpStats, ValidationTokenConfig, VarInt,
VarIntBoundsExceeded, Written, congestion, crypto,
};
#[cfg(feature = "qlog")]
pub use proto::{QlogConfig, QlogStream};
#[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
pub use rustls;
pub use udp;
pub use crate::connection::{
AcceptBi, AcceptUni, Connecting, Connection, OpenBi, OpenUni, ReadDatagram, SendDatagram,
SendDatagramError, ZeroRttAccepted,
};
pub use crate::endpoint::{Accept, Endpoint, EndpointStats};
pub use crate::incoming::{Incoming, IncomingFuture, RetryError};
pub use crate::recv_stream::{ReadError, ReadExactError, ReadToEndError, RecvStream, ResetError};
#[cfg(feature = "runtime-async-std")]
pub use crate::runtime::AsyncStdRuntime;
#[cfg(feature = "runtime-smol")]
pub use crate::runtime::SmolRuntime;
#[cfg(feature = "runtime-tokio")]
pub use crate::runtime::TokioRuntime;
pub use crate::runtime::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller, default_runtime};
pub use crate::send_stream::{SendStream, StoppedError, WriteError};
#[cfg(test)]
mod tests;
#[derive(Debug)]
enum ConnectionEvent {
Close {
error_code: VarInt,
reason: bytes::Bytes,
},
Proto(proto::ConnectionEvent),
Rebind(Arc<dyn AsyncUdpSocket>),
}
fn udp_transmit<'a>(t: &proto::Transmit, buffer: &'a [u8]) -> udp::Transmit<'a> {
udp::Transmit {
destination: t.destination,
ecn: t.ecn.map(udp_ecn),
contents: buffer,
segment_size: t.segment_size,
src_ip: t.src_ip,
}
}
fn udp_ecn(ecn: proto::EcnCodepoint) -> udp::EcnCodepoint {
match ecn {
proto::EcnCodepoint::Ect0 => udp::EcnCodepoint::Ect0,
proto::EcnCodepoint::Ect1 => udp::EcnCodepoint::Ect1,
proto::EcnCodepoint::Ce => udp::EcnCodepoint::Ce,
}
}
/// Maximum number of datagrams processed in send/recv calls to make before moving on to other processing
///
/// This helps ensure we don't starve anything when the CPU is slower than the link.
/// Value is selected by picking a low number which didn't degrade throughput in benchmarks.
const IO_LOOP_BOUND: usize = 160;
/// The maximum amount of time that should be spent in `recvmsg()` calls per endpoint iteration
///
/// 50us are chosen so that an endpoint iteration with a 50us sendmsg limit blocks
/// the runtime for a maximum of about 100us.
/// Going much lower does not yield any noticeable difference, since a single `recvmmsg`
/// batch of size 32 was observed to take 30us on some systems.
const RECV_TIME_BOUND: Duration = Duration::from_micros(50);

163
vendor/quinn/src/mutex.rs vendored Normal file
View File

@@ -0,0 +1,163 @@
use std::{
fmt::Debug,
ops::{Deref, DerefMut},
};
#[cfg(feature = "lock_tracking")]
mod tracking {
use super::*;
use crate::{Duration, Instant};
use std::collections::VecDeque;
use tracing::warn;
#[derive(Debug)]
struct Inner<T> {
last_lock_owner: VecDeque<(&'static str, Duration)>,
value: T,
}
/// A Mutex which optionally allows to track the time a lock was held and
/// emit warnings in case of excessive lock times
pub(crate) struct Mutex<T> {
inner: std::sync::Mutex<Inner<T>>,
}
impl<T: Debug> std::fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.inner, f)
}
}
impl<T> Mutex<T> {
pub(crate) fn new(value: T) -> Self {
Self {
inner: std::sync::Mutex::new(Inner {
last_lock_owner: VecDeque::new(),
value,
}),
}
}
/// Acquires the lock for a certain purpose
///
/// The purpose will be recorded in the list of last lock owners
pub(crate) fn lock(&self, purpose: &'static str) -> MutexGuard<'_, T> {
// We don't bother dispatching through Runtime::now because they're pure performance
// diagnostics.
let now = Instant::now();
let guard = self.inner.lock().unwrap();
let lock_time = Instant::now();
let elapsed = lock_time.duration_since(now);
if elapsed > Duration::from_millis(1) {
warn!(
"Locking the connection for {} took {:?}. Last owners: {:?}",
purpose, elapsed, guard.last_lock_owner
);
}
MutexGuard {
guard,
start_time: lock_time,
purpose,
}
}
}
pub(crate) struct MutexGuard<'a, T> {
guard: std::sync::MutexGuard<'a, Inner<T>>,
start_time: Instant,
purpose: &'static str,
}
impl<T> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
if self.guard.last_lock_owner.len() == MAX_LOCK_OWNERS {
self.guard.last_lock_owner.pop_back();
}
let duration = self.start_time.elapsed();
if duration > Duration::from_millis(1) {
warn!(
"Utilizing the connection for {} took {:?}",
self.purpose, duration
);
}
self.guard
.last_lock_owner
.push_front((self.purpose, duration));
}
}
impl<T> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.guard.value
}
}
impl<T> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard.value
}
}
const MAX_LOCK_OWNERS: usize = 20;
}
#[cfg(feature = "lock_tracking")]
pub(crate) use tracking::Mutex;
#[cfg(not(feature = "lock_tracking"))]
mod non_tracking {
use super::*;
/// A Mutex which optionally allows to track the time a lock was held and
/// emit warnings in case of excessive lock times
#[derive(Debug)]
pub(crate) struct Mutex<T> {
inner: std::sync::Mutex<T>,
}
impl<T> Mutex<T> {
pub(crate) fn new(value: T) -> Self {
Self {
inner: std::sync::Mutex::new(value),
}
}
/// Acquires the lock for a certain purpose
///
/// The purpose will be recorded in the list of last lock owners
pub(crate) fn lock(&self, _purpose: &'static str) -> MutexGuard<'_, T> {
MutexGuard {
guard: self.inner.lock().unwrap(),
}
}
}
pub(crate) struct MutexGuard<'a, T> {
guard: std::sync::MutexGuard<'a, T>,
}
impl<T> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.guard.deref()
}
}
impl<T> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.guard.deref_mut()
}
}
}
#[cfg(not(feature = "lock_tracking"))]
pub(crate) use non_tracking::Mutex;

699
vendor/quinn/src/recv_stream.rs vendored Normal file
View File

@@ -0,0 +1,699 @@
use std::{
future::{Future, poll_fn},
io,
pin::Pin,
task::{Context, Poll, ready},
};
use bytes::Bytes;
use proto::{Chunk, Chunks, ClosedStream, ConnectionError, ReadableError, StreamId};
use thiserror::Error;
use tokio::io::ReadBuf;
use crate::{VarInt, connection::ConnectionRef};
/// A stream that can only be used to receive data
///
/// `stop(0)` is implicitly called on drop unless:
/// - A variant of [`ReadError`] has been yielded by a read call
/// - [`stop()`] was called explicitly
///
/// # Cancellation
///
/// A `read` method is said to be *cancel-safe* when dropping its future before the future becomes
/// ready cannot lead to loss of stream data. This is true of methods which succeed immediately when
/// any progress is made, and is not true of methods which might need to perform multiple reads
/// internally before succeeding. Each `read` method documents whether it is cancel-safe.
///
/// # Common issues
///
/// ## Data never received on a locally-opened stream
///
/// Peers are not notified of streams until they or a later-numbered stream are used to send
/// data. If a bidirectional stream is locally opened but never used to send, then the peer may
/// never see it. Application protocols should always arrange for the endpoint which will first
/// transmit on a stream to be the endpoint responsible for opening it.
///
/// ## Data never received on a remotely-opened stream
///
/// Verify that the stream you are receiving is the same one that the server is sending on, e.g. by
/// logging the [`id`] of each. Streams are always accepted in the same order as they are created,
/// i.e. ascending order by [`StreamId`]. For example, even if a sender first transmits on
/// bidirectional stream 1, the first stream yielded by [`Connection::accept_bi`] on the receiver
/// will be bidirectional stream 0.
///
/// [`ReadError`]: crate::ReadError
/// [`stop()`]: RecvStream::stop
/// [`SendStream::finish`]: crate::SendStream::finish
/// [`WriteError::Stopped`]: crate::WriteError::Stopped
/// [`id`]: RecvStream::id
/// [`Connection::accept_bi`]: crate::Connection::accept_bi
#[derive(Debug)]
pub struct RecvStream {
conn: ConnectionRef,
stream: StreamId,
is_0rtt: bool,
all_data_read: bool,
reset: Option<VarInt>,
}
impl RecvStream {
pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
Self {
conn,
stream,
is_0rtt,
all_data_read: false,
reset: None,
}
}
/// Read data contiguously from the stream.
///
/// Yields the number of bytes read into `buf` on success, or `None` if the stream was finished.
///
/// This operation is cancel-safe.
pub async fn read(&mut self, buf: &mut [u8]) -> Result<Option<usize>, ReadError> {
Read {
stream: self,
buf: ReadBuf::new(buf),
}
.await
}
/// Read an exact number of bytes contiguously from the stream.
///
/// See [`read()`] for details. This operation is *not* cancel-safe.
///
/// [`read()`]: RecvStream::read
pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), ReadExactError> {
ReadExact {
stream: self,
buf: ReadBuf::new(buf),
}
.await
}
/// Attempts to read from the stream into the provided buffer
///
/// On success, returns `Poll::Ready(Ok(num_bytes_read))` and places data into `buf`. If this
/// returns zero bytes read (and `buf` has a non-zero length), that indicates that the remote
/// side has [`finish`]ed the stream and the local side has already read all bytes.
///
/// If no data is available for reading, this returns `Poll::Pending` and arranges for the
/// current task (via `cx.waker()`) to be notified when the stream becomes readable or is
/// closed.
///
/// [`finish`]: crate::SendStream::finish
pub fn poll_read(
&mut self,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, ReadError>> {
let mut buf = ReadBuf::new(buf);
ready!(self.poll_read_buf(cx, &mut buf))?;
Poll::Ready(Ok(buf.filled().len()))
}
/// Attempts to read from the stream into the provided buffer, which may be uninitialized
///
/// On success, returns `Poll::Ready(Ok(()))` and places data into the unfilled portion of
/// `buf`. If this does not write any bytes to `buf` (and `buf.remaining()` is non-zero), that
/// indicates that the remote side has [`finish`]ed the stream and the local side has already
/// read all bytes.
///
/// If no data is available for reading, this returns `Poll::Pending` and arranges for the
/// current task (via `cx.waker()`) to be notified when the stream becomes readable or is
/// closed.
///
/// [`finish`]: crate::SendStream::finish
pub fn poll_read_buf(
&mut self,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), ReadError>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
self.poll_read_generic(cx, true, |chunks| {
let mut read = false;
loop {
if buf.remaining() == 0 {
// We know `read` is `true` because `buf.remaining()` was not 0 before
return ReadStatus::Readable(());
}
match chunks.next(buf.remaining()) {
Ok(Some(chunk)) => {
buf.put_slice(&chunk.bytes);
read = true;
}
res => return (if read { Some(()) } else { None }, res.err()).into(),
}
}
})
.map(|res| res.map(|_| ()))
}
/// Read the next segment of data
///
/// Yields `None` if the stream was finished. Otherwise, yields a segment of data and its
/// offset in the stream. If `ordered` is `true`, the chunk's offset will be immediately after
/// the last data yielded by `read()` or `read_chunk()`. If `ordered` is `false`, segments may
/// be received in any order, and the `Chunk`'s `offset` field can be used to determine
/// ordering in the caller. Unordered reads are less prone to head-of-line blocking within a
/// stream, but require the application to manage reassembling the original data.
///
/// Slightly more efficient than `read` due to not copying. Chunk boundaries do not correspond
/// to peer writes, and hence cannot be used as framing.
///
/// This operation is cancel-safe.
pub async fn read_chunk(
&mut self,
max_length: usize,
ordered: bool,
) -> Result<Option<Chunk>, ReadError> {
ReadChunk {
stream: self,
max_length,
ordered,
}
.await
}
/// Attempts to read a chunk from the stream.
///
/// On success, returns `Poll::Ready(Ok(Some(chunk)))`. If `Poll::Ready(Ok(None))`
/// is returned, it implies that EOF has been reached.
///
/// If no data is available for reading, the method returns `Poll::Pending`
/// and arranges for the current task (via cx.waker()) to receive a notification
/// when the stream becomes readable or is closed.
fn poll_read_chunk(
&mut self,
cx: &mut Context,
max_length: usize,
ordered: bool,
) -> Poll<Result<Option<Chunk>, ReadError>> {
self.poll_read_generic(cx, ordered, |chunks| match chunks.next(max_length) {
Ok(Some(chunk)) => ReadStatus::Readable(chunk),
res => (None, res.err()).into(),
})
}
/// Read the next segments of data
///
/// Fills `bufs` with the segments of data beginning immediately after the
/// last data yielded by `read` or `read_chunk`, or `None` if the stream was
/// finished.
///
/// Slightly more efficient than `read` due to not copying. Chunk boundaries
/// do not correspond to peer writes, and hence cannot be used as framing.
///
/// This operation is cancel-safe.
pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Option<usize>, ReadError> {
ReadChunks { stream: self, bufs }.await
}
/// Foundation of [`Self::read_chunks`]
fn poll_read_chunks(
&mut self,
cx: &mut Context,
bufs: &mut [Bytes],
) -> Poll<Result<Option<usize>, ReadError>> {
if bufs.is_empty() {
return Poll::Ready(Ok(Some(0)));
}
self.poll_read_generic(cx, true, |chunks| {
let mut read = 0;
loop {
if read >= bufs.len() {
// We know `read > 0` because `bufs` cannot be empty here
return ReadStatus::Readable(read);
}
match chunks.next(usize::MAX) {
Ok(Some(chunk)) => {
bufs[read] = chunk.bytes;
read += 1;
}
res => return (if read == 0 { None } else { Some(read) }, res.err()).into(),
}
}
})
}
/// Convenience method to read all remaining data into a buffer
///
/// Fails with [`ReadToEndError::TooLong`] on reading more than `size_limit` bytes, discarding
/// all data read. Uses unordered reads to be more efficient than using `AsyncRead` would
/// allow. `size_limit` should be set to limit worst-case memory use.
///
/// If unordered reads have already been made, the resulting buffer may have gaps containing
/// arbitrary data.
///
/// This operation is *not* cancel-safe.
///
/// [`ReadToEndError::TooLong`]: crate::ReadToEndError::TooLong
pub async fn read_to_end(&mut self, size_limit: usize) -> Result<Vec<u8>, ReadToEndError> {
ReadToEnd {
stream: self,
size_limit,
read: Vec::new(),
start: u64::MAX,
end: 0,
}
.await
}
/// Stop accepting data
///
/// Discards unread data and notifies the peer to stop transmitting. Once stopped, further
/// attempts to operate on a stream will yield `ClosedStream` errors.
pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
let mut conn = self.conn.state.lock("RecvStream::stop");
if self.is_0rtt && conn.check_0rtt().is_err() {
return Ok(());
}
conn.inner.recv_stream(self.stream).stop(error_code)?;
conn.wake();
self.all_data_read = true;
Ok(())
}
/// Check if this stream has been opened during 0-RTT.
///
/// In which case any non-idempotent request should be considered dangerous at the application
/// level. Because read data is subject to replay attacks.
pub fn is_0rtt(&self) -> bool {
self.is_0rtt
}
/// Get the identity of this stream
pub fn id(&self) -> StreamId {
self.stream
}
/// Completes when the stream has been reset by the peer or otherwise closed
///
/// Yields `Some` with the reset error code when the stream is reset by the peer. Yields `None`
/// when the stream was previously [`stop()`](Self::stop)ed, or when the stream was
/// [`finish()`](crate::SendStream::finish)ed by the peer and all data has been received, after
/// which it is no longer meaningful for the stream to be reset.
///
/// This operation is cancel-safe.
pub async fn received_reset(&mut self) -> Result<Option<VarInt>, ResetError> {
poll_fn(|cx| {
let mut conn = self.conn.state.lock("RecvStream::reset");
if self.is_0rtt && conn.check_0rtt().is_err() {
return Poll::Ready(Err(ResetError::ZeroRttRejected));
}
if let Some(code) = self.reset {
return Poll::Ready(Ok(Some(code)));
}
match conn.inner.recv_stream(self.stream).received_reset() {
Err(_) => Poll::Ready(Ok(None)),
Ok(Some(error_code)) => {
// Stream state has just now been freed, so the connection may need to issue new
// stream ID flow control credit
conn.wake();
Poll::Ready(Ok(Some(error_code)))
}
Ok(None) => {
if let Some(e) = &conn.error {
return Poll::Ready(Err(e.clone().into()));
}
// Resets always notify readers, since a reset is an immediate read error. We
// could introduce a dedicated channel to reduce the risk of spurious wakeups,
// but that increased complexity is probably not justified, as an application
// that is expecting a reset is not likely to receive large amounts of data.
conn.blocked_readers.insert(self.stream, cx.waker().clone());
Poll::Pending
}
}
})
.await
}
/// Handle common logic related to reading out of a receive stream
///
/// This takes an `FnMut` closure that takes care of the actual reading process, matching
/// the detailed read semantics for the calling function with a particular return type.
/// The closure can read from the passed `&mut Chunks` and has to return the status after
/// reading: the amount of data read, and the status after the final read call.
fn poll_read_generic<T, U>(
&mut self,
cx: &mut Context,
ordered: bool,
mut read_fn: T,
) -> Poll<Result<Option<U>, ReadError>>
where
T: FnMut(&mut Chunks) -> ReadStatus<U>,
{
use proto::ReadError::*;
if self.all_data_read {
return Poll::Ready(Ok(None));
}
let mut conn = self.conn.state.lock("RecvStream::poll_read");
if self.is_0rtt {
conn.check_0rtt().map_err(|()| ReadError::ZeroRttRejected)?;
}
// If we stored an error during a previous call, return it now. This can happen if a
// `read_fn` both wants to return data and also returns an error in its final stream status.
let status = match self.reset {
Some(code) => ReadStatus::Failed(None, Reset(code)),
None => {
let mut recv = conn.inner.recv_stream(self.stream);
let mut chunks = recv.read(ordered)?;
let status = read_fn(&mut chunks);
if chunks.finalize().should_transmit() {
conn.wake();
}
status
}
};
match status {
ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
ReadStatus::Finished(read) => {
self.all_data_read = true;
Poll::Ready(Ok(read))
}
ReadStatus::Failed(read, Blocked) => match read {
Some(val) => Poll::Ready(Ok(Some(val))),
None => {
if let Some(ref x) = conn.error {
return Poll::Ready(Err(ReadError::ConnectionLost(x.clone())));
}
conn.blocked_readers.insert(self.stream, cx.waker().clone());
Poll::Pending
}
},
ReadStatus::Failed(read, Reset(error_code)) => match read {
None => {
self.all_data_read = true;
self.reset = Some(error_code);
Poll::Ready(Err(ReadError::Reset(error_code)))
}
done => {
self.reset = Some(error_code);
Poll::Ready(Ok(done))
}
},
}
}
}
enum ReadStatus<T> {
Readable(T),
Finished(Option<T>),
Failed(Option<T>, proto::ReadError),
}
impl<T> From<(Option<T>, Option<proto::ReadError>)> for ReadStatus<T> {
fn from(status: (Option<T>, Option<proto::ReadError>)) -> Self {
match status {
(read, None) => Self::Finished(read),
(read, Some(e)) => Self::Failed(read, e),
}
}
}
/// Future produced by [`RecvStream::read_to_end()`].
///
/// [`RecvStream::read_to_end()`]: crate::RecvStream::read_to_end
struct ReadToEnd<'a> {
stream: &'a mut RecvStream,
read: Vec<(Bytes, u64)>,
start: u64,
end: u64,
size_limit: usize,
}
impl Future for ReadToEnd<'_> {
type Output = Result<Vec<u8>, ReadToEndError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop {
match ready!(self.stream.poll_read_chunk(cx, usize::MAX, false))? {
Some(chunk) => {
self.start = self.start.min(chunk.offset);
let end = chunk.bytes.len() as u64 + chunk.offset;
if (end - self.start) > self.size_limit as u64 {
return Poll::Ready(Err(ReadToEndError::TooLong));
}
self.end = self.end.max(end);
self.read.push((chunk.bytes, chunk.offset));
}
None => {
if self.end == 0 {
// Never received anything
return Poll::Ready(Ok(Vec::new()));
}
let start = self.start;
let mut buffer = vec![0; (self.end - start) as usize];
for (data, offset) in self.read.drain(..) {
let offset = (offset - start) as usize;
buffer[offset..offset + data.len()].copy_from_slice(&data);
}
return Poll::Ready(Ok(buffer));
}
}
}
}
}
/// Errors from [`RecvStream::read_to_end`]
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadToEndError {
/// An error occurred during reading
#[error("read error: {0}")]
Read(#[from] ReadError),
/// The stream is larger than the user-supplied limit
#[error("stream too long")]
TooLong,
}
#[cfg(feature = "futures-io")]
impl futures_io::AsyncRead for RecvStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut buf = ReadBuf::new(buf);
ready!(Self::poll_read_buf(self.get_mut(), cx, &mut buf))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
impl tokio::io::AsyncRead for RecvStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
ready!(Self::poll_read_buf(self.get_mut(), cx, buf))?;
Poll::Ready(Ok(()))
}
}
impl Drop for RecvStream {
fn drop(&mut self) {
let mut conn = self.conn.state.lock("RecvStream::drop");
// clean up any previously registered wakers
conn.blocked_readers.remove(&self.stream);
if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
return;
}
if !self.all_data_read {
// Ignore ClosedStream errors
let _ = conn.inner.recv_stream(self.stream).stop(0u32.into());
conn.wake();
}
}
}
/// Errors that arise from reading from a stream.
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadError {
/// The peer abandoned transmitting data on this stream
///
/// Carries an application-defined error code.
#[error("stream reset by peer: error {0}")]
Reset(VarInt),
/// The connection was lost
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
/// The stream has already been stopped, finished, or reset
#[error("closed stream")]
ClosedStream,
/// Attempted an ordered read following an unordered read
///
/// Performing an unordered read allows discontinuities to arise in the receive buffer of a
/// stream which cannot be recovered, making further ordered reads impossible.
#[error("ordered read after unordered read")]
IllegalOrderedRead,
/// This was a 0-RTT stream and the server rejected it
///
/// Can only occur on clients for 0-RTT streams, which can be opened using
/// [`Connecting::into_0rtt()`].
///
/// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<ReadableError> for ReadError {
fn from(e: ReadableError) -> Self {
match e {
ReadableError::ClosedStream => Self::ClosedStream,
ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead,
}
}
}
impl From<ResetError> for ReadError {
fn from(e: ResetError) -> Self {
match e {
ResetError::ConnectionLost(e) => Self::ConnectionLost(e),
ResetError::ZeroRttRejected => Self::ZeroRttRejected,
}
}
}
impl From<ReadError> for io::Error {
fn from(x: ReadError) -> Self {
use ReadError::*;
let kind = match x {
Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
IllegalOrderedRead => io::ErrorKind::InvalidInput,
};
Self::new(kind, x)
}
}
/// Errors that arise while waiting for a stream to be reset
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ResetError {
/// The connection was lost
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
/// This was a 0-RTT stream and the server rejected it
///
/// Can only occur on clients for 0-RTT streams, which can be opened using
/// [`Connecting::into_0rtt()`].
///
/// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<ResetError> for io::Error {
fn from(x: ResetError) -> Self {
use ResetError::*;
let kind = match x {
ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) => io::ErrorKind::NotConnected,
};
Self::new(kind, x)
}
}
/// Future produced by [`RecvStream::read()`].
///
/// [`RecvStream::read()`]: crate::RecvStream::read
struct Read<'a> {
stream: &'a mut RecvStream,
buf: ReadBuf<'a>,
}
impl Future for Read<'_> {
type Output = Result<Option<usize>, ReadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
ready!(this.stream.poll_read_buf(cx, &mut this.buf))?;
match this.buf.filled().len() {
0 if this.buf.capacity() != 0 => Poll::Ready(Ok(None)),
n => Poll::Ready(Ok(Some(n))),
}
}
}
/// Future produced by [`RecvStream::read_exact()`].
///
/// [`RecvStream::read_exact()`]: crate::RecvStream::read_exact
struct ReadExact<'a> {
stream: &'a mut RecvStream,
buf: ReadBuf<'a>,
}
impl Future for ReadExact<'_> {
type Output = Result<(), ReadExactError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
let mut remaining = this.buf.remaining();
while remaining > 0 {
ready!(this.stream.poll_read_buf(cx, &mut this.buf))?;
let new = this.buf.remaining();
if new == remaining {
return Poll::Ready(Err(ReadExactError::FinishedEarly(this.buf.filled().len())));
}
remaining = new;
}
Poll::Ready(Ok(()))
}
}
/// Errors that arise from reading from a stream.
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadExactError {
/// The stream finished before all bytes were read
#[error("stream finished early ({0} bytes read)")]
FinishedEarly(usize),
/// A read error occurred
#[error(transparent)]
ReadError(#[from] ReadError),
}
/// Future produced by [`RecvStream::read_chunk()`].
///
/// [`RecvStream::read_chunk()`]: crate::RecvStream::read_chunk
struct ReadChunk<'a> {
stream: &'a mut RecvStream,
max_length: usize,
ordered: bool,
}
impl Future for ReadChunk<'_> {
type Output = Result<Option<Chunk>, ReadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let (max_length, ordered) = (self.max_length, self.ordered);
self.stream.poll_read_chunk(cx, max_length, ordered)
}
}
/// Future produced by [`RecvStream::read_chunks()`].
///
/// [`RecvStream::read_chunks()`]: crate::RecvStream::read_chunks
struct ReadChunks<'a> {
stream: &'a mut RecvStream,
bufs: &'a mut [Bytes],
}
impl Future for ReadChunks<'_> {
type Output = Result<Option<usize>, ReadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
this.stream.poll_read_chunks(cx, this.bufs)
}
}

200
vendor/quinn/src/runtime.rs vendored Normal file
View File

@@ -0,0 +1,200 @@
use std::{
fmt::Debug,
future::Future,
io::{self, IoSliceMut},
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use udp::{RecvMeta, Transmit};
use crate::Instant;
/// Abstracts I/O and timer operations for runtime independence
pub trait Runtime: Send + Sync + Debug + 'static {
/// Construct a timer that will expire at `i`
fn new_timer(&self, i: Instant) -> Pin<Box<dyn AsyncTimer>>;
/// Drive `future` to completion in the background
fn spawn(&self, future: Pin<Box<dyn Future<Output = ()> + Send>>);
/// Convert `t` into the socket type used by this runtime
#[cfg(not(wasm_browser))]
fn wrap_udp_socket(&self, t: std::net::UdpSocket) -> io::Result<Arc<dyn AsyncUdpSocket>>;
/// Look up the current time
///
/// Allows simulating the flow of time for testing.
fn now(&self) -> Instant {
Instant::now()
}
}
/// Abstract implementation of an async timer for runtime independence
pub trait AsyncTimer: Send + Debug + 'static {
/// Update the timer to expire at `i`
fn reset(self: Pin<&mut Self>, i: Instant);
/// Check whether the timer has expired, and register to be woken if not
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()>;
}
/// Abstract implementation of a UDP socket for runtime independence
pub trait AsyncUdpSocket: Send + Sync + Debug + 'static {
/// Create a [`UdpPoller`] that can register a single task for write-readiness notifications
///
/// A `poll_send` method on a single object can usually store only one [`Waker`] at a time,
/// i.e. allow at most one caller to wait for an event. This method allows any number of
/// interested tasks to construct their own [`UdpPoller`] object. They can all then wait for the
/// same event and be notified concurrently, because each [`UdpPoller`] can store a separate
/// [`Waker`].
///
/// [`Waker`]: std::task::Waker
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn UdpPoller>>;
/// Send UDP datagrams from `transmits`, or return `WouldBlock` and clear the underlying
/// socket's readiness, or return an I/O error
///
/// If this returns [`io::ErrorKind::WouldBlock`], [`UdpPoller::poll_writable`] must be called
/// to register the calling task to be woken when a send should be attempted again.
fn try_send(&self, transmit: &Transmit) -> io::Result<()>;
/// Receive UDP datagrams, or register to be woken if receiving may succeed in the future
fn poll_recv(
&self,
cx: &mut Context,
bufs: &mut [IoSliceMut<'_>],
meta: &mut [RecvMeta],
) -> Poll<io::Result<usize>>;
/// Look up the local IP address and port used by this socket
fn local_addr(&self) -> io::Result<SocketAddr>;
/// Maximum number of datagrams that a [`Transmit`] may encode
fn max_transmit_segments(&self) -> usize {
1
}
/// Maximum number of datagrams that might be described by a single [`RecvMeta`]
fn max_receive_segments(&self) -> usize {
1
}
/// Whether datagrams might get fragmented into multiple parts
///
/// Sockets should prevent this for best performance. See e.g. the `IPV6_DONTFRAG` socket
/// option.
fn may_fragment(&self) -> bool {
true
}
}
/// An object polled to detect when an associated [`AsyncUdpSocket`] is writable
///
/// Any number of `UdpPoller`s may exist for a single [`AsyncUdpSocket`]. Each `UdpPoller` is
/// responsible for notifying at most one task when that socket becomes writable.
pub trait UdpPoller: Send + Sync + Debug + 'static {
/// Check whether the associated socket is likely to be writable
///
/// Must be called after [`AsyncUdpSocket::try_send`] returns [`io::ErrorKind::WouldBlock`] to
/// register the task associated with `cx` to be woken when a send should be attempted
/// again. Unlike in [`Future::poll`], a [`UdpPoller`] may be reused indefinitely no matter how
/// many times `poll_writable` returns [`Poll::Ready`].
fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>>;
}
pin_project_lite::pin_project! {
/// Helper adapting a function `MakeFut` that constructs a single-use future `Fut` into a
/// [`UdpPoller`] that may be reused indefinitely
struct UdpPollHelper<MakeFut, Fut> {
make_fut: MakeFut,
#[pin]
fut: Option<Fut>,
}
}
impl<MakeFut, Fut> UdpPollHelper<MakeFut, Fut> {
/// Construct a [`UdpPoller`] that calls `make_fut` to get the future to poll, storing it until
/// it yields [`Poll::Ready`], then creating a new one on the next
/// [`poll_writable`](UdpPoller::poll_writable)
#[cfg(any(
feature = "runtime-async-std",
feature = "runtime-smol",
feature = "runtime-tokio",
))]
fn new(make_fut: MakeFut) -> Self {
Self {
make_fut,
fut: None,
}
}
}
impl<MakeFut, Fut> UdpPoller for UdpPollHelper<MakeFut, Fut>
where
MakeFut: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = io::Result<()>> + Send + Sync + 'static,
{
fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let mut this = self.project();
if this.fut.is_none() {
this.fut.set(Some((this.make_fut)()));
}
// We're forced to `unwrap` here because `Fut` may be `!Unpin`, which means we can't safely
// obtain an `&mut Fut` after storing it in `self.fut` when `self` is already behind `Pin`,
// and if we didn't store it then we wouldn't be able to keep it alive between
// `poll_writable` calls.
let result = this.fut.as_mut().as_pin_mut().unwrap().poll(cx);
if result.is_ready() {
// Polling an arbitrary `Future` after it becomes ready is a logic error, so arrange for
// a new `Future` to be created on the next call.
this.fut.set(None);
}
result
}
}
impl<MakeFut, Fut> Debug for UdpPollHelper<MakeFut, Fut> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UdpPollHelper").finish_non_exhaustive()
}
}
/// Automatically select an appropriate runtime from those enabled at compile time
///
/// If `runtime-tokio` is enabled and this function is called from within a Tokio runtime context,
/// then `TokioRuntime` is returned. Otherwise, if `runtime-async-std` is enabled, `AsyncStdRuntime`
/// is returned. Otherwise, if `runtime-smol` is enabled, `SmolRuntime` is returned.
/// Otherwise, `None` is returned.
#[allow(clippy::needless_return)] // Be sure we return the right thing
pub fn default_runtime() -> Option<Arc<dyn Runtime>> {
#[cfg(feature = "runtime-tokio")]
{
if ::tokio::runtime::Handle::try_current().is_ok() {
return Some(Arc::new(TokioRuntime));
}
}
#[cfg(feature = "runtime-async-std")]
{
return Some(Arc::new(AsyncStdRuntime));
}
#[cfg(all(feature = "runtime-smol", not(feature = "runtime-async-std")))]
{
return Some(Arc::new(SmolRuntime));
}
#[cfg(not(any(feature = "runtime-async-std", feature = "runtime-smol")))]
None
}
#[cfg(feature = "runtime-tokio")]
mod tokio;
// Due to MSRV, we must specify `self::` where there's crate/module ambiguity
#[cfg(feature = "runtime-tokio")]
pub use self::tokio::TokioRuntime;
#[cfg(feature = "async-io")]
mod async_io;
// Due to MSRV, we must specify `self::` where there's crate/module ambiguity
#[cfg(any(feature = "runtime-smol", feature = "runtime-async-std"))]
pub use self::async_io::*;

147
vendor/quinn/src/runtime/async_io.rs vendored Normal file
View File

@@ -0,0 +1,147 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Instant,
};
#[cfg(any(feature = "runtime-smol", feature = "runtime-async-std"))]
use std::{io, sync::Arc, task::ready};
#[cfg(any(feature = "runtime-smol", feature = "runtime-async-std"))]
use async_io::Async;
use async_io::Timer;
use super::AsyncTimer;
#[cfg(any(feature = "runtime-smol", feature = "runtime-async-std"))]
use super::{AsyncUdpSocket, Runtime, UdpPollHelper};
#[cfg(feature = "runtime-smol")]
// Due to MSRV, we must specify `self::` where there's crate/module ambiguity
pub use self::smol::SmolRuntime;
#[cfg(feature = "runtime-smol")]
mod smol {
use super::*;
/// A Quinn runtime for smol
#[derive(Debug)]
pub struct SmolRuntime;
impl Runtime for SmolRuntime {
fn new_timer(&self, t: Instant) -> Pin<Box<dyn AsyncTimer>> {
Box::pin(Timer::at(t))
}
fn spawn(&self, future: Pin<Box<dyn Future<Output = ()> + Send>>) {
::smol::spawn(future).detach();
}
fn wrap_udp_socket(
&self,
sock: std::net::UdpSocket,
) -> io::Result<Arc<dyn AsyncUdpSocket>> {
Ok(Arc::new(UdpSocket::new(sock)?))
}
}
}
#[cfg(feature = "runtime-async-std")]
// Due to MSRV, we must specify `self::` where there's crate/module ambiguity
pub use self::async_std::AsyncStdRuntime;
#[cfg(feature = "runtime-async-std")]
mod async_std {
use super::*;
/// A Quinn runtime for async-std
#[derive(Debug)]
pub struct AsyncStdRuntime;
impl Runtime for AsyncStdRuntime {
fn new_timer(&self, t: Instant) -> Pin<Box<dyn AsyncTimer>> {
Box::pin(Timer::at(t))
}
fn spawn(&self, future: Pin<Box<dyn Future<Output = ()> + Send>>) {
::async_std::task::spawn(future);
}
fn wrap_udp_socket(
&self,
sock: std::net::UdpSocket,
) -> io::Result<Arc<dyn AsyncUdpSocket>> {
Ok(Arc::new(UdpSocket::new(sock)?))
}
}
}
impl AsyncTimer for Timer {
fn reset(mut self: Pin<&mut Self>, t: Instant) {
self.set_at(t)
}
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
Future::poll(self, cx).map(|_| ())
}
}
#[cfg(any(feature = "runtime-smol", feature = "runtime-async-std"))]
#[derive(Debug)]
struct UdpSocket {
io: Async<std::net::UdpSocket>,
inner: udp::UdpSocketState,
}
#[cfg(any(feature = "runtime-smol", feature = "runtime-async-std"))]
impl UdpSocket {
fn new(sock: std::net::UdpSocket) -> io::Result<Self> {
Ok(Self {
inner: udp::UdpSocketState::new((&sock).into())?,
io: Async::new_nonblocking(sock)?,
})
}
}
#[cfg(any(feature = "runtime-smol", feature = "runtime-async-std"))]
impl AsyncUdpSocket for UdpSocket {
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn super::UdpPoller>> {
Box::pin(UdpPollHelper::new(move || {
let socket = self.clone();
async move { socket.io.writable().await }
}))
}
fn try_send(&self, transmit: &udp::Transmit) -> io::Result<()> {
self.inner.send((&self.io).into(), transmit)
}
fn poll_recv(
&self,
cx: &mut Context,
bufs: &mut [io::IoSliceMut<'_>],
meta: &mut [udp::RecvMeta],
) -> Poll<io::Result<usize>> {
loop {
ready!(self.io.poll_readable(cx))?;
if let Ok(res) = self.inner.recv((&self.io).into(), bufs, meta) {
return Poll::Ready(Ok(res));
}
}
}
fn local_addr(&self) -> io::Result<std::net::SocketAddr> {
self.io.as_ref().local_addr()
}
fn may_fragment(&self) -> bool {
self.inner.may_fragment()
}
fn max_transmit_segments(&self) -> usize {
self.inner.max_gso_segments()
}
fn max_receive_segments(&self) -> usize {
self.inner.gro_segments()
}
}

102
vendor/quinn/src/runtime/tokio.rs vendored Normal file
View File

@@ -0,0 +1,102 @@
use std::{
future::Future,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll, ready},
time::Instant,
};
use tokio::{
io::Interest,
time::{Sleep, sleep_until},
};
use super::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPollHelper};
/// A Quinn runtime for Tokio
#[derive(Debug)]
pub struct TokioRuntime;
impl Runtime for TokioRuntime {
fn new_timer(&self, t: Instant) -> Pin<Box<dyn AsyncTimer>> {
Box::pin(sleep_until(t.into()))
}
fn spawn(&self, future: Pin<Box<dyn Future<Output = ()> + Send>>) {
tokio::spawn(future);
}
fn wrap_udp_socket(&self, sock: std::net::UdpSocket) -> io::Result<Arc<dyn AsyncUdpSocket>> {
Ok(Arc::new(UdpSocket {
inner: udp::UdpSocketState::new((&sock).into())?,
io: tokio::net::UdpSocket::from_std(sock)?,
}))
}
fn now(&self) -> Instant {
tokio::time::Instant::now().into_std()
}
}
impl AsyncTimer for Sleep {
fn reset(self: Pin<&mut Self>, t: Instant) {
Self::reset(self, t.into())
}
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
Future::poll(self, cx)
}
}
#[derive(Debug)]
struct UdpSocket {
io: tokio::net::UdpSocket,
inner: udp::UdpSocketState,
}
impl AsyncUdpSocket for UdpSocket {
fn create_io_poller(self: Arc<Self>) -> Pin<Box<dyn super::UdpPoller>> {
Box::pin(UdpPollHelper::new(move || {
let socket = self.clone();
async move { socket.io.writable().await }
}))
}
fn try_send(&self, transmit: &udp::Transmit) -> io::Result<()> {
self.io.try_io(Interest::WRITABLE, || {
self.inner.send((&self.io).into(), transmit)
})
}
fn poll_recv(
&self,
cx: &mut Context,
bufs: &mut [std::io::IoSliceMut<'_>],
meta: &mut [udp::RecvMeta],
) -> Poll<io::Result<usize>> {
loop {
ready!(self.io.poll_recv_ready(cx))?;
if let Ok(res) = self.io.try_io(Interest::READABLE, || {
self.inner.recv((&self.io).into(), bufs, meta)
}) {
return Poll::Ready(Ok(res));
}
}
}
fn local_addr(&self) -> io::Result<std::net::SocketAddr> {
self.io.local_addr()
}
fn may_fragment(&self) -> bool {
self.inner.may_fragment()
}
fn max_transmit_segments(&self) -> usize {
self.inner.max_gso_segments()
}
fn max_receive_segments(&self) -> usize {
self.inner.gro_segments()
}
}

443
vendor/quinn/src/send_stream.rs vendored Normal file
View File

@@ -0,0 +1,443 @@
use std::{
future::{Future, poll_fn},
io,
pin::{Pin, pin},
task::{Context, Poll},
};
use bytes::Bytes;
use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
use thiserror::Error;
use crate::{
VarInt,
connection::{ConnectionRef, State},
};
/// A stream that can only be used to send data
///
/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed,
/// continuing to (re)transmit previously written data until it has been fully acknowledged or the
/// connection is closed.
///
/// # Cancellation
///
/// A `write` method is said to be *cancel-safe* when dropping its future before the future becomes
/// ready will always result in no data being written to the stream. This is true of methods which
/// succeed immediately when any progress is made, and is not true of methods which might need to
/// perform multiple writes internally before succeeding. Each `write` method documents whether it is
/// cancel-safe.
///
/// [`reset()`]: SendStream::reset
/// [`finish()`]: SendStream::finish
#[derive(Debug)]
pub struct SendStream {
conn: ConnectionRef,
stream: StreamId,
is_0rtt: bool,
}
impl SendStream {
pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
Self {
conn,
stream,
is_0rtt,
}
}
/// Write a buffer into this stream, returning how many bytes were written
///
/// Unless this method errors, it waits until some amount of `buf` can be written into this
/// stream, and then writes as much as it can without waiting again. Due to congestion and flow
/// control, this may be shorter than `buf.len()`. On success this yields the length of the
/// prefix that was written.
///
/// # Cancel safety
///
/// This method is cancellation safe. If this does not resolve, no bytes were written.
pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
}
/// Write a buffer into this stream in its entirety
///
/// This method repeatedly calls [`write`](Self::write) until all bytes are written, or an
/// error occurs.
///
/// # Cancel safety
///
/// This method is *not* cancellation safe. Even if this does not resolve, some prefix of `buf`
/// may have been written when previously polled.
pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
while !buf.is_empty() {
let written = self.write(buf).await?;
buf = &buf[written..];
}
Ok(())
}
/// Write a slice of [`Bytes`] into this stream, returning how much was written
///
/// Bytes to try to write are provided to this method as an array of cheaply cloneable chunks.
/// Unless this method errors, it waits until some amount of those bytes can be written into
/// this stream, and then writes as much as it can without waiting again. Due to congestion and
/// flow control, this may be less than the total number of bytes.
///
/// On success, this method both mutates `bufs` and yields an informative [`Written`] struct
/// indicating how much was written:
///
/// - [`Bytes`] chunks that were fully written are mutated to be [empty](Bytes::is_empty).
/// - If a [`Bytes`] chunk was partially written, it is [split to](Bytes::split_to) contain
/// only the suffix of bytes that were not written.
/// - The yielded [`Written`] struct indicates how many chunks were fully written as well as
/// how many bytes were written.
///
/// # Cancel safety
///
/// This method is cancellation safe. If this does not resolve, no bytes were written.
pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
}
/// Write a single [`Bytes`] into this stream in its entirety
///
/// Bytes to write are provided to this method as an single cheaply cloneable chunk. This
/// method repeatedly calls [`write_chunks`](Self::write_chunks) until all bytes are written,
/// or an error occurs.
///
/// # Cancel safety
///
/// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
/// been written when previously polled.
pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
self.write_all_chunks(&mut [buf]).await?;
Ok(())
}
/// Write a slice of [`Bytes`] into this stream in its entirety
///
/// Bytes to write are provided to this method as an array of cheaply cloneable chunks. This
/// method repeatedly calls [`write_chunks`](Self::write_chunks) until all bytes are written,
/// or an error occurs. This method mutates `bufs` by mutating all chunks to be
/// [empty](Bytes::is_empty).
///
/// # Cancel safety
///
/// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
/// been written when previously polled.
pub async fn write_all_chunks(&mut self, mut bufs: &mut [Bytes]) -> Result<(), WriteError> {
while !bufs.is_empty() {
let written = self.write_chunks(bufs).await?;
bufs = &mut bufs[written.chunks..];
}
Ok(())
}
fn execute_poll<F, R>(&mut self, cx: &mut Context, write_fn: F) -> Poll<Result<R, WriteError>>
where
F: FnOnce(&mut proto::SendStream) -> Result<R, proto::WriteError>,
{
use proto::WriteError::*;
let mut conn = self.conn.state.lock("SendStream::poll_write");
if self.is_0rtt {
conn.check_0rtt()
.map_err(|()| WriteError::ZeroRttRejected)?;
}
if let Some(ref x) = conn.error {
return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
}
let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
Ok(result) => result,
Err(Blocked) => {
conn.blocked_writers.insert(self.stream, cx.waker().clone());
return Poll::Pending;
}
Err(Stopped(error_code)) => {
return Poll::Ready(Err(WriteError::Stopped(error_code)));
}
Err(ClosedStream) => {
return Poll::Ready(Err(WriteError::ClosedStream));
}
};
conn.wake();
Poll::Ready(Ok(result))
}
/// Notify the peer that no more data will ever be written to this stream
///
/// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset)
/// may still be called after `finish` to abandon transmission of any stream data that might
/// still be buffered.
///
/// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped).
///
/// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
/// called. This error is harmless and serves only to indicate that the caller may have
/// incorrect assumptions about the stream's state.
pub fn finish(&mut self) -> Result<(), ClosedStream> {
let mut conn = self.conn.state.lock("finish");
match conn.inner.send_stream(self.stream).finish() {
Ok(()) => {
conn.wake();
Ok(())
}
Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
// Harmless. If the application needs to know about stopped streams at this point, it
// should call `stopped`.
Err(FinishError::Stopped(_)) => Ok(()),
}
}
/// Close the send stream immediately.
///
/// No new data can be written after calling this method. Locally buffered data is dropped, and
/// previously transmitted data will no longer be retransmitted if lost. If an attempt has
/// already been made to finish the stream, the peer may still receive all written data.
///
/// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
/// called. This error is harmless and serves only to indicate that the caller may have
/// incorrect assumptions about the stream's state.
pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
let mut conn = self.conn.state.lock("SendStream::reset");
if self.is_0rtt && conn.check_0rtt().is_err() {
return Ok(());
}
conn.inner.send_stream(self.stream).reset(error_code)?;
conn.wake();
Ok(())
}
/// Set the priority of the send stream
///
/// Every send stream has an initial priority of 0. Locally buffered data from streams with
/// higher priority will be transmitted before data from streams with lower priority. Changing
/// the priority of a stream with pending data may only take effect after that data has been
/// transmitted. Using many different priority levels per connection may have a negative
/// impact on performance.
pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
let mut conn = self.conn.state.lock("SendStream::set_priority");
conn.inner.send_stream(self.stream).set_priority(priority)?;
Ok(())
}
/// Get the priority of the send stream
pub fn priority(&self) -> Result<i32, ClosedStream> {
let mut conn = self.conn.state.lock("SendStream::priority");
conn.inner.send_stream(self.stream).priority()
}
/// Completes when the peer stops the stream or reads the stream to completion
///
/// Yields `Some` with the stop error code if the peer stops the stream. Yields `None` if the
/// local side [`finish()`](Self::finish)es the stream and then the peer acknowledges receipt
/// of all stream data (although not necessarily the processing of it), after which the peer
/// closing the stream is no longer meaningful.
///
/// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
/// data. As such, relying on `stopped` to know when the peer has read a stream to completion
/// may introduce more latency than using an application-level response of some sort.
pub fn stopped(
&self,
) -> impl Future<Output = Result<Option<VarInt>, StoppedError>> + Send + Sync + 'static {
let conn = self.conn.clone();
let stream = self.stream;
let is_0rtt = self.is_0rtt;
async move {
loop {
// The `Notify::notified` future needs to be created while the lock is being held,
// otherwise a wakeup could be missed if triggered inbetween releasing the lock
// and creating the future.
// The lock may only be held in a block without `await`s, otherwise the future
// becomes `!Send`. `Notify::notified` is lifetime-bound to `Notify`, therefore
// we need to declare `notify` outside of the block, and initialize it inside.
let notify;
{
let mut conn = conn.state.lock("SendStream::stopped");
if let Some(output) = send_stream_stopped(&mut conn, stream, is_0rtt) {
return output;
}
notify = conn.stopped.entry(stream).or_default().clone();
notify.notified()
}
.await
}
}
}
/// Get the identity of this stream
pub fn id(&self) -> StreamId {
self.stream
}
/// Attempt to write bytes from buf into the stream.
///
/// On success, returns Poll::Ready(Ok(num_bytes_written)).
///
/// If the stream is not ready for writing, the method returns Poll::Pending and arranges
/// for the current task (via cx.waker().wake_by_ref()) to receive a notification when the
/// stream becomes writable or is closed.
pub fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, WriteError>> {
pin!(self.get_mut().write(buf)).as_mut().poll(cx)
}
}
/// Check if a send stream is stopped.
///
/// Returns `Some` if the stream is stopped or the connection is closed.
/// Returns `None` if the stream is not stopped.
fn send_stream_stopped(
conn: &mut State,
stream: StreamId,
is_0rtt: bool,
) -> Option<Result<Option<VarInt>, StoppedError>> {
if is_0rtt && conn.check_0rtt().is_err() {
return Some(Err(StoppedError::ZeroRttRejected));
}
match conn.inner.send_stream(stream).stopped() {
Err(ClosedStream { .. }) => Some(Ok(None)),
Ok(Some(error_code)) => Some(Ok(Some(error_code))),
Ok(None) => conn.error.clone().map(|error| Err(error.into())),
}
}
#[cfg(feature = "futures-io")]
impl futures_io::AsyncWrite for SendStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
self.poll_write(cx, buf).map_err(Into::into)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(self.get_mut().finish().map_err(Into::into))
}
}
impl tokio::io::AsyncWrite for SendStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write(cx, buf).map_err(Into::into)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(self.get_mut().finish().map_err(Into::into))
}
}
impl Drop for SendStream {
fn drop(&mut self) {
let mut conn = self.conn.state.lock("SendStream::drop");
// clean up any previously registered wakers
conn.blocked_writers.remove(&self.stream);
if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
return;
}
match conn.inner.send_stream(self.stream).finish() {
Ok(()) => conn.wake(),
Err(FinishError::Stopped(reason)) => {
if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
conn.wake();
}
}
// Already finished or reset, which is fine.
Err(FinishError::ClosedStream) => {}
}
}
}
/// Errors that arise from writing to a stream
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum WriteError {
/// The peer is no longer accepting data on this stream
///
/// Carries an application-defined error code.
#[error("sending stopped by peer: error {0}")]
Stopped(VarInt),
/// The connection was lost
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
/// The stream has already been finished or reset
#[error("closed stream")]
ClosedStream,
/// This was a 0-RTT stream and the server rejected it
///
/// Can only occur on clients for 0-RTT streams, which can be opened using
/// [`Connecting::into_0rtt()`].
///
/// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<ClosedStream> for WriteError {
#[inline]
fn from(_: ClosedStream) -> Self {
Self::ClosedStream
}
}
impl From<StoppedError> for WriteError {
fn from(x: StoppedError) -> Self {
match x {
StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
}
}
}
impl From<WriteError> for io::Error {
fn from(x: WriteError) -> Self {
use WriteError::*;
let kind = match x {
Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
};
Self::new(kind, x)
}
}
/// Errors that arise while monitoring for a send stream stop from the peer
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum StoppedError {
/// The connection was lost
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
/// This was a 0-RTT stream and the server rejected it
///
/// Can only occur on clients for 0-RTT streams, which can be opened using
/// [`Connecting::into_0rtt()`].
///
/// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<StoppedError> for io::Error {
fn from(x: StoppedError) -> Self {
use StoppedError::*;
let kind = match x {
ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) => io::ErrorKind::NotConnected,
};
Self::new(kind, x)
}
}

947
vendor/quinn/src/tests.rs vendored Executable file
View File

@@ -0,0 +1,947 @@
#![cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
#[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
use rustls::crypto::aws_lc_rs::default_provider;
#[cfg(feature = "rustls-ring")]
use rustls::crypto::ring::default_provider;
use std::{
convert::TryInto,
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
str,
sync::Arc,
};
use crate::runtime::TokioRuntime;
use crate::{Duration, Instant};
use bytes::Bytes;
use proto::{RandomConnectionIdGenerator, crypto::rustls::QuicClientConfig};
use rand::{RngCore, SeedableRng, rngs::StdRng};
use rustls::{
RootCertStore,
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
};
use tokio::runtime::{Builder, Runtime};
use tracing::{error_span, info};
use tracing_futures::Instrument as _;
use tracing_subscriber::EnvFilter;
use super::{ClientConfig, Endpoint, EndpointConfig, RecvStream, SendStream, TransportConfig};
#[test]
fn handshake_timeout() {
let _guard = subscribe();
let runtime = rt_threaded();
let client = {
let _guard = runtime.enter();
Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap()
};
// Avoid NoRootAnchors error
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let mut roots = RootCertStore::empty();
roots.add(cert.cert.into()).unwrap();
let mut client_config = crate::ClientConfig::with_root_certificates(Arc::new(roots)).unwrap();
const IDLE_TIMEOUT: Duration = Duration::from_millis(500);
let mut transport_config = crate::TransportConfig::default();
transport_config
.max_idle_timeout(Some(IDLE_TIMEOUT.try_into().unwrap()))
.initial_rtt(Duration::from_millis(10));
client_config.transport_config(Arc::new(transport_config));
let start = Instant::now();
runtime.block_on(async move {
match client
.connect_with(
client_config,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1),
"localhost",
)
.unwrap()
.await
{
Err(crate::ConnectionError::TimedOut) => {}
Err(e) => panic!("unexpected error: {e:?}"),
Ok(_) => panic!("unexpected success"),
}
});
let dt = start.elapsed();
assert!(dt > IDLE_TIMEOUT && dt < 2 * IDLE_TIMEOUT);
}
#[tokio::test]
async fn close_endpoint() {
let _guard = subscribe();
// Avoid NoRootAnchors error
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let mut roots = RootCertStore::empty();
roots.add(cert.cert.into()).unwrap();
let mut endpoint =
Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap();
endpoint
.set_default_client_config(ClientConfig::with_root_certificates(Arc::new(roots)).unwrap());
let conn = endpoint
.connect(
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
"localhost",
)
.unwrap();
tokio::spawn(async move {
let _ = conn.await;
});
let conn = endpoint
.connect(
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234),
"localhost",
)
.unwrap();
endpoint.close(0u32.into(), &[]);
match conn.await {
Err(crate::ConnectionError::LocallyClosed) => (),
Err(e) => panic!("unexpected error: {e}"),
Ok(_) => {
panic!("unexpected success");
}
}
}
#[test]
fn local_addr() {
let socket = UdpSocket::bind((Ipv6Addr::LOCALHOST, 0)).unwrap();
let addr = socket.local_addr().unwrap();
let runtime = rt_basic();
let ep = {
let _guard = runtime.enter();
Endpoint::new(Default::default(), None, socket, Arc::new(TokioRuntime)).unwrap()
};
assert_eq!(
addr,
ep.local_addr()
.expect("Could not obtain our local endpoint")
);
}
#[test]
fn read_after_close() {
let _guard = subscribe();
let runtime = rt_basic();
let endpoint = {
let _guard = runtime.enter();
endpoint()
};
const MSG: &[u8] = b"goodbye!";
let endpoint2 = endpoint.clone();
runtime.spawn(async move {
let new_conn = endpoint2
.accept()
.await
.expect("endpoint")
.await
.expect("connection");
let mut s = new_conn.open_uni().await.unwrap();
s.write_all(MSG).await.unwrap();
s.finish().unwrap();
// Wait for the stream to be closed, one way or another.
_ = s.stopped().await;
});
runtime.block_on(async move {
let new_conn = endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap()
.await
.expect("connect");
tokio::time::sleep(Duration::from_millis(100)).await;
let mut stream = new_conn.accept_uni().await.expect("incoming streams");
let msg = stream.read_to_end(usize::MAX).await.expect("read_to_end");
assert_eq!(msg, MSG);
});
}
#[test]
fn export_keying_material() {
let _guard = subscribe();
let runtime = rt_basic();
let endpoint = {
let _guard = runtime.enter();
endpoint()
};
runtime.block_on(async move {
let outgoing_conn_fut = tokio::spawn({
let endpoint = endpoint.clone();
async move {
endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap()
.await
.expect("connect")
}
});
let incoming_conn_fut = tokio::spawn({
let endpoint = endpoint.clone();
async move {
endpoint
.accept()
.await
.expect("endpoint")
.await
.expect("connection")
}
});
let outgoing_conn = outgoing_conn_fut.await.unwrap();
let incoming_conn = incoming_conn_fut.await.unwrap();
let mut i_buf = [0u8; 64];
incoming_conn
.export_keying_material(&mut i_buf, b"asdf", b"qwer")
.unwrap();
let mut o_buf = [0u8; 64];
outgoing_conn
.export_keying_material(&mut o_buf, b"asdf", b"qwer")
.unwrap();
assert_eq!(&i_buf[..], &o_buf[..]);
});
}
#[tokio::test]
async fn ip_blocking() {
let _guard = subscribe();
let endpoint_factory = EndpointFactory::new();
let client_1 = endpoint_factory.endpoint();
let client_1_addr = client_1.local_addr().unwrap();
let client_2 = endpoint_factory.endpoint();
let server = endpoint_factory.endpoint();
let server_addr = server.local_addr().unwrap();
let server_task = tokio::spawn(async move {
loop {
let accepting = server.accept().await.unwrap();
if accepting.remote_address() == client_1_addr {
accepting.refuse();
} else if accepting.remote_address_validated() {
accepting.await.expect("connection");
} else {
accepting.retry().unwrap();
}
}
});
tokio::join!(
async move {
let e = client_1
.connect(server_addr, "localhost")
.unwrap()
.await
.expect_err("server should have blocked this");
assert!(
matches!(e, crate::ConnectionError::ConnectionClosed(_)),
"wrong error"
);
},
async move {
client_2
.connect(server_addr, "localhost")
.unwrap()
.await
.expect("connect");
}
);
server_task.abort();
}
/// Construct an endpoint suitable for connecting to itself
fn endpoint() -> Endpoint {
EndpointFactory::new().endpoint()
}
fn endpoint_with_config(transport_config: TransportConfig) -> Endpoint {
EndpointFactory::new().endpoint_with_config(transport_config)
}
/// Constructs endpoints suitable for connecting to themselves and each other
struct EndpointFactory {
cert: rcgen::CertifiedKey<rcgen::KeyPair>,
endpoint_config: EndpointConfig,
}
impl EndpointFactory {
fn new() -> Self {
Self {
cert: rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(),
endpoint_config: EndpointConfig::default(),
}
}
fn endpoint(&self) -> Endpoint {
self.endpoint_with_config(TransportConfig::default())
}
fn endpoint_with_config(&self, transport_config: TransportConfig) -> Endpoint {
let key = PrivateKeyDer::Pkcs8(self.cert.signing_key.serialize_der().into());
let transport_config = Arc::new(transport_config);
let mut server_config =
crate::ServerConfig::with_single_cert(vec![self.cert.cert.der().clone()], key).unwrap();
server_config.transport_config(transport_config.clone());
let mut roots = rustls::RootCertStore::empty();
roots.add(self.cert.cert.der().clone()).unwrap();
let mut endpoint = Endpoint::new(
self.endpoint_config.clone(),
Some(server_config),
UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap(),
Arc::new(TokioRuntime),
)
.unwrap();
let mut client_config = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap();
client_config.transport_config(transport_config);
endpoint.set_default_client_config(client_config);
endpoint
}
}
#[tokio::test]
async fn zero_rtt() {
let _guard = subscribe();
let endpoint = endpoint();
const MSG0: &[u8] = b"zero";
const MSG1: &[u8] = b"one";
let endpoint2 = endpoint.clone();
tokio::spawn(async move {
for _ in 0..2 {
let incoming = endpoint2.accept().await.unwrap().accept().unwrap();
let (connection, established) = incoming.into_0rtt().unwrap_or_else(|_| unreachable!());
let c = connection.clone();
tokio::spawn(async move {
while let Ok(mut x) = c.accept_uni().await {
let msg = x.read_to_end(usize::MAX).await.unwrap();
assert_eq!(msg, MSG0);
}
});
info!("sending 0.5-RTT");
let mut s = connection.open_uni().await.expect("open_uni");
s.write_all(MSG0).await.expect("write");
s.finish().unwrap();
established.await;
info!("sending 1-RTT");
let mut s = connection.open_uni().await.expect("open_uni");
s.write_all(MSG1).await.expect("write");
// The peer might close the connection before ACKing
let _ = s.finish();
}
});
let connection = endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap()
.into_0rtt()
.err()
.expect("0-RTT succeeded without keys")
.await
.expect("connect");
{
let mut stream = connection.accept_uni().await.expect("incoming streams");
let msg = stream.read_to_end(usize::MAX).await.expect("read_to_end");
assert_eq!(msg, MSG0);
// Read a 1-RTT message to ensure the handshake completes fully, allowing the server's
// NewSessionTicket frame to be received.
let mut stream = connection.accept_uni().await.expect("incoming streams");
let msg = stream.read_to_end(usize::MAX).await.expect("read_to_end");
assert_eq!(msg, MSG1);
drop(connection);
}
info!("initial connection complete");
let (connection, zero_rtt) = endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap()
.into_0rtt()
.unwrap_or_else(|_| panic!("missing 0-RTT keys"));
// Send something ASAP to use 0-RTT
let c = connection.clone();
tokio::spawn(async move {
let mut s = c.open_uni().await.expect("0-RTT open uni");
info!("sending 0-RTT");
s.write_all(MSG0).await.expect("0-RTT write");
s.finish().unwrap();
});
let mut stream = connection.accept_uni().await.expect("incoming streams");
let msg = stream.read_to_end(usize::MAX).await.expect("read_to_end");
assert_eq!(msg, MSG0);
assert!(zero_rtt.await);
drop((stream, connection));
endpoint.wait_idle().await;
}
#[test]
#[cfg_attr(
any(target_os = "solaris", target_os = "illumos"),
ignore = "Fails on Solaris and Illumos"
)]
fn echo_v6() {
run_echo(EchoArgs {
client_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
server_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0),
nr_streams: 1,
stream_size: 10 * 1024,
receive_window: None,
stream_receive_window: None,
});
}
#[test]
#[cfg_attr(target_os = "solaris", ignore = "Sometimes hangs in poll() on Solaris")]
fn echo_v4() {
run_echo(EchoArgs {
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
nr_streams: 1,
stream_size: 10 * 1024,
receive_window: None,
stream_receive_window: None,
});
}
#[test]
#[cfg_attr(target_os = "solaris", ignore = "Hangs in poll() on Solaris")]
fn echo_dualstack() {
run_echo(EchoArgs {
client_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
nr_streams: 1,
stream_size: 10 * 1024,
receive_window: None,
stream_receive_window: None,
});
}
#[test]
#[ignore]
#[cfg_attr(target_os = "solaris", ignore = "Hangs in poll() on Solaris")]
fn stress_receive_window() {
run_echo(EchoArgs {
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
nr_streams: 50,
stream_size: 25 * 1024 + 11,
receive_window: Some(37),
stream_receive_window: Some(100 * 1024 * 1024),
});
}
#[test]
#[ignore]
#[cfg_attr(target_os = "solaris", ignore = "Hangs in poll() on Solaris")]
fn stress_stream_receive_window() {
// Note that there is no point in running this with too many streams,
// since the window is only active within a stream.
run_echo(EchoArgs {
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
nr_streams: 2,
stream_size: 250 * 1024 + 11,
receive_window: Some(100 * 1024 * 1024),
stream_receive_window: Some(37),
});
}
#[test]
#[ignore]
#[cfg_attr(target_os = "solaris", ignore = "Hangs in poll() on Solaris")]
fn stress_both_windows() {
run_echo(EchoArgs {
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
server_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
nr_streams: 50,
stream_size: 25 * 1024 + 11,
receive_window: Some(37),
stream_receive_window: Some(37),
});
}
fn run_echo(args: EchoArgs) {
let _guard = subscribe();
let runtime = rt_basic();
let handle = {
// Use small receive windows
let mut transport_config = TransportConfig::default();
if let Some(receive_window) = args.receive_window {
transport_config.receive_window(receive_window.try_into().unwrap());
}
if let Some(stream_receive_window) = args.stream_receive_window {
transport_config.stream_receive_window(stream_receive_window.try_into().unwrap());
}
transport_config.max_concurrent_bidi_streams(1_u8.into());
transport_config.max_concurrent_uni_streams(1_u8.into());
let transport_config = Arc::new(transport_config);
// We don't use the `endpoint` helper here because we want two different endpoints with
// different addresses.
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
let cert = CertificateDer::from(cert.cert);
let mut server_config =
crate::ServerConfig::with_single_cert(vec![cert.clone()], key.into()).unwrap();
server_config.transport = transport_config.clone();
let server_sock = UdpSocket::bind(args.server_addr).unwrap();
let server_addr = server_sock.local_addr().unwrap();
let server = {
let _guard = runtime.enter();
let _guard = error_span!("server").entered();
Endpoint::new(
Default::default(),
Some(server_config),
server_sock,
Arc::new(TokioRuntime),
)
.unwrap()
};
let mut roots = rustls::RootCertStore::empty();
roots.add(cert).unwrap();
let mut client_crypto =
rustls::ClientConfig::builder_with_provider(default_provider().into())
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(roots)
.with_no_client_auth();
client_crypto.key_log = Arc::new(rustls::KeyLogFile::new());
let mut client = {
let _guard = runtime.enter();
let _guard = error_span!("client").entered();
Endpoint::client(args.client_addr).unwrap()
};
let mut client_config =
ClientConfig::new(Arc::new(QuicClientConfig::try_from(client_crypto).unwrap()));
client_config.transport_config(transport_config);
client.set_default_client_config(client_config);
let handle = runtime.spawn(async move {
let incoming = server.accept().await.unwrap();
// Note for anyone modifying the platform support in this test:
// If `local_ip` gets available on additional platforms - which
// requires modifying this test - please update the list of supported
// platforms in the doc comment of `quinn_udp::RecvMeta::dst_ip`.
if cfg!(target_os = "linux")
|| cfg!(target_os = "android")
|| cfg!(target_os = "freebsd")
|| cfg!(target_os = "openbsd")
|| cfg!(target_os = "netbsd")
|| cfg!(target_os = "macos")
|| cfg!(target_os = "windows")
{
let local_ip = incoming.local_ip().expect("Local IP must be available");
assert!(local_ip.is_loopback());
} else {
assert_eq!(None, incoming.local_ip());
}
let new_conn = incoming.await.unwrap();
tokio::spawn(async move {
while let Ok(stream) = new_conn.accept_bi().await {
tokio::spawn(echo(stream));
}
});
server.wait_idle().await;
});
info!("connecting from {} to {}", args.client_addr, server_addr);
runtime.block_on(
async move {
let new_conn = client
.connect(server_addr, "localhost")
.unwrap()
.await
.expect("connect");
/// This is just an arbitrary number to generate deterministic test data
const SEED: u64 = 0x12345678;
for i in 0..args.nr_streams {
println!("Opening stream {i}");
let (mut send, mut recv) = new_conn.open_bi().await.expect("stream open");
let msg = gen_data(args.stream_size, SEED);
let send_task = async {
send.write_all(&msg).await.expect("write");
send.finish().unwrap();
};
let recv_task = async { recv.read_to_end(usize::MAX).await.expect("read") };
let (_, data) = tokio::join!(send_task, recv_task);
assert_eq!(data[..], msg[..], "Data mismatch");
}
new_conn.close(0u32.into(), b"done");
client.wait_idle().await;
}
.instrument(error_span!("client")),
);
handle
};
runtime.block_on(handle).unwrap();
}
struct EchoArgs {
client_addr: SocketAddr,
server_addr: SocketAddr,
nr_streams: usize,
stream_size: usize,
receive_window: Option<u64>,
stream_receive_window: Option<u64>,
}
async fn echo((mut send, mut recv): (SendStream, RecvStream)) {
loop {
// These are 32 buffers, for reading approximately 32kB at once
#[rustfmt::skip]
let mut bufs = [
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
];
match recv.read_chunks(&mut bufs).await.expect("read chunks") {
Some(n) => {
send.write_all_chunks(&mut bufs[..n])
.await
.expect("write chunks");
}
None => break,
}
}
let _ = send.finish();
}
fn gen_data(size: usize, seed: u64) -> Vec<u8> {
let mut rng: StdRng = SeedableRng::seed_from_u64(seed);
let mut buf = vec![0; size];
rng.fill_bytes(&mut buf);
buf
}
fn subscribe() -> tracing::subscriber::DefaultGuard {
let sub = tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(EnvFilter::from_default_env())
.with_writer(|| TestWriter)
.finish();
tracing::subscriber::set_default(sub)
}
struct TestWriter;
impl std::io::Write for TestWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
print!(
"{}",
str::from_utf8(buf).expect("tried to log invalid UTF-8")
);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
io::stdout().flush()
}
}
fn rt_basic() -> Runtime {
Builder::new_current_thread().enable_all().build().unwrap()
}
fn rt_threaded() -> Runtime {
Builder::new_multi_thread().enable_all().build().unwrap()
}
#[tokio::test]
async fn rebind_recv() {
let _guard = subscribe();
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
let cert = CertificateDer::from(cert.cert);
let mut roots = rustls::RootCertStore::empty();
roots.add(cert.clone()).unwrap();
let mut client = Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap();
let mut client_config = ClientConfig::with_root_certificates(Arc::new(roots)).unwrap();
client_config.transport_config(Arc::new({
let mut cfg = TransportConfig::default();
cfg.max_concurrent_uni_streams(1u32.into());
cfg
}));
client.set_default_client_config(client_config);
let server_config =
crate::ServerConfig::with_single_cert(vec![cert.clone()], key.into()).unwrap();
let server = {
let _guard = tracing::error_span!("server").entered();
Endpoint::server(
server_config,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
)
.unwrap()
};
let server_addr = server.local_addr().unwrap();
const MSG: &[u8; 5] = b"hello";
let write_send = Arc::new(tokio::sync::Notify::new());
let write_recv = write_send.clone();
let connected_send = Arc::new(tokio::sync::Notify::new());
let connected_recv = connected_send.clone();
let server = tokio::spawn(async move {
let connection = server.accept().await.unwrap().await.unwrap();
info!("got conn");
connected_send.notify_one();
write_recv.notified().await;
let mut stream = connection.open_uni().await.unwrap();
stream.write_all(MSG).await.unwrap();
stream.finish().unwrap();
// Wait for the stream to be closed, one way or another.
_ = stream.stopped().await;
});
let connection = {
let _guard = tracing::error_span!("client").entered();
client
.connect(server_addr, "localhost")
.unwrap()
.await
.unwrap()
};
info!("connected");
connected_recv.notified().await;
client
.rebind(UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap())
.unwrap();
info!("rebound");
write_send.notify_one();
let mut stream = connection.accept_uni().await.unwrap();
assert_eq!(stream.read_to_end(MSG.len()).await.unwrap(), MSG);
server.await.unwrap();
}
#[tokio::test]
async fn stream_id_flow_control() {
let _guard = subscribe();
let mut cfg = TransportConfig::default();
cfg.max_concurrent_uni_streams(1u32.into());
let endpoint = endpoint_with_config(cfg);
let (client, server) = tokio::join!(
endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap(),
async { endpoint.accept().await.unwrap().await }
);
let client = client.unwrap();
let server = server.unwrap();
// If `open_uni` doesn't get unblocked when the previous stream is dropped, this will time out.
tokio::join!(
async {
client.open_uni().await.unwrap();
},
async {
client.open_uni().await.unwrap();
},
async {
client.open_uni().await.unwrap();
},
async {
server.accept_uni().await.unwrap();
server.accept_uni().await.unwrap();
}
);
}
#[tokio::test]
async fn two_datagram_readers() {
let _guard = subscribe();
let endpoint = endpoint();
let (client, server) = tokio::join!(
endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap(),
async { endpoint.accept().await.unwrap().await }
);
let client = client.unwrap();
let server = server.unwrap();
let done = tokio::sync::Notify::new();
let (a, b, ()) = tokio::join!(
async {
let x = client.read_datagram().await.unwrap();
done.notify_waiters();
x
},
async {
let x = client.read_datagram().await.unwrap();
done.notify_waiters();
x
},
async {
server.send_datagram(b"one"[..].into()).unwrap();
done.notified().await;
server.send_datagram_wait(b"two"[..].into()).await.unwrap();
}
);
assert!(*a == *b"one" || *b == *b"one");
assert!(*a == *b"two" || *b == *b"two");
}
#[tokio::test]
async fn multiple_conns_with_zero_length_cids() {
let _guard = subscribe();
let mut factory = EndpointFactory::new();
factory
.endpoint_config
.cid_generator(|| Box::new(RandomConnectionIdGenerator::new(0)));
let server = {
let _guard = error_span!("server").entered();
factory.endpoint()
};
let server_addr = server.local_addr().unwrap();
let client1 = {
let _guard = error_span!("client1").entered();
factory.endpoint()
};
let client2 = {
let _guard = error_span!("client2").entered();
factory.endpoint()
};
let client1 = async move {
let conn = client1
.connect(server_addr, "localhost")
.unwrap()
.await
.unwrap();
conn.closed().await;
}
.instrument(error_span!("client1"));
let client2 = async move {
let conn = client2
.connect(server_addr, "localhost")
.unwrap()
.await
.unwrap();
conn.closed().await;
}
.instrument(error_span!("client2"));
let server = async move {
let client1 = server.accept().await.unwrap().await.unwrap();
let client2 = server.accept().await.unwrap().await.unwrap();
// Both connections are now concurrently live.
client1.close(42u32.into(), &[]);
client2.close(42u32.into(), &[]);
}
.instrument(error_span!("server"));
tokio::join!(client1, client2, server);
}
#[tokio::test]
async fn stream_stopped() {
let _guard = subscribe();
let factory = EndpointFactory::new();
let server = {
let _guard = error_span!("server").entered();
factory.endpoint()
};
let server_addr = server.local_addr().unwrap();
let client = {
let _guard = error_span!("client1").entered();
factory.endpoint()
};
let client = async move {
let conn = client
.connect(server_addr, "localhost")
.unwrap()
.await
.unwrap();
let mut stream = conn.open_uni().await.unwrap();
let stopped1 = stream.stopped();
let stopped2 = stream.stopped();
let stopped3 = stream.stopped();
stream.write_all(b"hi").await.unwrap();
// spawn one of the futures into a task
let stopped1 = tokio::task::spawn(stopped1);
// verify that both futures resolved
let (stopped1, stopped2) = tokio::join!(stopped1, stopped2);
assert!(matches!(stopped1, Ok(Ok(Some(val))) if val == 42u32.into()));
assert!(matches!(stopped2, Ok(Some(val)) if val == 42u32.into()));
// drop the stream
drop(stream);
// verify that a future also resolves after dropping the stream
let stopped3 = stopped3.await;
assert_eq!(stopped3, Ok(Some(42u32.into())));
};
let client =
tokio::time::timeout(Duration::from_millis(100), client).instrument(error_span!("client"));
let server = async move {
let conn = server.accept().await.unwrap().await.unwrap();
let mut stream = conn.accept_uni().await.unwrap();
let mut buf = [0u8; 2];
stream.read_exact(&mut buf).await.unwrap();
stream.stop(42u32.into()).unwrap();
conn
}
.instrument(error_span!("server"));
let (client, conn) = tokio::join!(client, server);
client.expect("timeout");
drop(conn);
}
#[tokio::test]
async fn stream_stopped_2() {
let _guard = subscribe();
let endpoint = endpoint();
let (conn, _server_conn) = tokio::try_join!(
endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap(),
async { endpoint.accept().await.unwrap().await }
)
.unwrap();
let send_stream = conn.open_uni().await.unwrap();
let stopped = tokio::time::timeout(Duration::from_millis(100), send_stream.stopped())
.instrument(error_span!("stopped"));
tokio::pin!(stopped);
// poll the future once so that the waker is registered.
tokio::select! {
biased;
_x = &mut stopped => {},
_x = std::future::ready(()) => {}
}
// drop the send stream
drop(send_stream);
// make sure the stopped future still resolves
let res = stopped.await;
assert_eq!(res, Ok(Ok(None)));
}

224
vendor/quinn/src/work_limiter.rs vendored Normal file
View File

@@ -0,0 +1,224 @@
use crate::{Duration, Instant};
/// Limits the amount of time spent on a certain type of work in a cycle
///
/// The limiter works dynamically: For a sampled subset of cycles it measures
/// the time that is approximately required for fulfilling 1 work item, and
/// calculates the amount of allowed work items per cycle.
/// The estimates are smoothed over all cycles where the exact duration is measured.
///
/// In cycles where no measurement is performed the previously determined work limit
/// is used.
///
/// For the limiter the exact definition of a work item does not matter.
/// It could for example track the amount of transmitted bytes per cycle,
/// or the amount of transmitted datagrams per cycle.
/// It will however work best if the required time to complete a work item is
/// constant.
#[derive(Debug)]
pub(crate) struct WorkLimiter {
/// Whether to measure the required work time, or to use the previous estimates
mode: Mode,
/// The current cycle number
cycle: u16,
/// The time the cycle started - only used in measurement mode
start_time: Option<Instant>,
/// How many work items have been completed in the cycle
completed: usize,
/// The amount of work items which are allowed for a cycle
allowed: usize,
/// The desired cycle time
desired_cycle_time: Duration,
/// The estimated and smoothed time per work item in nanoseconds
smoothed_time_per_work_item_nanos: f64,
}
impl WorkLimiter {
pub(crate) fn new(desired_cycle_time: Duration) -> Self {
Self {
mode: Mode::Measure,
cycle: 0,
start_time: None,
completed: 0,
allowed: 0,
desired_cycle_time,
smoothed_time_per_work_item_nanos: 0.0,
}
}
/// Starts one work cycle
pub(crate) fn start_cycle(&mut self, now: impl Fn() -> Instant) {
self.completed = 0;
if let Mode::Measure = self.mode {
self.start_time = Some(now());
}
}
/// Returns whether more work can be performed inside the `desired_cycle_time`
///
/// Requires that previous work was tracked using `record_work`.
pub(crate) fn allow_work(&mut self, now: impl Fn() -> Instant) -> bool {
match self.mode {
Mode::Measure => (now() - self.start_time.unwrap()) < self.desired_cycle_time,
Mode::HistoricData => self.completed < self.allowed,
}
}
/// Records that `work` additional work items have been completed inside the cycle
///
/// Must be called between `start_cycle` and `finish_cycle`.
pub(crate) fn record_work(&mut self, work: usize) {
self.completed += work;
}
/// Finishes one work cycle
///
/// For cycles where the exact duration is measured this will update the estimates
/// for the time per work item and the limit of allowed work items per cycle.
/// The estimate is updated using the same exponential averaging (smoothing)
/// mechanism which is used for determining QUIC path rtts: The last value is
/// weighted by 1/8, and the previous average by 7/8.
pub(crate) fn finish_cycle(&mut self, now: impl Fn() -> Instant) {
// If no work was done in the cycle drop the measurement, it won't be useful
if self.completed == 0 {
return;
}
if let Mode::Measure = self.mode {
let elapsed = now() - self.start_time.unwrap();
let time_per_work_item_nanos = (elapsed.as_nanos()) as f64 / self.completed as f64;
// Calculate the time per work item. We set this to at least 1ns to avoid
// dividing by 0 when calculating the allowed amount of work items.
self.smoothed_time_per_work_item_nanos = if self.allowed == 0 {
// Initial estimate
time_per_work_item_nanos
} else {
// Smoothed estimate
(7.0 * self.smoothed_time_per_work_item_nanos + time_per_work_item_nanos) / 8.0
}
.max(1.0);
// Allow at least 1 work item in order to make progress
self.allowed = (((self.desired_cycle_time.as_nanos()) as f64
/ self.smoothed_time_per_work_item_nanos) as usize)
.max(1);
self.start_time = None;
}
self.cycle = self.cycle.wrapping_add(1);
self.mode = match self.cycle % SAMPLING_INTERVAL {
0 => Mode::Measure,
_ => Mode::HistoricData,
};
}
}
/// We take a measurement sample once every `SAMPLING_INTERVAL` cycles
const SAMPLING_INTERVAL: u16 = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Mode {
Measure,
HistoricData,
}
#[cfg(test)]
mod tests {
use super::*;
use std::cell::RefCell;
#[test]
fn limit_work() {
const CYCLE_TIME: Duration = Duration::from_millis(500);
const BATCH_WORK_ITEMS: usize = 12;
const BATCH_TIME: Duration = Duration::from_millis(100);
const EXPECTED_INITIAL_BATCHES: usize =
(CYCLE_TIME.as_nanos() / BATCH_TIME.as_nanos()) as usize;
const EXPECTED_ALLOWED_WORK_ITEMS: usize = EXPECTED_INITIAL_BATCHES * BATCH_WORK_ITEMS;
let mut limiter = WorkLimiter::new(CYCLE_TIME);
reset_time();
// The initial cycle is measuring
limiter.start_cycle(get_time);
let mut initial_batches = 0;
while limiter.allow_work(get_time) {
limiter.record_work(BATCH_WORK_ITEMS);
advance_time(BATCH_TIME);
initial_batches += 1;
}
limiter.finish_cycle(get_time);
assert_eq!(initial_batches, EXPECTED_INITIAL_BATCHES);
assert_eq!(limiter.allowed, EXPECTED_ALLOWED_WORK_ITEMS);
let initial_time_per_work_item = limiter.smoothed_time_per_work_item_nanos;
// The next cycles are using historic data
const BATCH_SIZES: [usize; 4] = [1, 2, 3, 5];
for &batch_size in &BATCH_SIZES {
limiter.start_cycle(get_time);
let mut allowed_work = 0;
while limiter.allow_work(get_time) {
limiter.record_work(batch_size);
allowed_work += batch_size;
}
limiter.finish_cycle(get_time);
assert_eq!(allowed_work, EXPECTED_ALLOWED_WORK_ITEMS);
}
// After `SAMPLING_INTERVAL`, we get into measurement mode again
for _ in 0..(SAMPLING_INTERVAL as usize - BATCH_SIZES.len() - 1) {
limiter.start_cycle(get_time);
limiter.record_work(1);
limiter.finish_cycle(get_time);
}
// We now do more work per cycle, and expect the estimate of allowed
// work items to go up
const BATCH_WORK_ITEMS_2: usize = 96;
const TIME_PER_WORK_ITEMS_2_NANOS: f64 =
CYCLE_TIME.as_nanos() as f64 / (EXPECTED_INITIAL_BATCHES * BATCH_WORK_ITEMS_2) as f64;
let expected_updated_time_per_work_item =
(initial_time_per_work_item * 7.0 + TIME_PER_WORK_ITEMS_2_NANOS) / 8.0;
let expected_updated_allowed_work_items =
(CYCLE_TIME.as_nanos() as f64 / expected_updated_time_per_work_item) as usize;
limiter.start_cycle(get_time);
let mut initial_batches = 0;
while limiter.allow_work(get_time) {
limiter.record_work(BATCH_WORK_ITEMS_2);
advance_time(BATCH_TIME);
initial_batches += 1;
}
limiter.finish_cycle(get_time);
assert_eq!(initial_batches, EXPECTED_INITIAL_BATCHES);
assert_eq!(limiter.allowed, expected_updated_allowed_work_items);
}
thread_local! {
/// Mocked time
pub static TIME: RefCell<Instant> = RefCell::new(Instant::now());
}
fn reset_time() {
TIME.with(|t| {
*t.borrow_mut() = Instant::now();
})
}
fn get_time() -> Instant {
TIME.with(|t| *t.borrow())
}
fn advance_time(duration: Duration) {
TIME.with(|t| {
*t.borrow_mut() += duration;
})
}
}