Flatten conditional branches; eliminate unwrap().

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-03-10 01:09:36 +00:00
parent 2a1d34bee1
commit aa847e4844
3 changed files with 64 additions and 66 deletions

View File

@@ -3,8 +3,6 @@ mod plain;
mod tls; mod tls;
mod unix; mod unix;
#[cfg(all(feature = "systemd", target_os = "linux"))]
use std::os::fd::FromRawFd;
use std::{ use std::{
net::{SocketAddr, TcpListener}, net::{SocketAddr, TcpListener},
sync::{Arc, atomic::Ordering}, sync::{Arc, atomic::Ordering},
@@ -34,55 +32,32 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
.await?; .await?;
} }
let mut listen_addrs: Vec<SocketAddr> = vec![]; let systemd_listeners: Vec<_> = systemd_listeners().collect();
let systemd_listeners_is_empty = systemd_listeners.is_empty();
let (listeners, addrs): (Vec<_>, Vec<_>) = config
.get_bind_addrs()
.into_iter()
.filter(|_| systemd_listeners_is_empty)
.map(|addr| {
let listener = TcpListener::bind(addr)
.expect("Failed to bind configured TcpListener to {addr:?}");
#[cfg(all(feature = "systemd", target_os = "linux"))] (listener, addr)
let mut listeners: Vec<TcpListener> = if let Ok(fds) = sd_notify::listen_fds() { })
fds.filter_map(|fd| { .chain(systemd_listeners)
// SAFETY: systemd should already take care of providing .inspect(|(listener, _)| {
// the correct TCP socket, so we just use it via raw fd
let listener = unsafe { TcpListener::from_raw_fd(fd) };
listener listener
.set_nonblocking(true) .set_nonblocking(true)
.expect("Failed to set non-blocking"); .expect("Failed to set non-blocking");
if let Ok(addr) = listener.local_addr() {
listen_addrs.push(addr);
Some(listener)
} else {
None
}
}) })
.collect() .unzip();
} else {
vec![]
};
#[cfg(not(all(feature = "systemd", target_os = "linux")))]
let mut listeners: Vec<TcpListener> = vec![];
let addrs = config.get_bind_addrs();
if !addrs.is_empty() && listeners.is_empty() {
listeners = addrs
.clone()
.into_iter()
.map(|addr| {
let listener = TcpListener::bind(addr).expect("Failed to bind {addr:?}");
listener
.set_nonblocking(true)
.expect("Failed to set non-blocking");
listen_addrs.push(addr);
listener
})
.collect();
}
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))] #[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
if config.tls.certs.is_some() { if config.tls.certs.is_some() {
#[cfg(feature = "direct_tls")] #[cfg(feature = "direct_tls")]
{ {
services.globals.init_rustls_provider()?; services.globals.init_rustls_provider()?;
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &listen_addrs, &listeners) tls::serve(server, &app, &handle.handle_ip, &mut join_set, &listeners, &addrs)
.await?; .await?;
} }
@@ -92,7 +67,7 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
"tuwunel was not built with direct TLS support (\"direct_tls\")" "tuwunel was not built with direct TLS support (\"direct_tls\")"
)); ));
} else { } else {
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &listen_addrs, &listeners); plain::serve(server, &app, &handle.handle_ip, &mut join_set, &listeners, &addrs)?;
} }
assert!(!join_set.is_empty(), "at least one listener should be installed"); assert!(!join_set.is_empty(), "at least one listener should be installed");
@@ -114,10 +89,34 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
.requests_panic .requests_panic
.load(Ordering::Acquire), .load(Ordering::Acquire),
handle_active, handle_active,
"Stopped listening on {listen_addrs:?}", "Stopped listening on {addrs:?}",
); );
debug_assert_eq!(0, handle_active, "active request handles still pending"); debug_assert_eq!(0, handle_active, "active request handles still pending");
Ok(()) Ok(())
} }
#[cfg(all(feature = "systemd", target_os = "linux"))]
fn systemd_listeners() -> impl Iterator<Item = (TcpListener, SocketAddr)> {
sd_notify::listen_fds()
.into_iter()
.flatten()
.filter_map(|fd| {
use std::os::fd::FromRawFd;
debug_assert!(fd >= 3, "fdno probably not a listener socket");
// SAFETY: systemd should already take care of providing
// the correct TCP socket, so we just use it via raw fd
let listener = unsafe { TcpListener::from_raw_fd(fd) };
let Ok(addr) = listener.local_addr() else {
return None;
};
Some((listener, addr))
})
}
#[cfg(any(not(feature = "systemd"), not(target_os = "linux")))]
fn systemd_listeners() -> impl Iterator<Item = (TcpListener, SocketAddr)> { std::iter::empty() }

