chore: checkpoint before Python removal

This commit is contained in:
2026-03-26 22:33:59 +00:00
parent 683cec9307
commit e568ddf82a
29972 changed files with 11269302 additions and 2 deletions

198
vendor/tokio-tungstenite/src/compat.rs vendored Normal file
View File

@@ -0,0 +1,198 @@
use log::*;
use std::{
io::{Read, Write},
pin::Pin,
task::{Context, Poll},
};
use futures_util::task;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tungstenite::Error as WsError;
pub(crate) enum ContextWaker {
Read,
Write,
}
#[derive(Debug)]
pub(crate) struct AllowStd<S> {
inner: S,
// We have the problem that external read operations (i.e. the Stream impl)
// can trigger both read (AsyncRead) and write (AsyncWrite) operations on
// the underyling stream. At the same time write operations (i.e. the Sink
// impl) can trigger write operations (AsyncWrite) too.
// Both the Stream and the Sink can be used on two different tasks, but it
// is required that AsyncRead and AsyncWrite are only ever used by a single
// task (or better: with a single waker) at a time.
//
// Doing otherwise would cause only the latest waker to be remembered, so
// in our case either the Stream or the Sink impl would potentially wait
// forever to be woken up because only the other one would've been woken
// up.
//
// To solve this we implement a waker proxy that has two slots (one for
// read, one for write) to store wakers. One waker proxy is always passed
// to the AsyncRead, the other to AsyncWrite so that they will only ever
// have to store a single waker, but internally we dispatch any wakeups to
// up to two actual wakers (one from the Sink impl and one from the Stream
// impl).
//
// write_waker_proxy is always used for AsyncWrite, read_waker_proxy for
// AsyncRead. The read_waker slots of both are used for the Stream impl
// (and handshaking), the write_waker slots for the Sink impl.
write_waker_proxy: Arc<WakerProxy>,
read_waker_proxy: Arc<WakerProxy>,
}
// Internal trait used only in the Handshake module for registering
// the waker for the context used during handshaking. We're using the
// read waker slot for this, but any would do.
//
// Don't ever use this from multiple tasks at the same time!
pub(crate) trait SetWaker {
fn set_waker(&self, waker: &task::Waker);
}
impl<S> SetWaker for AllowStd<S> {
fn set_waker(&self, waker: &task::Waker) {
self.set_waker(ContextWaker::Read, waker);
}
}
impl<S> AllowStd<S> {
pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
let res = Self {
inner,
write_waker_proxy: Default::default(),
read_waker_proxy: Default::default(),
};
// Register the handshake waker as read waker for both proxies,
// see also the SetWaker trait.
res.write_waker_proxy.read_waker.register(waker);
res.read_waker_proxy.read_waker.register(waker);
res
}
// Set the read or write waker for our proxies.
//
// Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream
// impl on the WebSocketStream.
// Reading can also cause writes to happen, e.g. in case of Message::Ping handling.
//
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
// WebSocketStream.
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
match kind {
ContextWaker::Read => {
self.write_waker_proxy.read_waker.register(waker);
self.read_waker_proxy.read_waker.register(waker);
}
ContextWaker::Write => {
self.write_waker_proxy.write_waker.register(waker);
self.read_waker_proxy.write_waker.register(waker);
}
}
}
}
// Proxy Waker that we pass to the internal AsyncRead/Write of the
// stream underlying the websocket. We have two slots here for the
// actual wakers to allow external read operations to trigger both
// reads and writes, and the same for writes.
#[derive(Debug, Default)]
struct WakerProxy {
read_waker: task::AtomicWaker,
write_waker: task::AtomicWaker,
}
impl task::ArcWake for WakerProxy {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.read_waker.wake();
arc_self.write_waker.wake();
}
}
impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
{
trace!("{}:{} AllowStd.with_context", file!(), line!());
let waker = match kind {
ContextWaker::Read => task::waker_ref(&self.read_waker_proxy),
ContextWaker::Write => task::waker_ref(&self.write_waker_proxy),
};
let mut context = task::Context::from_waker(&waker);
f(&mut context, Pin::new(&mut self.inner))
}
pub(crate) fn get_mut(&mut self) -> &mut S {
&mut self.inner
}
pub(crate) fn get_ref(&self) -> &S {
&self.inner
}
}
impl<S> Read for AllowStd<S>
where
S: AsyncRead + Unpin,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
trace!("{}:{} Read.read", file!(), line!());
let mut buf = ReadBuf::new(buf);
match self.with_context(ContextWaker::Read, |ctx, stream| {
trace!("{}:{} Read.with_context read -> poll_read", file!(), line!());
stream.poll_read(ctx, &mut buf)
}) {
Poll::Ready(Ok(_)) => Ok(buf.filled().len()),
Poll::Ready(Err(err)) => Err(err),
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
}
impl<S> Write for AllowStd<S>
where
S: AsyncWrite + Unpin,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
trace!("{}:{} Write.write", file!(), line!());
match self.with_context(ContextWaker::Write, |ctx, stream| {
trace!("{}:{} Write.with_context write -> poll_write", file!(), line!());
stream.poll_write(ctx, buf)
}) {
Poll::Ready(r) => r,
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
fn flush(&mut self) -> std::io::Result<()> {
trace!("{}:{} Write.flush", file!(), line!());
match self.with_context(ContextWaker::Write, |ctx, stream| {
trace!("{}:{} Write.with_context flush -> poll_flush", file!(), line!());
stream.poll_flush(ctx)
}) {
Poll::Ready(r) => r,
Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
}
}
}
pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
trace!("WouldBlock");
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}

