diff --git a/src/router/serve.rs b/src/router/serve.rs index 9f4d6981..2bd1240c 100644 --- a/src/router/serve.rs +++ b/src/router/serve.rs @@ -3,7 +3,12 @@ mod plain; mod tls; mod unix; -use std::sync::{Arc, atomic::Ordering}; +#[cfg(all(feature = "systemd", target_os = "linux"))] +use std::os::fd::FromRawFd; +use std::{ + net::{SocketAddr, TcpListener}, + sync::{Arc, atomic::Ordering}, +}; use tokio::task::JoinSet; use tuwunel_core::{Result, debug_info}; @@ -29,24 +34,65 @@ pub(super) async fn serve(services: Arc, handle: ServerHandle) -> Resu .await?; } - let addrs = config.get_bind_addrs(); - if !addrs.is_empty() { - #[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))] - if config.tls.certs.is_some() { - #[cfg(feature = "direct_tls")] - { - services.globals.init_rustls_provider()?; - tls::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs).await?; - } + let mut listen_addrs: Vec = vec![]; - #[cfg(not(feature = "direct_tls"))] - return tuwunel_core::Err!(Config( - "tls", - "tuwunel was not built with direct TLS support (\"direct_tls\")" - )); - } else { - plain::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs); + #[cfg(all(feature = "systemd", target_os = "linux"))] + let mut listeners: Vec = if let Ok(fds) = sd_notify::listen_fds() { + fds.filter_map(|fd| { + // SAFETY: systemd should already take care of providing + // the correct TCP socket, so we just use it via raw fd + let listener = unsafe { TcpListener::from_raw_fd(fd) }; + listener + .set_nonblocking(true) + .expect("Failed to set non-blocking"); + if let Ok(addr) = listener.local_addr() { + listen_addrs.push(addr); + Some(listener) + } else { + None + } + }) + .collect() + } else { + vec![] + }; + + #[cfg(not(all(feature = "systemd", target_os = "linux")))] + let mut listeners: Vec = vec![]; + + let addrs = config.get_bind_addrs(); + + if !addrs.is_empty() && listeners.is_empty() { + listeners = addrs + .clone() + .into_iter() + .map(|addr| { + let listener = TcpListener::bind(addr).expect("Failed to bind {addr:?}"); + listener + .set_nonblocking(true) + .expect("Failed to set non-blocking"); + listen_addrs.push(addr); + listener + }) + .collect(); + } + + #[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))] + if config.tls.certs.is_some() { + #[cfg(feature = "direct_tls")] + { + services.globals.init_rustls_provider()?; + tls::serve(server, &app, &handle.handle_ip, &mut join_set, &listen_addrs, &listeners) + .await?; } + + #[cfg(not(feature = "direct_tls"))] + return tuwunel_core::Err!(Config( + "tls", + "tuwunel was not built with direct TLS support (\"direct_tls\")" + )); + } else { + plain::serve(server, &app, &handle.handle_ip, &mut join_set, &listen_addrs, &listeners); } assert!(!join_set.is_empty(), "at least one listener should be installed"); @@ -68,7 +114,7 @@ pub(super) async fn serve(services: Arc, handle: ServerHandle) -> Resu .requests_panic .load(Ordering::Acquire), handle_active, - "Stopped listening on {addrs:?}", + "Stopped listening on {listen_addrs:?}", ); debug_assert_eq!(0, handle_active, "active request handles still pending"); diff --git a/src/router/serve/plain.rs b/src/router/serve/plain.rs index 4ee8fcf1..9826d0f5 100644 --- a/src/router/serve/plain.rs +++ b/src/router/serve/plain.rs @@ -1,4 +1,7 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::{SocketAddr, TcpListener}, + sync::Arc, +}; use axum::Router; use axum_server::Handle; @@ -11,12 +14,14 @@ pub(super) fn serve( handle: &Handle, join_set: &mut JoinSet>, addrs: &[SocketAddr], + listeners: &Vec, ) { let router = router .clone() .into_make_service_with_connect_info::(); - for addr in addrs { - let acceptor = axum_server::bind(*addr) + for listener in listeners { + let acceptor = axum_server::from_tcp(listener.try_clone().unwrap()) + .unwrap() .handle(handle.clone()) .serve(router.clone()); join_set.spawn_on(acceptor, server.runtime()); diff --git a/src/router/serve/tls.rs b/src/router/serve/tls.rs index 09ecce31..d11e7d5e 100644 --- a/src/router/serve/tls.rs +++ b/src/router/serve/tls.rs @@ -1,4 +1,7 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::{SocketAddr, TcpListener}, + sync::Arc, +}; use axum::Router; use axum_server::Handle; @@ -12,6 +15,7 @@ pub(super) async fn serve( handle: &Handle, join_set: &mut JoinSet>, addrs: &[SocketAddr], + listeners: &Vec, ) -> Result { let tls = &server.config.tls; let certs = tls.certs.as_ref().unwrap(); @@ -30,12 +34,16 @@ pub(super) async fn serve( .clone() .into_make_service_with_connect_info::(); if tls.dual_protocol { - for addr in addrs { + for listener in listeners { join_set.spawn_on( - axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) - .set_upgrade(false) - .handle(handle.clone()) - .serve(app.clone()), + axum_server_dual_protocol::from_tcp_dual_protocol( + listener.try_clone().unwrap(), + conf.clone(), + ) + .unwrap() + .set_upgrade(false) + .handle(handle.clone()) + .serve(app.clone()), server.runtime(), ); } @@ -45,9 +53,10 @@ pub(super) async fn serve( (HTTP) connections too (insecure!)", ); } else { - for addr in addrs { + for listener in listeners { join_set.spawn_on( - axum_server::bind_rustls(*addr, conf.clone()) + axum_server::from_tcp_rustls(listener.try_clone().unwrap(), conf.clone()) + .unwrap() .handle(handle.clone()) .serve(app.clone()), server.runtime(),