diff --git a/Cargo.lock b/Cargo.lock index e677b1a7..7d742002 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2451,9 +2451,9 @@ checksum = "7a79a3332a6609480d7d0c9eab957bca6b455b91bb84e66d19f5ff66294b85b8" [[package]] name = "libc" -version = "0.2.182" +version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" [[package]] name = "libfuzzer-sys" @@ -3487,9 +3487,9 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.13" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" dependencies = [ "aws-lc-rs", "bytes", @@ -6442,18 +6442,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.40" +version = "0.8.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.40" +version = "0.8.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f" dependencies = [ "proc-macro2", "quote", diff --git a/src/router/serve.rs b/src/router/serve.rs index 9f4d6981..63b569f0 100644 --- a/src/router/serve.rs +++ b/src/router/serve.rs @@ -3,7 +3,10 @@ mod plain; mod tls; mod unix; -use std::sync::{Arc, atomic::Ordering}; +use std::{ + net::{SocketAddr, TcpListener}, + sync::{Arc, atomic::Ordering}, +}; use tokio::task::JoinSet; use tuwunel_core::{Result, debug_info}; @@ -29,24 +32,42 @@ pub(super) async fn serve(services: Arc, handle: ServerHandle) -> Resu .await?; } - let addrs = config.get_bind_addrs(); - if !addrs.is_empty() { - #[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))] - if config.tls.certs.is_some() { - #[cfg(feature = "direct_tls")] - { - services.globals.init_rustls_provider()?; - tls::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs).await?; - } + let systemd_listeners: Vec<_> = systemd_listeners().collect(); + let systemd_listeners_is_empty = systemd_listeners.is_empty(); + let (listeners, addrs): (Vec<_>, Vec<_>) = config + .get_bind_addrs() + .into_iter() + .filter(|_| systemd_listeners_is_empty) + .map(|addr| { + let listener = TcpListener::bind(addr) + .expect("Failed to bind configured TcpListener to {addr:?}"); - #[cfg(not(feature = "direct_tls"))] - return tuwunel_core::Err!(Config( - "tls", - "tuwunel was not built with direct TLS support (\"direct_tls\")" - )); - } else { - plain::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs); + (listener, addr) + }) + .chain(systemd_listeners) + .inspect(|(listener, _)| { + listener + .set_nonblocking(true) + .expect("Failed to set non-blocking"); + }) + .unzip(); + + #[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))] + if config.tls.certs.is_some() { + #[cfg(feature = "direct_tls")] + { + services.globals.init_rustls_provider()?; + tls::serve(server, &app, &handle.handle_ip, &mut join_set, &listeners, &addrs) + .await?; } + + #[cfg(not(feature = "direct_tls"))] + return tuwunel_core::Err!(Config( + "tls", + "tuwunel was not built with direct TLS support (\"direct_tls\")" + )); + } else { + plain::serve(server, &app, &handle.handle_ip, &mut join_set, &listeners, &addrs)?; } assert!(!join_set.is_empty(), "at least one listener should be installed"); @@ -75,3 +96,27 @@ pub(super) async fn serve(services: Arc, handle: ServerHandle) -> Resu Ok(()) } + +#[cfg(all(feature = "systemd", target_os = "linux"))] +fn systemd_listeners() -> impl Iterator { + sd_notify::listen_fds() + .into_iter() + .flatten() + .filter_map(|fd| { + use std::os::fd::FromRawFd; + + debug_assert!(fd >= 3, "fdno probably not a listener socket"); + // SAFETY: systemd should already take care of providing + // the correct TCP socket, so we just use it via raw fd + let listener = unsafe { TcpListener::from_raw_fd(fd) }; + + let Ok(addr) = listener.local_addr() else { + return None; + }; + + Some((listener, addr)) + }) +} + +#[cfg(any(not(feature = "systemd"), not(target_os = "linux")))] +fn systemd_listeners() -> impl Iterator { std::iter::empty() } diff --git a/src/router/serve/plain.rs b/src/router/serve/plain.rs index 4ee8fcf1..91ffc10b 100644 --- a/src/router/serve/plain.rs +++ b/src/router/serve/plain.rs @@ -1,26 +1,34 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::{SocketAddr, TcpListener}, + sync::Arc, +}; use axum::Router; use axum_server::Handle; use tokio::task::JoinSet; -use tuwunel_core::{Server, info}; +use tuwunel_core::{Result, Server, info}; pub(super) fn serve( server: &Arc, router: &Router, handle: &Handle, join_set: &mut JoinSet>, + listeners: &[TcpListener], addrs: &[SocketAddr], -) { +) -> Result { let router = router .clone() .into_make_service_with_connect_info::(); - for addr in addrs { - let acceptor = axum_server::bind(*addr) + + for listener in listeners { + let acceptor = axum_server::from_tcp(listener.try_clone()?)? .handle(handle.clone()) .serve(router.clone()); + join_set.spawn_on(acceptor, server.runtime()); } info!("Listening on {addrs:?}"); + + Ok(()) } diff --git a/src/router/serve/tls.rs b/src/router/serve/tls.rs index f57bbc00..8e118cfe 100644 --- a/src/router/serve/tls.rs +++ b/src/router/serve/tls.rs @@ -1,4 +1,7 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::{SocketAddr, TcpListener}, + sync::Arc, +}; use axum::Router; use axum_server::Handle; @@ -11,6 +14,7 @@ pub(super) async fn serve( app: &Router, handle: &Handle, join_set: &mut JoinSet>, + listeners: &[TcpListener], addrs: &[SocketAddr], ) -> Result { let tls = &server.config.tls; @@ -43,14 +47,16 @@ pub(super) async fn serve( .into_make_service_with_connect_info::(); if tls.dual_protocol { - for addr in addrs { - join_set.spawn_on( - axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) - .set_upgrade(false) - .handle(handle.clone()) - .serve(app.clone()), - server.runtime(), - ); + for listener in listeners { + let acceptor = axum_server_dual_protocol::from_tcp_dual_protocol( + listener.try_clone()?, + conf.clone(), + )? + .set_upgrade(false) + .handle(handle.clone()) + .serve(app.clone()); + + join_set.spawn_on(acceptor, server.runtime()); } warn!( @@ -58,13 +64,12 @@ pub(super) async fn serve( (HTTP) connections too (insecure!)", ); } else { - for addr in addrs { - join_set.spawn_on( - axum_server::bind_rustls(*addr, conf.clone()) - .handle(handle.clone()) - .serve(app.clone()), - server.runtime(), - ); + for listener in listeners { + let acceptor = axum_server::from_tcp_rustls(listener.try_clone()?, conf.clone())? + .handle(handle.clone()) + .serve(app.clone()); + + join_set.spawn_on(acceptor, server.runtime()); } info!("Listening on {addrs:?} with TLS certificate {certs}");