98
vendor/tokio-tungstenite/src/connect.rs vendored Normal file
View File

@@ -0,0 +1,98 @@
//! Connection helper.
use tokio::net::TcpStream;
use tungstenite::{
error::{Error, UrlError},
handshake::client::{Request, Response},
protocol::WebSocketConfig,
};
use crate::{domain, stream::MaybeTlsStream, Connector, IntoClientRequest, WebSocketStream};
/// Connect to a given URL.
///
/// Accepts any request that implements [`IntoClientRequest`], which is often just `&str`, but can
/// be a variety of types such as `httparse::Request` or [`tungstenite::http::Request`] for more
/// complex uses.
///
/// ```no_run
/// # use tungstenite::client::IntoClientRequest;
///
/// # async fn test() {
/// use tungstenite::http::{Method, Request};
/// use tokio_tungstenite::connect_async;
///
/// let mut request = "wss://api.example.com".into_client_request().unwrap();
/// request.headers_mut().insert("api-key", "42".parse().unwrap());
///
/// let (stream, response) = connect_async(request).await.unwrap();
/// # }
/// ```
pub async fn connect_async<R>(
request: R,
) -> Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error>
where
R: IntoClientRequest + Unpin,
{
connect_async_with_config(request, None, false).await
}
/// The same as `connect_async()` but the one can specify a websocket configuration.
/// Please refer to `connect_async()` for more details. `disable_nagle` specifies if
/// the Nagle's algorithm must be disabled, i.e. `set_nodelay(true)`. If you don't know
/// what the Nagle's algorithm is, better leave it set to `false`.
pub async fn connect_async_with_config<R>(
request: R,
config: Option<WebSocketConfig>,
disable_nagle: bool,
) -> Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error>
where
R: IntoClientRequest + Unpin,
{
connect(request.into_client_request()?, config, disable_nagle, None).await
}
/// The same as `connect_async()` but the one can specify a websocket configuration,
/// and a TLS connector to use. Please refer to `connect_async()` for more details.
/// `disable_nagle` specifies if the Nagle's algorithm must be disabled, i.e.
/// `set_nodelay(true)`. If you don't know what the Nagle's algorithm is, better
/// leave it to `false`.
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub async fn connect_async_tls_with_config<R>(
request: R,
config: Option<WebSocketConfig>,
disable_nagle: bool,
connector: Option<Connector>,
) -> Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error>
where
R: IntoClientRequest + Unpin,
{
connect(request.into_client_request()?, config, disable_nagle, connector).await
}
async fn connect(
request: Request,
config: Option<WebSocketConfig>,
disable_nagle: bool,
connector: Option<Connector>,
) -> Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response), Error> {
let domain = domain(&request)?;
let port = request
.uri()
.port_u16()
.or_else(|| match request.uri().scheme_str() {
Some("wss") => Some(443),
Some("ws") => Some(80),
_ => None,
})
.ok_or(Error::Url(UrlError::UnsupportedUrlScheme))?;
let addr = format!("{domain}:{port}");
let socket = TcpStream::connect(addr).await.map_err(Error::Io)?;
if disable_nagle {
socket.set_nodelay(true)?;
}
crate::tls::client_async_tls_with_config(request, socket, config, connector).await
}

