Flatten conditional branches; eliminate unwrap().
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -3,8 +3,6 @@ mod plain;
|
|||||||
mod tls;
|
mod tls;
|
||||||
mod unix;
|
mod unix;
|
||||||
|
|
||||||
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
|
||||||
use std::os::fd::FromRawFd;
|
|
||||||
use std::{
|
use std::{
|
||||||
net::{SocketAddr, TcpListener},
|
net::{SocketAddr, TcpListener},
|
||||||
sync::{Arc, atomic::Ordering},
|
sync::{Arc, atomic::Ordering},
|
||||||
@@ -34,55 +32,32 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
|
|||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut listen_addrs: Vec<SocketAddr> = vec![];
|
let systemd_listeners: Vec<_> = systemd_listeners().collect();
|
||||||
|
let systemd_listeners_is_empty = systemd_listeners.is_empty();
|
||||||
|
let (listeners, addrs): (Vec<_>, Vec<_>) = config
|
||||||
|
.get_bind_addrs()
|
||||||
|
.into_iter()
|
||||||
|
.filter(|_| systemd_listeners_is_empty)
|
||||||
|
.map(|addr| {
|
||||||
|
let listener = TcpListener::bind(addr)
|
||||||
|
.expect("Failed to bind configured TcpListener to {addr:?}");
|
||||||
|
|
||||||
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
(listener, addr)
|
||||||
let mut listeners: Vec<TcpListener> = if let Ok(fds) = sd_notify::listen_fds() {
|
})
|
||||||
fds.filter_map(|fd| {
|
.chain(systemd_listeners)
|
||||||
// SAFETY: systemd should already take care of providing
|
.inspect(|(listener, _)| {
|
||||||
// the correct TCP socket, so we just use it via raw fd
|
|
||||||
let listener = unsafe { TcpListener::from_raw_fd(fd) };
|
|
||||||
listener
|
listener
|
||||||
.set_nonblocking(true)
|
.set_nonblocking(true)
|
||||||
.expect("Failed to set non-blocking");
|
.expect("Failed to set non-blocking");
|
||||||
if let Ok(addr) = listener.local_addr() {
|
|
||||||
listen_addrs.push(addr);
|
|
||||||
Some(listener)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.collect()
|
.unzip();
|
||||||
} else {
|
|
||||||
vec![]
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(not(all(feature = "systemd", target_os = "linux")))]
|
|
||||||
let mut listeners: Vec<TcpListener> = vec![];
|
|
||||||
|
|
||||||
let addrs = config.get_bind_addrs();
|
|
||||||
|
|
||||||
if !addrs.is_empty() && listeners.is_empty() {
|
|
||||||
listeners = addrs
|
|
||||||
.clone()
|
|
||||||
.into_iter()
|
|
||||||
.map(|addr| {
|
|
||||||
let listener = TcpListener::bind(addr).expect("Failed to bind {addr:?}");
|
|
||||||
listener
|
|
||||||
.set_nonblocking(true)
|
|
||||||
.expect("Failed to set non-blocking");
|
|
||||||
listen_addrs.push(addr);
|
|
||||||
listener
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
|
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
|
||||||
if config.tls.certs.is_some() {
|
if config.tls.certs.is_some() {
|
||||||
#[cfg(feature = "direct_tls")]
|
#[cfg(feature = "direct_tls")]
|
||||||
{
|
{
|
||||||
services.globals.init_rustls_provider()?;
|
services.globals.init_rustls_provider()?;
|
||||||
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &listen_addrs, &listeners)
|
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &listeners, &addrs)
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,7 +67,7 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
|
|||||||
"tuwunel was not built with direct TLS support (\"direct_tls\")"
|
"tuwunel was not built with direct TLS support (\"direct_tls\")"
|
||||||
));
|
));
|
||||||
} else {
|
} else {
|
||||||
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &listen_addrs, &listeners);
|
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &listeners, &addrs)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(!join_set.is_empty(), "at least one listener should be installed");
|
assert!(!join_set.is_empty(), "at least one listener should be installed");
|
||||||
@@ -114,10 +89,34 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
|
|||||||
.requests_panic
|
.requests_panic
|
||||||
.load(Ordering::Acquire),
|
.load(Ordering::Acquire),
|
||||||
handle_active,
|
handle_active,
|
||||||
"Stopped listening on {listen_addrs:?}",
|
"Stopped listening on {addrs:?}",
|
||||||
);
|
);
|
||||||
|
|
||||||
debug_assert_eq!(0, handle_active, "active request handles still pending");
|
debug_assert_eq!(0, handle_active, "active request handles still pending");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
||||||
|
fn systemd_listeners() -> impl Iterator<Item = (TcpListener, SocketAddr)> {
|
||||||
|
sd_notify::listen_fds()
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.filter_map(|fd| {
|
||||||
|
use std::os::fd::FromRawFd;
|
||||||
|
|
||||||
|
debug_assert!(fd >= 3, "fdno probably not a listener socket");
|
||||||
|
// SAFETY: systemd should already take care of providing
|
||||||
|
// the correct TCP socket, so we just use it via raw fd
|
||||||
|
let listener = unsafe { TcpListener::from_raw_fd(fd) };
|
||||||
|
|
||||||
|
let Ok(addr) = listener.local_addr() else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
Some((listener, addr))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(any(not(feature = "systemd"), not(target_os = "linux")))]
|
||||||
|
fn systemd_listeners() -> impl Iterator<Item = (TcpListener, SocketAddr)> { std::iter::empty() }
|
||||||
|
|||||||
@@ -6,26 +6,29 @@ use std::{
|
|||||||
use axum::Router;
|
use axum::Router;
|
||||||
use axum_server::Handle;
|
use axum_server::Handle;
|
||||||
use tokio::task::JoinSet;
|
use tokio::task::JoinSet;
|
||||||
use tuwunel_core::{Server, info};
|
use tuwunel_core::{Result, Server, info};
|
||||||
|
|
||||||
pub(super) fn serve(
|
pub(super) fn serve(
|
||||||
server: &Arc<Server>,
|
server: &Arc<Server>,
|
||||||
router: &Router,
|
router: &Router,
|
||||||
handle: &Handle<SocketAddr>,
|
handle: &Handle<SocketAddr>,
|
||||||
join_set: &mut JoinSet<Result<(), std::io::Error>>,
|
join_set: &mut JoinSet<Result<(), std::io::Error>>,
|
||||||
|
listeners: &[TcpListener],
|
||||||
addrs: &[SocketAddr],
|
addrs: &[SocketAddr],
|
||||||
listeners: &Vec<TcpListener>,
|
) -> Result {
|
||||||
) {
|
|
||||||
let router = router
|
let router = router
|
||||||
.clone()
|
.clone()
|
||||||
.into_make_service_with_connect_info::<SocketAddr>();
|
.into_make_service_with_connect_info::<SocketAddr>();
|
||||||
|
|
||||||
for listener in listeners {
|
for listener in listeners {
|
||||||
let acceptor = axum_server::from_tcp(listener.try_clone().unwrap())
|
let acceptor = axum_server::from_tcp(listener.try_clone()?)?
|
||||||
.unwrap()
|
|
||||||
.handle(handle.clone())
|
.handle(handle.clone())
|
||||||
.serve(router.clone());
|
.serve(router.clone());
|
||||||
|
|
||||||
join_set.spawn_on(acceptor, server.runtime());
|
join_set.spawn_on(acceptor, server.runtime());
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Listening on {addrs:?}");
|
info!("Listening on {addrs:?}");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ pub(super) async fn serve(
|
|||||||
app: &Router,
|
app: &Router,
|
||||||
handle: &Handle<SocketAddr>,
|
handle: &Handle<SocketAddr>,
|
||||||
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
|
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
|
||||||
|
listeners: &[TcpListener],
|
||||||
addrs: &[SocketAddr],
|
addrs: &[SocketAddr],
|
||||||
listeners: &Vec<TcpListener>,
|
|
||||||
) -> Result {
|
) -> Result {
|
||||||
let tls = &server.config.tls;
|
let tls = &server.config.tls;
|
||||||
let certs = tls.certs.as_ref().unwrap();
|
let certs = tls.certs.as_ref().unwrap();
|
||||||
@@ -35,17 +35,15 @@ pub(super) async fn serve(
|
|||||||
.into_make_service_with_connect_info::<SocketAddr>();
|
.into_make_service_with_connect_info::<SocketAddr>();
|
||||||
if tls.dual_protocol {
|
if tls.dual_protocol {
|
||||||
for listener in listeners {
|
for listener in listeners {
|
||||||
join_set.spawn_on(
|
let acceptor = axum_server_dual_protocol::from_tcp_dual_protocol(
|
||||||
axum_server_dual_protocol::from_tcp_dual_protocol(
|
listener.try_clone()?,
|
||||||
listener.try_clone().unwrap(),
|
conf.clone(),
|
||||||
conf.clone(),
|
)?
|
||||||
)
|
.set_upgrade(false)
|
||||||
.unwrap()
|
.handle(handle.clone())
|
||||||
.set_upgrade(false)
|
.serve(app.clone());
|
||||||
.handle(handle.clone())
|
|
||||||
.serve(app.clone()),
|
join_set.spawn_on(acceptor, server.runtime());
|
||||||
server.runtime(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
warn!(
|
warn!(
|
||||||
@@ -54,13 +52,11 @@ pub(super) async fn serve(
|
|||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
for listener in listeners {
|
for listener in listeners {
|
||||||
join_set.spawn_on(
|
let acceptor = axum_server::from_tcp_rustls(listener.try_clone()?, conf.clone())?
|
||||||
axum_server::from_tcp_rustls(listener.try_clone().unwrap(), conf.clone())
|
.handle(handle.clone())
|
||||||
.unwrap()
|
.serve(app.clone());
|
||||||
.handle(handle.clone())
|
|
||||||
.serve(app.clone()),
|
join_set.spawn_on(acceptor, server.runtime());
|
||||||
server.runtime(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Listening on {addrs:?} with TLS certificate {certs}");
|
info!("Listening on {addrs:?} with TLS certificate {certs}");
|
||||||
|
|||||||
Reference in New Issue
Block a user