View File

@@ -6,26 +6,29 @@ use std::{
use axum::Router; use axum::Router;
use axum_server::Handle; use axum_server::Handle;
use tokio::task::JoinSet; use tokio::task::JoinSet;
use tuwunel_core::{Server, info}; use tuwunel_core::{Result, Server, info};
pub(super) fn serve( pub(super) fn serve(
server: &Arc<Server>, server: &Arc<Server>,
router: &Router, router: &Router,
handle: &Handle<SocketAddr>, handle: &Handle<SocketAddr>,
join_set: &mut JoinSet<Result<(), std::io::Error>>, join_set: &mut JoinSet<Result<(), std::io::Error>>,
listeners: &[TcpListener],
addrs: &[SocketAddr], addrs: &[SocketAddr],
listeners: &Vec<TcpListener>, ) -> Result {
) {
let router = router let router = router
.clone() .clone()
.into_make_service_with_connect_info::<SocketAddr>(); .into_make_service_with_connect_info::<SocketAddr>();
for listener in listeners { for listener in listeners {
let acceptor = axum_server::from_tcp(listener.try_clone().unwrap()) let acceptor = axum_server::from_tcp(listener.try_clone()?)?
.unwrap()
.handle(handle.clone()) .handle(handle.clone())
.serve(router.clone()); .serve(router.clone());
join_set.spawn_on(acceptor, server.runtime()); join_set.spawn_on(acceptor, server.runtime());
} }
info!("Listening on {addrs:?}"); info!("Listening on {addrs:?}");
Ok(())
} }

View File

@@ -14,8 +14,8 @@ pub(super) async fn serve(
app: &Router, app: &Router,
handle: &Handle<SocketAddr>, handle: &Handle<SocketAddr>,
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>, join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
listeners: &[TcpListener],
addrs: &[SocketAddr], addrs: &[SocketAddr],
listeners: &Vec<TcpListener>,
) -> Result { ) -> Result {
let tls = &server.config.tls; let tls = &server.config.tls;
let certs = tls.certs.as_ref().unwrap(); let certs = tls.certs.as_ref().unwrap();
@@ -35,17 +35,15 @@ pub(super) async fn serve(
.into_make_service_with_connect_info::<SocketAddr>(); .into_make_service_with_connect_info::<SocketAddr>();
if tls.dual_protocol { if tls.dual_protocol {
for listener in listeners { for listener in listeners {
join_set.spawn_on( let acceptor = axum_server_dual_protocol::from_tcp_dual_protocol(
axum_server_dual_protocol::from_tcp_dual_protocol( listener.try_clone()?,
listener.try_clone().unwrap(), conf.clone(),
conf.clone(), )?
) .set_upgrade(false)
.unwrap() .handle(handle.clone())
.set_upgrade(false) .serve(app.clone());
.handle(handle.clone())
.serve(app.clone()), join_set.spawn_on(acceptor, server.runtime());
server.runtime(),
);
} }
warn!( warn!(
@@ -54,13 +52,11 @@ pub(super) async fn serve(
); );
} else { } else {
for listener in listeners { for listener in listeners {
join_set.spawn_on( let acceptor = axum_server::from_tcp_rustls(listener.try_clone()?, conf.clone())?
axum_server::from_tcp_rustls(listener.try_clone().unwrap(), conf.clone()) .handle(handle.clone())
.unwrap() .serve(app.clone());
.handle(handle.clone())
.serve(app.clone()), join_set.spawn_on(acceptor, server.runtime());
server.runtime(),
);
} }
info!("Listening on {addrs:?} with TLS certificate {certs}"); info!("Listening on {addrs:?} with TLS certificate {certs}");