View File

@@ -0,0 +1,179 @@
#[cfg(feature = "handshake")]
use crate::compat::SetWaker;
use crate::{compat::AllowStd, WebSocketStream};
use log::*;
use std::{
future::Future,
io::{Read, Write},
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tungstenite::WebSocket;
#[cfg(feature = "handshake")]
use tungstenite::{
handshake::{
client::Response, server::Callback, HandshakeError as Error, HandshakeRole,
MidHandshake as WsHandshake,
},
ClientHandshake, ServerHandshake,
};
pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S>
where
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream }));
let ws = start.await;
WebSocketStream::new(ws)
}
struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>);
struct SkippedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
impl<F, S> Future for SkippedHandshakeFuture<F, S>
where
F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = WebSocket<AllowStd<S>>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.get_mut().0.take().expect("future polled after completion");
trace!("Setting context when skipping handshake");
let stream = AllowStd::new(inner.stream, ctx.waker());
Poll::Ready((inner.f)(stream))
}
}
#[cfg(feature = "handshake")]
struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);
#[cfg(feature = "handshake")]
enum StartedHandshake<Role: HandshakeRole> {
Done(Role::FinalResult),
Mid(WsHandshake<Role>),
}
#[cfg(feature = "handshake")]
struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
#[cfg(feature = "handshake")]
struct StartedHandshakeFutureInner<F, S> {
f: F,
stream: S,
}
#[cfg(feature = "handshake")]
async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: SetWaker + Unpin,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
match start.await? {
StartedHandshake::Done(r) => Ok(r),
StartedHandshake::Mid(s) => {
let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await;
res
}
}
}
#[cfg(feature = "handshake")]
pub(crate) async fn client_handshake<F, S>(
stream: S,
f: F,
) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>>
where
F: FnOnce(
AllowStd<S>,
) -> Result<
<ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult,
Error<ClientHandshake<AllowStd<S>>>,
> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let result = handshake(stream, f).await?;
let (s, r) = result;
Ok((WebSocketStream::new(s), r))
}
#[cfg(feature = "handshake")]
pub(crate) async fn server_handshake<C, F, S>(
stream: S,
f: F,
) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>>
where
C: Callback + Unpin,
F: FnOnce(
AllowStd<S>,
) -> Result<
<ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult,
Error<ServerHandshake<AllowStd<S>, C>>,
> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?;
Ok(WebSocketStream::new(s))
}
#[cfg(feature = "handshake")]
impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
where
Role: HandshakeRole,
Role::InternalStream: SetWaker + Unpin,
F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
S: Unpin,
AllowStd<S>: Read + Write,
{
type Output = Result<StartedHandshake<Role>, Error<Role>>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = self.0.take().expect("future polled after completion");
trace!("Setting ctx when starting handshake");
let stream = AllowStd::new(inner.stream, ctx.waker());
match (inner.f)(stream) {
Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
}
}
}
#[cfg(feature = "handshake")]
impl<Role> Future for MidHandshake<Role>
where
Role: HandshakeRole + Unpin,
Role::InternalStream: SetWaker + Unpin,
{
type Output = Result<Role::FinalResult, Error<Role>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut s = self.as_mut().0.take().expect("future polled after completion");
let machine = s.get_mut();
trace!("Setting context in handshake");
machine.get_mut().set_waker(cx.waker());
match s.handshake() {
Ok(stream) => Poll::Ready(Ok(stream)),
Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
Err(Error::Interrupted(mid)) => {
self.0 = Some(mid);
Poll::Pending
}
}
}
}

440
vendor/tokio-tungstenite/src/lib.rs vendored Normal file
View File

