Add support for systemd socket activation
Co-authored-by: Jason Volk <jason@zemos.net> Signed-off-by: Vasiliy Stelmachenok <ventureo@cachyos.org>
This commit is contained in:
@@ -3,7 +3,12 @@ mod plain;
|
||||
mod tls;
|
||||
mod unix;
|
||||
|
||||
use std::sync::{Arc, atomic::Ordering};
|
||||
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
||||
use std::os::fd::FromRawFd;
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
sync::{Arc, atomic::Ordering},
|
||||
};
|
||||
|
||||
use tokio::task::JoinSet;
|
||||
use tuwunel_core::{Result, debug_info};
|
||||
@@ -29,14 +34,56 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
|
||||
.await?;
|
||||
}
|
||||
|
||||
let mut listen_addrs: Vec<SocketAddr> = vec![];
|
||||
|
||||
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
||||
let mut listeners: Vec<TcpListener> = if let Ok(fds) = sd_notify::listen_fds() {
|
||||
fds.filter_map(|fd| {
|
||||
// SAFETY: systemd should already take care of providing
|
||||
// the correct TCP socket, so we just use it via raw fd
|
||||
let listener = unsafe { TcpListener::from_raw_fd(fd) };
|
||||
listener
|
||||
.set_nonblocking(true)
|
||||
.expect("Failed to set non-blocking");
|
||||
if let Ok(addr) = listener.local_addr() {
|
||||
listen_addrs.push(addr);
|
||||
Some(listener)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
#[cfg(not(all(feature = "systemd", target_os = "linux")))]
|
||||
let mut listeners: Vec<TcpListener> = vec![];
|
||||
|
||||
let addrs = config.get_bind_addrs();
|
||||
if !addrs.is_empty() {
|
||||
|
||||
if !addrs.is_empty() && listeners.is_empty() {
|
||||
listeners = addrs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|addr| {
|
||||
let listener = TcpListener::bind(addr).expect("Failed to bind {addr:?}");
|
||||
listener
|
||||
.set_nonblocking(true)
|
||||
.expect("Failed to set non-blocking");
|
||||
listen_addrs.push(addr);
|
||||
listener
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
|
||||
if config.tls.certs.is_some() {
|
||||
#[cfg(feature = "direct_tls")]
|
||||
{
|
||||
services.globals.init_rustls_provider()?;
|
||||
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs).await?;
|
||||
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &listen_addrs, &listeners)
|
||||
.await?;
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "direct_tls"))]
|
||||
@@ -45,8 +92,7 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
|
||||
"tuwunel was not built with direct TLS support (\"direct_tls\")"
|
||||
));
|
||||
} else {
|
||||
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs);
|
||||
}
|
||||
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &listen_addrs, &listeners);
|
||||
}
|
||||
|
||||
assert!(!join_set.is_empty(), "at least one listener should be installed");
|
||||
@@ -68,7 +114,7 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
|
||||
.requests_panic
|
||||
.load(Ordering::Acquire),
|
||||
handle_active,
|
||||
"Stopped listening on {addrs:?}",
|
||||
"Stopped listening on {listen_addrs:?}",
|
||||
);
|
||||
|
||||
debug_assert_eq!(0, handle_active, "active request handles still pending");
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::Handle;
|
||||
@@ -11,12 +14,14 @@ pub(super) fn serve(
|
||||
handle: &Handle<SocketAddr>,
|
||||
join_set: &mut JoinSet<Result<(), std::io::Error>>,
|
||||
addrs: &[SocketAddr],
|
||||
listeners: &Vec<TcpListener>,
|
||||
) {
|
||||
let router = router
|
||||
.clone()
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
for addr in addrs {
|
||||
let acceptor = axum_server::bind(*addr)
|
||||
for listener in listeners {
|
||||
let acceptor = axum_server::from_tcp(listener.try_clone().unwrap())
|
||||
.unwrap()
|
||||
.handle(handle.clone())
|
||||
.serve(router.clone());
|
||||
join_set.spawn_on(acceptor, server.runtime());
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::Handle;
|
||||
@@ -12,6 +15,7 @@ pub(super) async fn serve(
|
||||
handle: &Handle<SocketAddr>,
|
||||
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
|
||||
addrs: &[SocketAddr],
|
||||
listeners: &Vec<TcpListener>,
|
||||
) -> Result {
|
||||
let tls = &server.config.tls;
|
||||
let certs = tls.certs.as_ref().unwrap();
|
||||
@@ -30,9 +34,13 @@ pub(super) async fn serve(
|
||||
.clone()
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
if tls.dual_protocol {
|
||||
for addr in addrs {
|
||||
for listener in listeners {
|
||||
join_set.spawn_on(
|
||||
axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone())
|
||||
axum_server_dual_protocol::from_tcp_dual_protocol(
|
||||
listener.try_clone().unwrap(),
|
||||
conf.clone(),
|
||||
)
|
||||
.unwrap()
|
||||
.set_upgrade(false)
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
@@ -45,9 +53,10 @@ pub(super) async fn serve(
|
||||
(HTTP) connections too (insecure!)",
|
||||
);
|
||||
} else {
|
||||
for addr in addrs {
|
||||
for listener in listeners {
|
||||
join_set.spawn_on(
|
||||
axum_server::bind_rustls(*addr, conf.clone())
|
||||
axum_server::from_tcp_rustls(listener.try_clone().unwrap(), conf.clone())
|
||||
.unwrap()
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
server.runtime(),
|
||||
|
||||
Reference in New Issue
Block a user