Files
cli/vendor/tokio-tungstenite/examples/server-custom-accept.rs

184 lines
6.0 KiB
Rust

//! A chat server that broadcasts a message to all connections.
//!
//! This is a simple line-based server which accepts WebSocket connections,
//! reads lines from those connections, and broadcasts the lines to all other
//! connected clients.
//!
//! You can test this out by running:
//!
//! cargo run --example server 127.0.0.1:12345
//!
//! And then in another window run:
//!
//! cargo run --example client ws://127.0.0.1:12345/socket
//!
//! You can run the second command in multiple windows and then chat between the
//! two, seeing the messages from the other client as they're received. For all
//! connected clients they'll all join the same room and see everyone else's
//! messages.
use std::{
collections::HashMap,
convert::Infallible,
env,
net::SocketAddr,
sync::{Arc, Mutex},
};
use hyper::{
body::Incoming,
header::{
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
UPGRADE,
},
server::conn::http1,
service::service_fn,
upgrade::Upgraded,
Method, Request, Response, StatusCode, Version,
};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use futures_channel::mpsc::{unbounded, UnboundedSender};
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt};
use tokio_tungstenite::{
tungstenite::{
handshake::derive_accept_key,
protocol::{Message, Role},
},
WebSocketStream,
};
type Tx = UnboundedSender<Message>;
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
type Body = http_body_util::Full<hyper::body::Bytes>;
async fn handle_connection(
peer_map: PeerMap,
ws_stream: WebSocketStream<TokioIo<Upgraded>>,
addr: SocketAddr,
) {
println!("WebSocket connection established: {}", addr);
// Insert the write part of this peer to the peer map.
let (tx, rx) = unbounded();
peer_map.lock().unwrap().insert(addr, tx);
let (outgoing, incoming) = ws_stream.split();
let broadcast_incoming = incoming.try_for_each(|msg| {
println!("Received a message from {}: {}", addr, msg.to_text().unwrap());
let peers = peer_map.lock().unwrap();
// We want to broadcast the message to everyone except ourselves.
let broadcast_recipients =
peers.iter().filter(|(peer_addr, _)| peer_addr != &&addr).map(|(_, ws_sink)| ws_sink);
for recp in broadcast_recipients {
recp.unbounded_send(msg.clone()).unwrap();
}
future::ok(())
});
let receive_from_others = rx.map(Ok).forward(outgoing);
pin_mut!(broadcast_incoming, receive_from_others);
future::select(broadcast_incoming, receive_from_others).await;
println!("{} disconnected", &addr);
peer_map.lock().unwrap().remove(&addr);
}
async fn handle_request(
peer_map: PeerMap,
mut req: Request<Incoming>,
addr: SocketAddr,
) -> Result<Response<Body>, Infallible> {
println!("Received a new, potentially ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
for (ref header, _value) in req.headers() {
println!("* {}", header);
}
let upgrade = HeaderValue::from_static("Upgrade");
let websocket = HeaderValue::from_static("websocket");
let headers = req.headers();
let key = headers.get(SEC_WEBSOCKET_KEY);
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
if req.method() != Method::GET
|| req.version() < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap()))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
|| !headers.get(SEC_WEBSOCKET_VERSION).map(|h| h == "13").unwrap_or(false)
|| key.is_none()
|| req.uri() != "/socket"
{
return Ok(Response::new(Body::from("Hello World!")));
}
let ver = req.version();
tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
let upgraded = TokioIo::new(upgraded);
handle_connection(
peer_map,
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
addr,
)
.await;
}
Err(e) => println!("upgrade error: {}", e),
}
});
let mut res = Response::new(Body::default());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = ver;
res.headers_mut().append(CONNECTION, upgrade);
res.headers_mut().append(UPGRADE, websocket);
res.headers_mut().append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
// Let's add an additional header to our response to the client.
res.headers_mut().append("MyCustomHeader", ":)".parse().unwrap());
res.headers_mut().append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap());
Ok(res)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let state = PeerMap::new(Mutex::new(HashMap::new()));
let addr =
env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string()).parse::<SocketAddr>()?;
let listener = TcpListener::bind(addr).await?;
loop {
let (stream, remote_addr) = listener.accept().await?;
let state = state.clone();
tokio::spawn(async move {
let io = TokioIo::new(stream);
let service = service_fn(move |req| handle_request(state.clone(), req, remote_addr));
let conn = http1::Builder::new().serve_connection(io, service).with_upgrades();
if let Err(err) = conn.await {
eprintln!("failed to serve connection: {err:?}");
}
});
}
}