@@ -0,0 +1,440 @@
//! Async WebSocket usage.
//!
//! This library is an implementation of WebSocket handshakes and streams. It
//! is based on the crate which implements all required WebSocket protocol
//! logic. So this crate basically just brings tokio support / tokio integration
//! to it.
//!
//! Each WebSocket stream implements the required `Stream` and `Sink` traits,
//! so the socket is just a stream of messages coming in and going out.
#![deny(missing_docs, unused_must_use, unused_mut, unused_imports, unused_import_braces)]
pub use tungstenite;
mod compat;
#[cfg(feature = "connect")]
mod connect;
mod handshake;
#[cfg(feature = "stream")]
mod stream;
#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
mod tls;
use std::io::{Read, Write};
use compat::{cvt, AllowStd, ContextWaker};
use futures_util::{
sink::{Sink, SinkExt},
stream::{FusedStream, Stream},
};
use log::*;
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "handshake")]
use tungstenite::{
client::IntoClientRequest,
handshake::{
client::{ClientHandshake, Response},
server::{Callback, NoCallback},
HandshakeError,
},
};
use tungstenite::{
error::Error as WsError,
protocol::{Message, Role, WebSocket, WebSocketConfig},
};
#[cfg(any(feature = "native-tls", feature = "__rustls-tls", feature = "connect"))]
pub use tls::Connector;
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub use tls::{client_async_tls, client_async_tls_with_config};
#[cfg(feature = "connect")]
pub use connect::{connect_async, connect_async_with_config};
#[cfg(all(any(feature = "native-tls", feature = "__rustls-tls"), feature = "connect"))]
pub use connect::connect_async_tls_with_config;
#[cfg(feature = "stream")]
pub use stream::MaybeTlsStream;
use tungstenite::protocol::CloseFrame;
/// Creates a WebSocket handshake from a request and a stream.
/// For convenience, the user may call this with a url string, a URL,
/// or a `Request`. Calling with `Request` allows the user to add
/// a WebSocket protocol or other custom headers.
///
/// Internally, this custom creates a handshake representation and returns
/// a future representing the resolution of the WebSocket handshake. The
/// returned future will resolve to either `WebSocketStream<S>` or `Error`
/// depending on whether the handshake is successful.
///
/// This is typically used for clients who have already established, for
/// example, a TCP connection to the remote server.
#[cfg(feature = "handshake")]
pub async fn client_async<'a, R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<S>, Response), WsError>
where
R: IntoClientRequest + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
client_async_with_config(request, stream, None).await
}
/// The same as `client_async()` but the one can specify a websocket configuration.
/// Please refer to `client_async()` for more details.
#[cfg(feature = "handshake")]
pub async fn client_async_with_config<'a, R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<S>, Response), WsError>
where
R: IntoClientRequest + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let f = handshake::client_handshake(stream, move |allow_std| {
let request = request.into_client_request()?;
let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
cli_handshake.handshake()
});
f.await.map_err(|e| match e {
HandshakeError::Failure(e) => e,
e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
})
}
/// Accepts a new WebSocket connection with the provided stream.
///
/// This function will internally call `server::accept` to create a
/// handshake representation and returns a future representing the
/// resolution of the WebSocket handshake. The returned future will resolve
/// to either `WebSocketStream<S>` or `Error` depending if it's successful
/// or not.
///
/// This is typically used after a socket has been accepted from a
/// `TcpListener`. That socket is then passed to this function to perform
/// the server half of the accepting a client's websocket connection.
#[cfg(feature = "handshake")]
pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
accept_hdr_async(stream, NoCallback).await
}
/// The same as `accept_async()` but the one can specify a websocket configuration.
/// Please refer to `accept_async()` for more details.
#[cfg(feature = "handshake")]
pub async fn accept_async_with_config<S>(
stream: S,
config: Option<WebSocketConfig>,
) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
accept_hdr_async_with_config(stream, NoCallback, config).await
}
/// Accepts a new WebSocket connection with the provided stream.
///
/// This function does the same as `accept_async()` but accepts an extra callback
/// for header processing. The callback receives headers of the incoming
/// requests and is able to add extra headers to the reply.
#[cfg(feature = "handshake")]
pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
C: Callback + Unpin,
{
accept_hdr_async_with_config(stream, callback, None).await
}
/// The same as `accept_hdr_async()` but the one can specify a websocket configuration.
/// Please refer to `accept_hdr_async()` for more details.
#[cfg(feature = "handshake")]
pub async fn accept_hdr_async_with_config<S, C>(
stream: S,
callback: C,
config: Option<WebSocketConfig>,
) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
C: Callback + Unpin,
{
let f = handshake::server_handshake(stream, move |allow_std| {
tungstenite::accept_hdr_with_config(allow_std, callback, config)
});
f.await.map_err(|e| match e {
HandshakeError::Failure(e) => e,
e => WsError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
})
}
/// A wrapper around an underlying raw stream which implements the WebSocket
/// protocol.
///
/// A `WebSocketStream<S>` represents a handshake that has been completed
/// successfully and both the server and the client are ready for receiving
/// and sending data. Message from a `WebSocketStream<S>` are accessible
/// through the respective `Stream` and `Sink`. Check more information about
/// them in `futures-rs` crate documentation or have a look on the examples
/// and unit tests for this crate.
#[derive(Debug)]
pub struct WebSocketStream<S> {
inner: WebSocket<AllowStd<S>>,
closing: bool,
ended: bool,
/// Tungstenite is probably ready to receive more data.
///
/// `false` once start_send hits `WouldBlock` errors.
/// `true` initially and after `flush`ing.
ready: bool,
}
impl<S> WebSocketStream<S> {
/// Convert a raw socket into a WebSocketStream without performing a
/// handshake.
pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake::without_handshake(stream, move |allow_std| {
WebSocket::from_raw_socket(allow_std, role, config)
})
.await
}
/// Convert a raw socket into a WebSocketStream without performing a
/// handshake.
pub async fn from_partially_read(
stream: S,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake::without_handshake(stream, move |allow_std| {
WebSocket::from_partially_read(allow_std, part, role, config)
})
.await
}
pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
Self { inner: ws, closing: false, ended: false, ready: true }
}
fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
where
S: Unpin,
F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
AllowStd<S>: Read + Write,
{
trace!("{}:{} WebSocketStream.with_context", file!(), line!());
if let Some((kind, ctx)) = ctx {
self.inner.get_mut().set_waker(kind, ctx.waker());
}
f(&mut self.inner)
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.inner.get_ref().get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut S
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.inner.get_mut().get_mut()
}
/// Returns a reference to the configuration of the tungstenite stream.
pub fn get_config(&self) -> &WebSocketConfig {
self.inner.get_config()
}
/// Close the underlying web socket
pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.send(Message::Close(msg)).await
}
}
impl<T> Stream for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<Message, WsError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
trace!("{}:{} Stream.poll_next", file!(), line!());
// The connection has been closed or a critical error has occurred.
// We have already returned the error to the user, the `Stream` is unusable,
// so we assume that the stream has been "fused".
if self.ended {
return Poll::Ready(None);
}
match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!());
cvt(s.read())
})) {
Ok(v) => Poll::Ready(Some(Ok(v))),
Err(e) => {
self.ended = true;
if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
Poll::Ready(None)
} else {
Poll::Ready(Some(Err(e)))
}
}
}
}
}
impl<T> FusedStream for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn is_terminated(&self) -> bool {
self.ended
}
}
impl<T> Sink<Message> for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Error = WsError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.ready {
Poll::Ready(Ok(()))
} else {
// Currently blocked so try to flush the blockage away
(*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
self.ready = true;
r
})
}
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match (*self).with_context(None, |s| s.write(item)) {
Ok(()) => {
self.ready = true;
Ok(())
}
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
// the message was accepted and queued so not an error
// but `poll_ready` will now start trying to flush the block
self.ready = false;
Ok(())
}
Err(e) => {
self.ready = true;
debug!("websocket start_send error: {}", e);
Err(e)
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
self.ready = true;
match r {
// WebSocket connection has just been closed. Flushing completed, not an error.
Err(WsError::ConnectionClosed) => Ok(()),
other => other,
}
})
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.ready = true;
let res = if self.closing {
// After queueing it, we call `flush` to drive the close handshake to completion.
(*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
} else {
(*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
};
match res {
Ok(()) => Poll::Ready(Ok(())),
Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
trace!("WouldBlock");
self.closing = true;
Poll::Pending
}
Err(err) => {
debug!("websocket close error: {}", err);
Poll::Ready(Err(err))
}
}
}
}
/// Get a domain from an URL.
#[cfg(any(feature = "connect", feature = "native-tls", feature = "__rustls-tls"))]
#[inline]
fn domain(request: &tungstenite::handshake::client::Request) -> Result<String, WsError> {
match request.uri().host() {
// rustls expects IPv6 addresses without the surrounding [] brackets
#[cfg(feature = "__rustls-tls")]
Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
Some(d) => Ok(d.to_string()),
None => Err(WsError::Url(tungstenite::error::UrlError::NoHostName)),
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "connect")]
use crate::stream::MaybeTlsStream;
use crate::{compat::AllowStd, WebSocketStream};
use std::io::{Read, Write};
#[cfg(feature = "connect")]
use tokio::io::{AsyncReadExt, AsyncWriteExt};
fn is_read<T: Read>() {}
fn is_write<T: Write>() {}
#[cfg(feature = "connect")]
fn is_async_read<T: AsyncReadExt>() {}
#[cfg(feature = "connect")]
fn is_async_write<T: AsyncWriteExt>() {}
fn is_unpin<T: Unpin>() {}
#[test]
fn web_socket_stream_has_traits() {
is_read::<AllowStd<tokio::net::TcpStream>>();
is_write::<AllowStd<tokio::net::TcpStream>>();
#[cfg(feature = "connect")]
is_async_read::<MaybeTlsStream<tokio::net::TcpStream>>();
#[cfg(feature = "connect")]
is_async_write::<MaybeTlsStream<tokio::net::TcpStream>>();
is_unpin::<WebSocketStream<tokio::net::TcpStream>>();
#[cfg(feature = "connect")]
is_unpin::<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>();
}
}

