Files
cli/vendor/russh/tests/test_data_stream.rs

143 lines
3.6 KiB
Rust
Raw Normal View History

use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use rand::RngCore;
use russh::keys::key;
use russh::server::{self, Auth, Msg, Server as _, Session};
use russh::{client, Channel};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
pub const WINDOW_SIZE: u32 = 8 * 2048;
#[tokio::test]
async fn test_reader_and_writer() -> Result<(), anyhow::Error> {
env_logger::init();
let addr = addr();
let data = data();
tokio::spawn(Server::run(addr));
// Wait until the server is started
while TcpStream::connect(addr).is_err() {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
stream(addr, &data).await?;
Ok(())
}
async fn stream(addr: SocketAddr, data: &[u8]) -> Result<(), anyhow::Error> {
let config = Arc::new(client::Config::default());
let key = Arc::new(russh_keys::key::KeyPair::generate_ed25519());
let mut session = russh::client::connect(config, addr, Client).await?;
let mut channel = match session.authenticate_publickey("user", key).await {
Ok(true) => session.channel_open_session().await?,
Ok(false) => panic!("Authentication failed"),
Err(err) => return Err(err.into()),
};
let mut buf = Vec::<u8>::new();
let (mut writer, mut reader) = (channel.make_writer_ext(Some(1)), channel.make_reader());
let (r0, r1) = tokio::join!(
async {
writer.write_all(data).await?;
writer.shutdown().await?;
Ok::<_, anyhow::Error>(())
},
reader.read_to_end(&mut buf)
);
r0?;
let count = r1?;
assert_eq!(data.len(), count);
assert_eq!(data, buf);
Ok(())
}
fn data() -> Vec<u8> {
let mut rng = rand::thread_rng();
let mut data = vec![0u8; WINDOW_SIZE as usize * 2 + 7]; // Check whether the window_size resizing works
rng.fill_bytes(&mut data);
data
}
/// Find a unused local address to bind our server to
fn addr() -> SocketAddr {
TcpListener::bind(("127.0.0.1", 0))
.unwrap()
.local_addr()
.unwrap()
}
#[derive(Clone)]
struct Server;
impl Server {
async fn run(addr: SocketAddr) {
let config = Arc::new(server::Config {
keys: vec![russh_keys::key::KeyPair::generate_ed25519()],
window_size: WINDOW_SIZE,
..Default::default()
});
let mut sh = Server {};
sh.run_on_address(config, addr).await.unwrap();
}
}
impl russh::server::Server for Server {
type Handler = Self;
fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self::Handler {
self.clone()
}
}
#[async_trait::async_trait]
impl russh::server::Handler for Server {
type Error = anyhow::Error;
async fn auth_publickey(&mut self, _: &str, _: &key::PublicKey) -> Result<Auth, Self::Error> {
Ok(Auth::Accept)
}
async fn channel_open_session(
&mut self,
mut channel: Channel<Msg>,
_session: &mut Session,
) -> Result<bool, Self::Error> {
tokio::spawn(async move {
let (mut writer, mut reader) =
(channel.make_writer(), channel.make_reader_ext(Some(1)));
tokio::io::copy(&mut reader, &mut writer)
.await
.expect("Data transfer failed");
writer.shutdown().await.expect("Shutdown failed");
});
Ok(true)
}
}
struct Client;
#[async_trait::async_trait]
impl russh::client::Handler for Client {
type Error = anyhow::Error;
async fn check_server_key(&mut self, _: &key::PublicKey) -> Result<bool, Self::Error> {
Ok(true)
}
}