Update axum-server to 0.8; switch to axum unix listener.

This commit is contained in:
dasha_uwu
2026-02-03 02:30:50 +05:00
committed by Jason Volk
parent 87faf818ff
commit bd5203b406
13 changed files with 231 additions and 390 deletions

View File

@@ -115,7 +115,6 @@ futures.workspace = true
http.workspace = true
http-body-util.workspace = true
hyper.workspace = true
hyper-util.workspace = true
log.workspace = true
ruma.workspace = true
rustls.workspace = true

27
src/router/handle.rs Normal file
View File

@@ -0,0 +1,27 @@
use std::time::Duration;
use axum_server::Handle;
#[derive(Clone)]
pub(crate) struct ServerHandle {
pub(crate) handle_ip: Handle<std::net::SocketAddr>,
#[cfg(unix)]
pub(crate) handle_unix: Handle<std::os::unix::net::SocketAddr>,
}
impl ServerHandle {
pub(crate) fn new() -> Self {
Self {
handle_ip: Handle::<std::net::SocketAddr>::new(),
#[cfg(unix)]
handle_unix: Handle::<std::os::unix::net::SocketAddr>::new(),
}
}
pub(crate) fn graceful_shutdown(&self, duration: Option<Duration>) {
self.handle_ip.graceful_shutdown(duration);
#[cfg(unix)]
self.handle_unix.graceful_shutdown(duration);
}
}

View File

@@ -1,6 +1,7 @@
#![type_length_limit = "32768"] //TODO: reduce me
#![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91
mod handle;
mod layers;
mod request;
mod router;

View File