80
vendor/tokio-tungstenite/src/stream.rs vendored Normal file
View File

@@ -0,0 +1,80 @@
//! Convenience wrapper for streams to switch between plain TCP and TLS at runtime.
//!
//! There is no dependency on actual TLS implementations. Everything like
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits.
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// A stream that might be protected with TLS.
#[non_exhaustive]
#[derive(Debug)]
pub enum MaybeTlsStream<S> {
/// Unencrypted socket stream.
Plain(S),
/// Encrypted socket stream using `native-tls`.
#[cfg(feature = "native-tls")]
NativeTls(tokio_native_tls::TlsStream<S>),
/// Encrypted socket stream using `rustls`.
#[cfg(feature = "__rustls-tls")]
Rustls(tokio_rustls::client::TlsStream<S>),
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "__rustls-tls")]
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "__rustls-tls")]
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "__rustls-tls")]
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "native-tls")]
MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "__rustls-tls")]
MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}

237
vendor/tokio-tungstenite/src/tls.rs vendored Normal file
View File

@@ -0,0 +1,237 @@
//! Connection helper.
use tokio::io::{AsyncRead, AsyncWrite};
use tungstenite::{
client::uri_mode, error::Error, handshake::client::Response, protocol::WebSocketConfig,
};
use crate::{client_async_with_config, IntoClientRequest, WebSocketStream};
pub use crate::stream::MaybeTlsStream;
/// A connector that can be used when establishing connections, allowing to control whether
/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
/// `Plain` variant.
#[non_exhaustive]
#[derive(Clone)]
pub enum Connector {
/// Plain (non-TLS) connector.
Plain,
/// `native-tls` TLS connector.
#[cfg(feature = "native-tls")]
NativeTls(native_tls_crate::TlsConnector),
/// `rustls` TLS connector.
#[cfg(feature = "__rustls-tls")]
Rustls(std::sync::Arc<rustls::ClientConfig>),
}
mod encryption {
#[cfg(feature = "native-tls")]
pub mod native_tls {
use native_tls_crate::TlsConnector;
use tokio_native_tls::TlsConnector as TokioTlsConnector;
use tokio::io::{AsyncRead, AsyncWrite};
use tungstenite::{error::TlsError, stream::Mode, Error};
use crate::stream::MaybeTlsStream;
pub async fn wrap_stream<S>(
socket: S,
domain: String,
mode: Mode,
tls_connector: Option<TlsConnector>,
) -> Result<MaybeTlsStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{
match mode {
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
Mode::Tls => {
let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
let connector = try_connector.map_err(TlsError::Native)?;
let stream = TokioTlsConnector::from(connector);
let connected = stream.connect(&domain, socket).await;
match connected {
Err(e) => Err(Error::Tls(e.into())),
Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
}
}
}
}
}
#[cfg(feature = "__rustls-tls")]
pub mod rustls {
pub use rustls::ClientConfig;
use rustls::RootCertStore;
use rustls_pki_types::ServerName;
use tokio_rustls::TlsConnector as TokioTlsConnector;
use std::{convert::TryFrom, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tungstenite::{error::TlsError, stream::Mode, Error};
use crate::stream::MaybeTlsStream;
pub async fn wrap_stream<S>(
socket: S,
domain: String,
mode: Mode,
tls_connector: Option<Arc<ClientConfig>>,
) -> Result<MaybeTlsStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{
match mode {
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
Mode::Tls => {
let config = match tls_connector {
Some(config) => config,
None => {
#[allow(unused_mut)]
let mut root_store = RootCertStore::empty();
#[cfg(feature = "rustls-tls-native-roots")]
{
let rustls_native_certs::CertificateResult {
certs, errors, ..
} = rustls_native_certs::load_native_certs();
if !errors.is_empty() {
log::warn!(
"native root CA certificate loading errors: {errors:?}"
);
}
// Not finding any native root CA certificates is not fatal if the
// "rustls-tls-webpki-roots" feature is enabled.
#[cfg(not(feature = "rustls-tls-webpki-roots"))]
if certs.is_empty() {
return Err(std::io::Error::new(std::io::ErrorKind::NotFound, format!("no native root CA certificates found (errors: {errors:?})")).into());
}
let total_number = certs.len();
let (number_added, number_ignored) =
root_store.add_parsable_certificates(certs);
log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
}
#[cfg(feature = "rustls-tls-webpki-roots")]
{
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}
Arc::new(
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
)
}
};
let domain = ServerName::try_from(domain.as_str())
.map_err(|_| TlsError::InvalidDnsName)?
.to_owned();
let stream = TokioTlsConnector::from(config);
let connected = stream.connect(domain, socket).await;
match connected {
Err(e) => Err(Error::Io(e)),
Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
}
}
}
}
}
pub mod plain {
use tokio::io::{AsyncRead, AsyncWrite};
use tungstenite::{
error::{Error, UrlError},
stream::Mode,
};
use crate::stream::MaybeTlsStream;
pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
where
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
{
match mode {
Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
}
}
}
}
/// Creates a WebSocket handshake from a request and a stream,
/// upgrading the stream to TLS if required.
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
pub async fn client_async_tls<R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
where
R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
MaybeTlsStream<S>: Unpin,
{
client_async_tls_with_config(request, stream, None, None).await
}
/// The same as `client_async_tls()` but the one can specify a websocket configuration,
/// and an optional connector. If no connector is specified, a default one will
/// be created.
///
/// Please refer to `client_async_tls()` for more details.
pub async fn client_async_tls_with_config<R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
connector: Option<Connector>,
) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
where
R: IntoClientRequest + Unpin,
S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
MaybeTlsStream<S>: Unpin,
{
let request = request.into_client_request()?;
#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
let domain = crate::domain(&request)?;
// Make sure we check domain and mode first. URL must be valid.
let mode = uri_mode(request.uri())?;
let stream = match connector {
Some(conn) => match conn {
#[cfg(feature = "native-tls")]
Connector::NativeTls(conn) => {
self::encryption::native_tls::wrap_stream(stream, domain, mode, Some(conn)).await
}
#[cfg(feature = "__rustls-tls")]
Connector::Rustls(conn) => {
self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
}
Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
},
None => {
#[cfg(feature = "native-tls")]
{
self::encryption::native_tls::wrap_stream(stream, domain, mode, None).await
}
#[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
{
self::encryption::rustls::wrap_stream(stream, domain, mode, None).await
}
#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
{
self::encryption::plain::wrap_stream(stream, mode).await
}
}
}?;
client_async_with_config(request, stream, config).await
}