110 lines
3.5 KiB
Rust
110 lines
3.5 KiB
Rust
|
|
#![cfg(feature = "early-data")]
|
||
|
|
|
||
|
|
use std::io::{self, Read, Write};
|
||
|
|
use std::net::{SocketAddr, TcpListener};
|
||
|
|
use std::sync::Arc;
|
||
|
|
use std::thread;
|
||
|
|
|
||
|
|
use rustls::pki_types::ServerName;
|
||
|
|
use rustls::{self, ClientConfig, ServerConnection, Stream};
|
||
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||
|
|
use tokio::net::TcpStream;
|
||
|
|
use tokio_rustls::client::TlsStream;
|
||
|
|
use tokio_rustls::TlsConnector;
|
||
|
|
|
||
|
|
async fn send<S: AsyncRead + AsyncWrite + Unpin>(
|
||
|
|
config: Arc<ClientConfig>,
|
||
|
|
addr: SocketAddr,
|
||
|
|
wrapper: impl Fn(TcpStream) -> S,
|
||
|
|
data: &[u8],
|
||
|
|
vectored: bool,
|
||
|
|
) -> io::Result<(TlsStream<S>, Vec<u8>)> {
|
||
|
|
let connector = TlsConnector::from(config).early_data(true);
|
||
|
|
let stream = wrapper(TcpStream::connect(&addr).await?);
|
||
|
|
let domain = ServerName::try_from("foobar.com").unwrap();
|
||
|
|
|
||
|
|
let mut stream = connector.connect(domain, stream).await?;
|
||
|
|
utils::write(&mut stream, data, vectored).await?;
|
||
|
|
stream.flush().await?;
|
||
|
|
stream.shutdown().await?;
|
||
|
|
|
||
|
|
let mut buf = Vec::new();
|
||
|
|
stream.read_to_end(&mut buf).await?;
|
||
|
|
|
||
|
|
Ok((stream, buf))
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_0rtt() -> io::Result<()> {
|
||
|
|
test_0rtt_impl(|s| s, false).await
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_0rtt_vectored() -> io::Result<()> {
|
||
|
|
test_0rtt_impl(|s| s, true).await
|
||
|
|
}
|
||
|
|
|
||
|
|
#[tokio::test]
|
||
|
|
async fn test_0rtt_vectored_flush_pending() -> io::Result<()> {
|
||
|
|
test_0rtt_impl(utils::FlushWrapper::new, false).await
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn test_0rtt_impl<S: AsyncRead + AsyncWrite + Unpin>(
|
||
|
|
wrapper: impl Fn(TcpStream) -> S,
|
||
|
|
vectored: bool,
|
||
|
|
) -> io::Result<()> {
|
||
|
|
let (mut server, mut client) = utils::make_configs();
|
||
|
|
server.max_early_data_size = 8192;
|
||
|
|
let server = Arc::new(server);
|
||
|
|
|
||
|
|
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||
|
|
let server_port = listener.local_addr().unwrap().port();
|
||
|
|
thread::spawn(move || loop {
|
||
|
|
let (mut sock, _addr) = listener.accept().unwrap();
|
||
|
|
|
||
|
|
let server = Arc::clone(&server);
|
||
|
|
thread::spawn(move || {
|
||
|
|
let mut conn = ServerConnection::new(server).unwrap();
|
||
|
|
conn.complete_io(&mut sock).unwrap();
|
||
|
|
|
||
|
|
if let Some(mut early_data) = conn.early_data() {
|
||
|
|
let mut buf = Vec::new();
|
||
|
|
early_data.read_to_end(&mut buf).unwrap();
|
||
|
|
let mut stream = Stream::new(&mut conn, &mut sock);
|
||
|
|
stream.write_all(b"EARLY:").unwrap();
|
||
|
|
stream.write_all(&buf).unwrap();
|
||
|
|
}
|
||
|
|
|
||
|
|
let mut stream = Stream::new(&mut conn, &mut sock);
|
||
|
|
stream.write_all(b"LATE:").unwrap();
|
||
|
|
loop {
|
||
|
|
let mut buf = [0; 1024];
|
||
|
|
let n = stream.read(&mut buf).unwrap();
|
||
|
|
if n == 0 {
|
||
|
|
conn.send_close_notify();
|
||
|
|
conn.complete_io(&mut sock).unwrap();
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
stream.write_all(&buf[..n]).unwrap();
|
||
|
|
}
|
||
|
|
});
|
||
|
|
});
|
||
|
|
|
||
|
|
client.enable_early_data = true;
|
||
|
|
let client = Arc::new(client);
|
||
|
|
let addr = SocketAddr::from(([127, 0, 0, 1], server_port));
|
||
|
|
|
||
|
|
let (io, buf) = send(client.clone(), addr, &wrapper, b"hello", vectored).await?;
|
||
|
|
assert!(!io.get_ref().1.is_early_data_accepted());
|
||
|
|
assert_eq!("LATE:hello", String::from_utf8_lossy(&buf));
|
||
|
|
|
||
|
|
let (io, buf) = send(client, addr, wrapper, b"world!", vectored).await?;
|
||
|
|
assert!(io.get_ref().1.is_early_data_accepted());
|
||
|
|
assert_eq!("EARLY:world!LATE:", String::from_utf8_lossy(&buf));
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
// Include `utils` module
|
||
|
|
include!("utils.rs");
|