@@ -3,15 +3,11 @@ use std::{
time::Duration,
};
use axum_server::Handle as ServerHandle;
use tokio::{
sync::broadcast::{self, Sender},
task::JoinHandle,
};
use futures::FutureExt;
use tuwunel_core::{Error, Result, Server, debug, debug_error, debug_info, error, info};
use tuwunel_service::Services;
use crate::serve;
use crate::{handle::ServerHandle, serve};
/// Main loop base
#[tracing::instrument(skip_all)]
@@ -24,20 +20,30 @@ pub(crate) async fn run(services: Arc<Services>) -> Result {
// Setup shutdown/signal handling
let handle = ServerHandle::new();
let (tx, _) = broadcast::channel::<()>(1);
let sigs = server
.runtime()
.spawn(signal(server.clone(), tx.clone(), handle.clone()));
.spawn(signal(server.clone(), handle.clone()));
let mut listener =
let mut listener = if services.config.listening {
let future = serve::serve(services.clone(), handle);
server
.runtime()
.spawn(serve::serve(services.clone(), handle.clone(), tx.subscribe()));
.spawn(future)
.map(|res| res.map_err(Error::from).unwrap_or_else(Err))
.boxed()
} else {
let server = server.clone();
async move {
server.until_shutdown().await;
Ok(())
}
.boxed()
};
// Focal point
debug!("Running");
let res = tokio::select! {
res = &mut listener => res.map_err(Error::from).unwrap_or_else(Err),
res = &mut listener => res,
res = services.poll() => handle_services_poll(server, res, listener).await,
};
@@ -104,16 +110,12 @@ pub(crate) async fn stop(services: Arc<Services>) -> Result {
}
#[tracing::instrument(skip_all)]
async fn signal(server: Arc<Server>, tx: Sender<()>, handle: axum_server::Handle) {
async fn signal(server: Arc<Server>, handle: ServerHandle) {
server.until_shutdown().await;
handle_shutdown(&server, &tx, &handle);
handle_shutdown(&server, &handle);
}
fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_server::Handle) {
if let Err(e) = tx.send(()) {
error!("failed sending shutdown transaction to channel: {e}");
}
fn handle_shutdown(server: &Arc<Server>, handle: &ServerHandle) {
let timeout = server.config.client_shutdown_timeout;
let timeout = Duration::from_secs(timeout);
debug!(
@@ -128,7 +130,7 @@ fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_server::
async fn handle_services_poll(
server: &Arc<Server>,
result: Result,
listener: JoinHandle<Result>,
listener: impl Future<Output = Result>,
) -> Result {
debug!("Service manager finished: {result:?}");

View File

@@ -3,46 +3,75 @@ mod plain;
mod tls;
mod unix;
use std::sync::Arc;
use std::sync::{Arc, atomic::Ordering};
use axum_server::Handle as ServerHandle;
use tokio::sync::broadcast;
use tuwunel_core::{Result, err};
use tokio::task::JoinSet;
use tuwunel_core::{Result, debug_info};
use tuwunel_service::Services;
use super::layers;
use crate::handle::ServerHandle;
/// Serve clients
pub(super) async fn serve(
services: Arc<Services>,
handle: ServerHandle,
mut shutdown: broadcast::Receiver<()>,
) -> Result {
pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Result {
let server = &services.server;
let config = &server.config;
if !config.listening {
return shutdown
.recv()
.await
.map_err(|e| err!(error!("channel error: {e}")));
let (app, _guard) = layers::build(&services)?;
let mut join_set = JoinSet::new();
#[cfg(unix)]
if let Some(unix_socket) = &config.unix_socket_path {
let socket_perms = config.get_unix_socket_perms()?;
unix::serve(server, &app, &handle.handle_unix, &mut join_set, unix_socket, socket_perms)
.await?;
}
let addrs = config.get_bind_addrs();
let (app, _guard) = layers::build(&services)?;
if cfg!(unix) && config.unix_socket_path.is_some() {
unix::serve(server, app, shutdown).await
} else if config.tls.certs.is_some() {
#[cfg(feature = "direct_tls")]
services.globals.init_rustls_provider()?;
#[cfg(feature = "direct_tls")]
return tls::serve(server, app, handle, addrs).await;
if !addrs.is_empty() {
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
if config.tls.certs.is_some() {
#[cfg(feature = "direct_tls")]
{
services.globals.init_rustls_provider()?;
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs).await?;
}
#[cfg(not(feature = "direct_tls"))]
return tuwunel_core::Err!(Config(
"tls",
"tuwunel was not built with direct TLS support (\"direct_tls\")"
));
} else {
plain::serve(server, app, handle, addrs).await
#[cfg(not(feature = "direct_tls"))]
return tuwunel_core::Err!(Config(
"tls",
"tuwunel was not built with direct TLS support (\"direct_tls\")"
));
} else {
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs);
}
}
assert!(!join_set.is_empty(), "at least one listener should be installed");
join_set.join_all().await;
let handle_active = server
.metrics
.requests_handle_active
.load(Ordering::Acquire);
debug_info!(
handle_finished = server
.metrics
.requests_handle_finished
.load(Ordering::Acquire),
panics = server
.metrics
.requests_panic
.load(Ordering::Acquire),
handle_active,
"Stopped listening on {addrs:?}",
);
debug_assert_eq!(0, handle_active, "active request handles still pending");
Ok(())
}

View File

@@ -1,50 +1,26 @@
use std::{
net::SocketAddr,
sync::{Arc, atomic::Ordering},
};
use std::{net::SocketAddr, sync::Arc};
use axum::Router;
use axum_server::{Handle as ServerHandle, bind};
use axum_server::Handle;
use tokio::task::JoinSet;
use tuwunel_core::{Result, Server, debug_info, info};
use tuwunel_core::{Server, info};
pub(super) async fn serve(
pub(super) fn serve(
server: &Arc<Server>,
router: Router,
handle: ServerHandle,
addrs: Vec<SocketAddr>,
) -> Result {
let mut join_set = JoinSet::new();
let router = router.into_make_service_with_connect_info::<SocketAddr>();
for addr in &addrs {
let bound = bind(*addr);
let handler = bound.handle(handle.clone());
let acceptor = handler.serve(router.clone());
router: &Router,
handle: &Handle<SocketAddr>,
join_set: &mut JoinSet<Result<(), std::io::Error>>,
addrs: &[SocketAddr],
) {
let router = router
.clone()
.into_make_service_with_connect_info::<SocketAddr>();
for addr in addrs {
let acceptor = axum_server::bind(*addr)
.handle(handle.clone())
.serve(router.clone());
join_set.spawn_on(acceptor, server.runtime());
}
info!("Listening on {addrs:?}");
while join_set.join_next().await.is_some() {}
let handle_active = server
.metrics
.requests_handle_active
.load(Ordering::Acquire);
debug_info!(
handle_finished = server
.metrics
.requests_handle_finished
.load(Ordering::Acquire),
panics = server
.metrics
.requests_panic
.load(Ordering::Acquire),
handle_active,
"Stopped listening on {addrs:?}",
);
debug_assert_eq!(0, handle_active, "active request handles still pending");
Ok(())
}

View File

@@ -1,31 +1,21 @@
use std::{
net::SocketAddr,
sync::{Arc, atomic::Ordering},
};
use std::{net::SocketAddr, sync::Arc};
use axum::Router;
use axum_server::Handle as ServerHandle;
use axum_server_dual_protocol::{
ServerExt,
axum_server::{bind_rustls, tls_rustls::RustlsConfig},
};
use axum_server::Handle;
use axum_server_dual_protocol::{ServerExt, axum_server::tls_rustls::RustlsConfig};
use tokio::task::JoinSet;
use tuwunel_core::{Result, Server, debug, debug_info, err, info, warn};
use tuwunel_core::{Result, Server, debug, err, info, warn};
pub(super) async fn serve(
server: &Arc<Server>,
app: Router,
handle: ServerHandle,
addrs: Vec<SocketAddr>,
app: &Router,
handle: &Handle<SocketAddr>,
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
addrs: &[SocketAddr],
) -> Result {
let tls = &server.config.tls;
let certs = tls.certs.as_ref().ok_or_else(|| {
err!(Config("tls.certs", "Missing required value in tls config section"))
})?;
let key = tls
.key
.as_ref()
.ok_or_else(|| err!(Config("tls.key", "Missing required value in tls config section")))?;
let certs = tls.certs.as_ref().unwrap();
let key = tls.key.as_ref().unwrap();
info!(
"Note: It is strongly recommended that you use a reverse proxy instead of running \
@@ -36,10 +26,11 @@ pub(super) async fn serve(
.await
.map_err(|e| err!(Config("tls", "Failed to load certificates or key: {e}")))?;
let mut join_set = JoinSet::new();
let app = app.into_make_service_with_connect_info::<SocketAddr>();
let app = app
.clone()
.into_make_service_with_connect_info::<SocketAddr>();
if tls.dual_protocol {
for addr in &addrs {
for addr in addrs {
join_set.spawn_on(
axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone())
.set_upgrade(false)
@@ -48,47 +39,23 @@ pub(super) async fn serve(
server.runtime(),
);
}
} else {
for addr in &addrs {
join_set.spawn_on(
bind_rustls(*addr, conf.clone())
.handle(handle.clone())
.serve(app.clone()),
server.runtime(),
);
}
}
if tls.dual_protocol {
warn!(
"Listening on {addrs:?} with TLS certificate {certs} and supporting plain text \
(HTTP) connections too (insecure!)",
);
} else {
for addr in addrs {
join_set.spawn_on(
axum_server::bind_rustls(*addr, conf.clone())
.handle(handle.clone())
.serve(app.clone()),
server.runtime(),
);
}
info!("Listening on {addrs:?} with TLS certificate {certs}");
}
while join_set.join_next().await.is_some() {}
let handle_active = server
.metrics
.requests_handle_active
.load(Ordering::Acquire);
debug_info!(
handle_finished = server
.metrics
.requests_handle_finished
.load(Ordering::Acquire),
panics = server
.metrics
.requests_panic
.load(Ordering::Acquire),
handle_active,
"Stopped listening on {addrs:?}",
);
debug_assert_eq!(0, handle_active, "active request handles still pending");
Ok(())
}

View File

@@ -1,185 +1,48 @@
#![cfg(unix)]
use std::{
net::{self, IpAddr, Ipv4Addr},
os::fd::AsRawFd,
path::Path,
sync::{Arc, atomic::Ordering},
};
use axum::{
Router,
extract::{Request, connect_info::IntoMakeServiceWithConnectInfo},
};
use hyper::{body::Incoming, service::service_fn};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server,
};
use tokio::{
fs,
net::{UnixListener, UnixStream, unix::SocketAddr},
sync::broadcast::{self},
task::JoinSet,
time::{Duration, sleep},
};
use tower::{Service, ServiceExt};
use tuwunel_core::{
Err, Result, Server, debug, debug_error, info, result::UnwrapInfallible, trace, warn,
net::SocketAddr,
os::unix::{fs::PermissionsExt, net::UnixListener},
path::Path,
sync::Arc,
};
type MakeService = IntoMakeServiceWithConnectInfo<Router, net::SocketAddr>;
const NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
const FINI_POLL_INTERVAL: Duration = Duration::from_millis(750);
use axum::{Extension, Router, extract::ConnectInfo};
use axum_server::Handle;
use tokio::task::JoinSet;
use tuwunel_core::{Result, Server, info, warn};
#[tracing::instrument(skip_all, level = "debug")]
pub(super) async fn serve(
server: &Arc<Server>,
app: Router,
mut shutdown: broadcast::Receiver<()>,
router: &Router,
handle: &Handle<std::os::unix::net::SocketAddr>,
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
path: &Path,
socket_perms: u32,
) -> Result {
let mut tasks = JoinSet::<()>::new();
let executor = TokioExecutor::new();
let app = app.into_make_service_with_connect_info::<net::SocketAddr>();
let builder = server::conn::auto::Builder::new(executor);
let listener = init(server).await?;
while server.running() {
let app = app.clone();
let builder = builder.clone();
tokio::select! {
_sig = shutdown.recv() => break,
conn = listener.accept() => match conn {
Ok(conn) => accept(server, &listener, &mut tasks, app, builder, conn).await,
Err(err) => debug_error!(?listener, "accept error: {err}"),
},
}
}
fini(server, listener, tasks).await;
Ok(())
}
#[tracing::instrument(
level = "trace",
skip_all,
fields(
?listener,
socket = ?conn.0,
),
)]
async fn accept(
server: &Arc<Server>,
listener: &UnixListener,
tasks: &mut JoinSet<()>,
app: MakeService,
builder: server::conn::auto::Builder<TokioExecutor>,
conn: (UnixStream, SocketAddr),
) {
let (socket, _) = conn;
let server_ = server.clone();
let task = async move { accepted(server_, builder, socket, app).await };
_ = tasks.spawn_on(task, server.runtime());
while tasks.try_join_next().is_some() {}
}
#[tracing::instrument(
level = "trace",
skip_all,
fields(
fd = %socket.as_raw_fd(),
path = ?socket.local_addr(),
),
)]
async fn accepted(
server: Arc<Server>,
builder: server::conn::auto::Builder<TokioExecutor>,
socket: UnixStream,
mut app: MakeService,
) {
let socket = TokioIo::new(socket);
let called = app.call(NULL_ADDR).await.unwrap_infallible();
let service = move |req: Request<Incoming>| called.clone().oneshot(req);
let handler = service_fn(service);
trace!(?socket, ?handler, "serving connection");
// bug on darwin causes all results to be errors. do not unwrap this
tokio::select! {
() = server.until_shutdown() => (),
_ = builder.serve_connection(socket, handler) => (),
};
}
async fn init(server: &Arc<Server>) -> Result<UnixListener> {
use std::os::unix::fs::PermissionsExt;
let config = &server.config;
let path = config
.unix_socket_path
.as_ref()
.expect("failed to extract configured unix socket path");
if path.exists() {
warn!("Removing existing UNIX socket {:#?} (unclean shutdown?)...", path.display());
fs::remove_file(&path)
.await
.map_err(|e| warn!("Failed to remove existing UNIX socket: {e}"))
.unwrap();
warn!("Removing existing UNIX socket {path:?} (unclean shutdown?)...");
fs::remove_file(path)?;
}
let dir = path.parent().unwrap_or_else(|| Path::new("/"));
if let Err(e) = fs::create_dir_all(dir).await {
return Err!("Failed to create {dir:?} for socket {path:?}: {e}");
}
let unix_listener = UnixListener::bind(path)?;
unix_listener.set_nonblocking(true)?;
let listener = UnixListener::bind(path);
if let Err(e) = listener {
return Err!("Failed to bind listener {path:?}: {e}");
}
let perms = fs::Permissions::from_mode(socket_perms);
fs::set_permissions(path, perms)?;
let socket_perms = config.unix_socket_perms.to_string();
let octal_perms =
u32::from_str_radix(&socket_perms, 8).expect("failed to convert octal permissions");
let perms = std::fs::Permissions::from_mode(octal_perms);
if let Err(e) = fs::set_permissions(&path, perms).await {
return Err!("Failed to set socket {path:?} permissions: {e}");
}
let router = router
.clone()
.layer(Extension(ConnectInfo("0.0.0.0".parse::<SocketAddr>())))
.into_make_service();
let acceptor = axum_server::from_unix(unix_listener)?
.handle(handle.clone())
.serve(router);
join_set.spawn_on(acceptor, server.runtime());
info!("Listening at {path:?}");
Ok(listener.unwrap())
}
async fn fini(server: &Arc<Server>, listener: UnixListener, mut tasks: JoinSet<()>) {
let local = listener.local_addr();
debug!("Closing listener at {local:?} ...");
drop(listener);
debug!("Waiting for requests to finish...");
while server
.metrics
.requests_handle_active
.load(Ordering::Acquire)
.gt(&0)
{
tokio::select! {
task = tasks.join_next() => if task.is_none() { break; },
() = sleep(FINI_POLL_INTERVAL) => {},
}
}
debug!("Shutting down...");
tasks.shutdown().await;
if let Ok(local) = local
&& let Some(path) = local.as_pathname()
{
debug!(?path, "Removing unix socket file.");
if let Err(e) = fs::remove_file(path).await {
warn!(?path, "Failed to remove UNIX socket file: {e}");
}
}
Ok(())
}