From bd5203b406488b6e34c770ba0ed743209a845401 Mon Sep 17 00:00:00 2001 From: dasha_uwu Date: Tue, 3 Feb 2026 02:30:50 +0500 Subject: [PATCH] Update axum-server to 0.8; switch to axum unix listener. --- Cargo.lock | 21 +---- Cargo.toml | 14 +-- src/core/config/check.rs | 60 ++++++------ src/core/config/mod.rs | 39 ++++---- src/router/Cargo.toml | 1 - src/router/handle.rs | 27 ++++++ src/router/mod.rs | 1 + src/router/run.rs | 40 ++++---- src/router/serve/mod.rs | 87 +++++++++++------ src/router/serve/plain.rs | 56 ++++------- src/router/serve/tls.rs | 79 +++++----------- src/router/serve/unix.rs | 193 ++++++-------------------------------- tuwunel-example.toml | 3 - 13 files changed, 231 insertions(+), 390 deletions(-) create mode 100644 src/router/handle.rs diff --git a/Cargo.lock b/Cargo.lock index 3ee7fcc8..cc8b0c94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -402,12 +402,13 @@ dependencies = [ [[package]] name = "axum-server" -version = "0.7.3" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ab4a3ec9ea8a657c72d99a03a824af695bd0fb5ec639ccbd9cd3543b41a5f9" +checksum = "b1df331683d982a0b9492b38127151e6453639cd34926eb9c07d4cd8c6d22bfc" dependencies = [ "arc-swap", "bytes", + "either", "fs-err", "http", "http-body", @@ -415,7 +416,6 @@ dependencies = [ "hyper-util", "pin-project-lite", "rustls", - "rustls-pemfile", "rustls-pki-types", "tokio", "tokio-rustls", @@ -424,9 +424,8 @@ dependencies = [ [[package]] name = "axum-server-dual-protocol" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2164551db024e87f20316d164eab9f5ad342d8188b08051ceb15ca92a60ea7b7" +version = "0.8.0" +source = "git+https://github.com/matrix-construct/axum-server-dual-protocol?rev=76c782fa6f129f83ffdca59e903093e798d4a82f#76c782fa6f129f83ffdca59e903093e798d4a82f" dependencies = [ "axum-server", "bytes", @@ -4049,15 +4048,6 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" version = "1.14.0" @@ -5393,7 +5383,6 @@ dependencies = [ "http", "http-body-util", "hyper", - "hyper-util", "log", "ruma", "rustls", diff --git a/Cargo.toml b/Cargo.toml index bc602ecb..fbd60f68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,11 +82,12 @@ features = [ ] [workspace.dependencies.axum-server] -version = "0.7" +version = "0.8" default-features = false [workspace.dependencies.axum-server-dual-protocol] -version = "0.7" +git = "https://github.com/matrix-construct/axum-server-dual-protocol" +rev = "76c782fa6f129f83ffdca59e903093e798d4a82f" default-features = false [workspace.dependencies.base64] @@ -201,15 +202,6 @@ features = [ "http2", ] -[workspace.dependencies.hyper-util] -version = "0.1.20" -default-features = false -features = [ - "server-auto", - "server-graceful", - "tokio", -] - [workspace.dependencies.image] version = "0.25" default-features = false diff --git a/src/core/config/check.rs b/src/core/config/check.rs index 098b613e..13f59500 100644 --- a/src/core/config/check.rs +++ b/src/core/config/check.rs @@ -1,11 +1,10 @@ use std::env::consts::OS; use either::Either; -use figment::Figment; use itertools::Itertools; use super::{DEPRECATED_KEYS, IdentityProvider}; -use crate::{Config, Err, Result, Server, debug, debug_info, debug_warn, error, warn}; +use crate::{Config, Err, Result, Server, debug, debug_info, error, warn}; /// Performs check() with additional checks specific to reloading old config /// with new config. @@ -24,9 +23,8 @@ pub fn reload(old: &Config, new: &Config) -> Result { } pub fn check(config: &Config) -> Result { - if cfg!(debug_assertions) { - warn!("Note: tuwunel was built without optimisations (i.e. debug build)"); - } + #[cfg(debug_assertions)] + warn!("Note: tuwunel was built without optimisations (i.e. debug build)"); warn_deprecated(config); warn_unknown_key(config)?; @@ -38,14 +36,18 @@ pub fn check(config: &Config) -> Result { )); } - if cfg!(all(feature = "hardened_malloc", feature = "jemalloc", not(target_env = "msvc"))) { - debug_warn!( - "hardened_malloc and jemalloc compile-time features are both enabled, this causes \ - jemalloc to be used." - ); - } + #[cfg(all( + feature = "hardened_malloc", + feature = "jemalloc", + not(target_env = "msvc") + ))] + debug_warn!( + "hardened_malloc and jemalloc compile-time features are both enabled, this causes \ + jemalloc to be used." + ); - if cfg!(not(unix)) && config.unix_socket_path.is_some() { + #[cfg(not(unix))] + if config.unix_socket_path.is_some() { return Err!(Config( "unix_socket_path", "UNIX socket support is only available on *nix platforms. Please remove \ @@ -53,12 +55,20 @@ pub fn check(config: &Config) -> Result { )); } - if config.unix_socket_path.is_none() && config.get_bind_hosts().is_empty() { - return Err!(Config("address", "No TCP addresses were specified to listen on")); + if config.unix_socket_path.is_none() { + if config.get_bind_hosts().is_empty() { + return Err!(Config("address", "No TCP addresses were specified to listen on")); + } + + if config.get_bind_ports().is_empty() { + return Err!(Config("port", "No ports were specified to listen on")); + } } - if config.unix_socket_path.is_none() && config.get_bind_ports().is_empty() { - return Err!(Config("port", "No ports were specified to listen on")); + let certs_set = config.tls.certs.is_some(); + let key_set = config.tls.key.is_some(); + if certs_set ^ key_set { + return Err!(Config("tls", "tls.certs and tls.key must either both be set or unset")); } if !config.listening { @@ -115,7 +125,8 @@ pub fn check(config: &Config) -> Result { } // yeah, unless the user built a debug build hopefully for local testing only - if cfg!(not(debug_assertions)) && config.server_name == "your.server.name" { + #[cfg(not(debug_assertions))] + if config.server_name == "your.server.name" { return Err!(Config( "server_name", "You must specify a valid server name for production usage of tuwunel." @@ -412,18 +423,3 @@ fn warn_unknown_key(config: &Config) -> Result { Ok(()) } } - -/// Checks the presence of the `address` and `unix_socket_path` keys in the -/// raw_config, exiting the process if both keys were detected. -pub(super) fn is_dual_listening(raw_config: &Figment) -> Result { - let contains_address = raw_config.contains("address"); - let contains_unix_socket = raw_config.contains("unix_socket_path"); - if contains_address && contains_unix_socket { - return Err!( - "TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify \ - only one option." - ); - } - - Ok(()) -} diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 612b0fdf..e2cb4cec 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -105,8 +105,8 @@ pub struct Config { /// "::1"] /// /// default: ["127.0.0.1", "::1"] - #[serde(default = "default_address")] - address: ListeningAddr, + #[serde(default)] + address: Option, /// The port(s) tuwunel will listen on. /// @@ -128,9 +128,6 @@ pub struct Config { /// The UNIX socket tuwunel will listen on. /// - /// tuwunel cannot listen on both an IP address and a UNIX socket. If - /// listening on a UNIX socket, you MUST remove/comment the `address` key. - /// /// Remember to make sure that your reverse proxy has access to this socket /// file, either by adding your reverse proxy to the 'tuwunel' group or /// granting world R/W permissions with `unix_socket_perms` (666 minimum). @@ -3010,18 +3007,24 @@ impl Config { .extract::() .map_err(|e| err!("There was a problem with your configuration file: {e}"))?; - // don't start if we're listening on both UNIX sockets and TCP at same time - check::is_dual_listening(raw_config)?; - Ok(config) } + pub fn get_unix_socket_perms(&self) -> Result { + let octal_perms = self.unix_socket_perms.to_string(); + let socket_perms = u32::from_str_radix(&octal_perms, 8).map_err(|_| { + err!(Config("unix_socket_perms", "failed to convert octal permissions")) + })?; + + Ok(socket_perms) + } + #[must_use] pub fn get_bind_addrs(&self) -> Vec { let mut addrs = Vec::with_capacity( self.get_bind_hosts() .len() - .saturating_add(self.get_bind_ports().len()), + .saturating_mul(self.get_bind_ports().len()), ); for host in &self.get_bind_hosts() { for port in &self.get_bind_ports() { @@ -3033,9 +3036,15 @@ impl Config { } fn get_bind_hosts(&self) -> Vec { - match &self.address.addrs { - | Left(addr) => vec![*addr], - | Right(addrs) => addrs.clone(), + if let Some(address) = &self.address { + match &address.addrs { + | Left(addr) => vec![*addr], + | Right(addrs) => addrs.clone(), + } + } else if self.unix_socket_path.is_some() { + vec![] + } else { + vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()] } } @@ -3056,12 +3065,6 @@ fn default_server_name() -> OwnedServerName { ruma::owned_server_name!("localhos fn default_database_path() -> PathBuf { "/var/lib/tuwunel".to_owned().into() } -fn default_address() -> ListeningAddr { - ListeningAddr { - addrs: Right(vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]), - } -} - fn default_port() -> ListeningPort { ListeningPort { ports: Left(8008) } } fn default_unix_socket_perms() -> u32 { 660 } diff --git a/src/router/Cargo.toml b/src/router/Cargo.toml index d0de5743..1876b79f 100644 --- a/src/router/Cargo.toml +++ b/src/router/Cargo.toml @@ -115,7 +115,6 @@ futures.workspace = true http.workspace = true http-body-util.workspace = true hyper.workspace = true -hyper-util.workspace = true log.workspace = true ruma.workspace = true rustls.workspace = true diff --git a/src/router/handle.rs b/src/router/handle.rs new file mode 100644 index 00000000..64a96c88 --- /dev/null +++ b/src/router/handle.rs @@ -0,0 +1,27 @@ +use std::time::Duration; + +use axum_server::Handle; + +#[derive(Clone)] +pub(crate) struct ServerHandle { + pub(crate) handle_ip: Handle, + #[cfg(unix)] + pub(crate) handle_unix: Handle, +} + +impl ServerHandle { + pub(crate) fn new() -> Self { + Self { + handle_ip: Handle::::new(), + #[cfg(unix)] + handle_unix: Handle::::new(), + } + } + + pub(crate) fn graceful_shutdown(&self, duration: Option) { + self.handle_ip.graceful_shutdown(duration); + + #[cfg(unix)] + self.handle_unix.graceful_shutdown(duration); + } +} diff --git a/src/router/mod.rs b/src/router/mod.rs index 0f9a9958..336c1fda 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,6 +1,7 @@ #![type_length_limit = "32768"] //TODO: reduce me #![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91 +mod handle; mod layers; mod request; mod router; diff --git a/src/router/run.rs b/src/router/run.rs index f76ed642..57f1c048 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -3,15 +3,11 @@ use std::{ time::Duration, }; -use axum_server::Handle as ServerHandle; -use tokio::{ - sync::broadcast::{self, Sender}, - task::JoinHandle, -}; +use futures::FutureExt; use tuwunel_core::{Error, Result, Server, debug, debug_error, debug_info, error, info}; use tuwunel_service::Services; -use crate::serve; +use crate::{handle::ServerHandle, serve}; /// Main loop base #[tracing::instrument(skip_all)] @@ -24,20 +20,30 @@ pub(crate) async fn run(services: Arc) -> Result { // Setup shutdown/signal handling let handle = ServerHandle::new(); - let (tx, _) = broadcast::channel::<()>(1); let sigs = server .runtime() - .spawn(signal(server.clone(), tx.clone(), handle.clone())); + .spawn(signal(server.clone(), handle.clone())); - let mut listener = + let mut listener = if services.config.listening { + let future = serve::serve(services.clone(), handle); server .runtime() - .spawn(serve::serve(services.clone(), handle.clone(), tx.subscribe())); + .spawn(future) + .map(|res| res.map_err(Error::from).unwrap_or_else(Err)) + .boxed() + } else { + let server = server.clone(); + async move { + server.until_shutdown().await; + Ok(()) + } + .boxed() + }; // Focal point debug!("Running"); let res = tokio::select! { - res = &mut listener => res.map_err(Error::from).unwrap_or_else(Err), + res = &mut listener => res, res = services.poll() => handle_services_poll(server, res, listener).await, }; @@ -104,16 +110,12 @@ pub(crate) async fn stop(services: Arc) -> Result { } #[tracing::instrument(skip_all)] -async fn signal(server: Arc, tx: Sender<()>, handle: axum_server::Handle) { +async fn signal(server: Arc, handle: ServerHandle) { server.until_shutdown().await; - handle_shutdown(&server, &tx, &handle); + handle_shutdown(&server, &handle); } -fn handle_shutdown(server: &Arc, tx: &Sender<()>, handle: &axum_server::Handle) { - if let Err(e) = tx.send(()) { - error!("failed sending shutdown transaction to channel: {e}"); - } - +fn handle_shutdown(server: &Arc, handle: &ServerHandle) { let timeout = server.config.client_shutdown_timeout; let timeout = Duration::from_secs(timeout); debug!( @@ -128,7 +130,7 @@ fn handle_shutdown(server: &Arc, tx: &Sender<()>, handle: &axum_server:: async fn handle_services_poll( server: &Arc, result: Result, - listener: JoinHandle, + listener: impl Future, ) -> Result { debug!("Service manager finished: {result:?}"); diff --git a/src/router/serve/mod.rs b/src/router/serve/mod.rs index 49ebe583..9f4d6981 100644 --- a/src/router/serve/mod.rs +++ b/src/router/serve/mod.rs @@ -3,46 +3,75 @@ mod plain; mod tls; mod unix; -use std::sync::Arc; +use std::sync::{Arc, atomic::Ordering}; -use axum_server::Handle as ServerHandle; -use tokio::sync::broadcast; -use tuwunel_core::{Result, err}; +use tokio::task::JoinSet; +use tuwunel_core::{Result, debug_info}; use tuwunel_service::Services; use super::layers; +use crate::handle::ServerHandle; /// Serve clients -pub(super) async fn serve( - services: Arc, - handle: ServerHandle, - mut shutdown: broadcast::Receiver<()>, -) -> Result { +pub(super) async fn serve(services: Arc, handle: ServerHandle) -> Result { let server = &services.server; let config = &server.config; - if !config.listening { - return shutdown - .recv() - .await - .map_err(|e| err!(error!("channel error: {e}"))); + + let (app, _guard) = layers::build(&services)?; + + let mut join_set = JoinSet::new(); + + #[cfg(unix)] + if let Some(unix_socket) = &config.unix_socket_path { + let socket_perms = config.get_unix_socket_perms()?; + + unix::serve(server, &app, &handle.handle_unix, &mut join_set, unix_socket, socket_perms) + .await?; } let addrs = config.get_bind_addrs(); - let (app, _guard) = layers::build(&services)?; - if cfg!(unix) && config.unix_socket_path.is_some() { - unix::serve(server, app, shutdown).await - } else if config.tls.certs.is_some() { - #[cfg(feature = "direct_tls")] - services.globals.init_rustls_provider()?; - #[cfg(feature = "direct_tls")] - return tls::serve(server, app, handle, addrs).await; + if !addrs.is_empty() { + #[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))] + if config.tls.certs.is_some() { + #[cfg(feature = "direct_tls")] + { + services.globals.init_rustls_provider()?; + tls::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs).await?; + } - #[cfg(not(feature = "direct_tls"))] - return tuwunel_core::Err!(Config( - "tls", - "tuwunel was not built with direct TLS support (\"direct_tls\")" - )); - } else { - plain::serve(server, app, handle, addrs).await + #[cfg(not(feature = "direct_tls"))] + return tuwunel_core::Err!(Config( + "tls", + "tuwunel was not built with direct TLS support (\"direct_tls\")" + )); + } else { + plain::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs); + } } + + assert!(!join_set.is_empty(), "at least one listener should be installed"); + + join_set.join_all().await; + + let handle_active = server + .metrics + .requests_handle_active + .load(Ordering::Acquire); + + debug_info!( + handle_finished = server + .metrics + .requests_handle_finished + .load(Ordering::Acquire), + panics = server + .metrics + .requests_panic + .load(Ordering::Acquire), + handle_active, + "Stopped listening on {addrs:?}", + ); + + debug_assert_eq!(0, handle_active, "active request handles still pending"); + + Ok(()) } diff --git a/src/router/serve/plain.rs b/src/router/serve/plain.rs index 96ed60ae..4ee8fcf1 100644 --- a/src/router/serve/plain.rs +++ b/src/router/serve/plain.rs @@ -1,50 +1,26 @@ -use std::{ - net::SocketAddr, - sync::{Arc, atomic::Ordering}, -}; +use std::{net::SocketAddr, sync::Arc}; use axum::Router; -use axum_server::{Handle as ServerHandle, bind}; +use axum_server::Handle; use tokio::task::JoinSet; -use tuwunel_core::{Result, Server, debug_info, info}; +use tuwunel_core::{Server, info}; -pub(super) async fn serve( +pub(super) fn serve( server: &Arc, - router: Router, - handle: ServerHandle, - addrs: Vec, -) -> Result { - let mut join_set = JoinSet::new(); - let router = router.into_make_service_with_connect_info::(); - for addr in &addrs { - let bound = bind(*addr); - let handler = bound.handle(handle.clone()); - let acceptor = handler.serve(router.clone()); + router: &Router, + handle: &Handle, + join_set: &mut JoinSet>, + addrs: &[SocketAddr], +) { + let router = router + .clone() + .into_make_service_with_connect_info::(); + for addr in addrs { + let acceptor = axum_server::bind(*addr) + .handle(handle.clone()) + .serve(router.clone()); join_set.spawn_on(acceptor, server.runtime()); } info!("Listening on {addrs:?}"); - while join_set.join_next().await.is_some() {} - - let handle_active = server - .metrics - .requests_handle_active - .load(Ordering::Acquire); - - debug_info!( - handle_finished = server - .metrics - .requests_handle_finished - .load(Ordering::Acquire), - panics = server - .metrics - .requests_panic - .load(Ordering::Acquire), - handle_active, - "Stopped listening on {addrs:?}", - ); - - debug_assert_eq!(0, handle_active, "active request handles still pending"); - - Ok(()) } diff --git a/src/router/serve/tls.rs b/src/router/serve/tls.rs index d1c2c986..09ecce31 100644 --- a/src/router/serve/tls.rs +++ b/src/router/serve/tls.rs @@ -1,31 +1,21 @@ -use std::{ - net::SocketAddr, - sync::{Arc, atomic::Ordering}, -}; +use std::{net::SocketAddr, sync::Arc}; use axum::Router; -use axum_server::Handle as ServerHandle; -use axum_server_dual_protocol::{ - ServerExt, - axum_server::{bind_rustls, tls_rustls::RustlsConfig}, -}; +use axum_server::Handle; +use axum_server_dual_protocol::{ServerExt, axum_server::tls_rustls::RustlsConfig}; use tokio::task::JoinSet; -use tuwunel_core::{Result, Server, debug, debug_info, err, info, warn}; +use tuwunel_core::{Result, Server, debug, err, info, warn}; pub(super) async fn serve( server: &Arc, - app: Router, - handle: ServerHandle, - addrs: Vec, + app: &Router, + handle: &Handle, + join_set: &mut JoinSet>, + addrs: &[SocketAddr], ) -> Result { let tls = &server.config.tls; - let certs = tls.certs.as_ref().ok_or_else(|| { - err!(Config("tls.certs", "Missing required value in tls config section")) - })?; - let key = tls - .key - .as_ref() - .ok_or_else(|| err!(Config("tls.key", "Missing required value in tls config section")))?; + let certs = tls.certs.as_ref().unwrap(); + let key = tls.key.as_ref().unwrap(); info!( "Note: It is strongly recommended that you use a reverse proxy instead of running \ @@ -36,10 +26,11 @@ pub(super) async fn serve( .await .map_err(|e| err!(Config("tls", "Failed to load certificates or key: {e}")))?; - let mut join_set = JoinSet::new(); - let app = app.into_make_service_with_connect_info::(); + let app = app + .clone() + .into_make_service_with_connect_info::(); if tls.dual_protocol { - for addr in &addrs { + for addr in addrs { join_set.spawn_on( axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) .set_upgrade(false) @@ -48,47 +39,23 @@ pub(super) async fn serve( server.runtime(), ); } - } else { - for addr in &addrs { - join_set.spawn_on( - bind_rustls(*addr, conf.clone()) - .handle(handle.clone()) - .serve(app.clone()), - server.runtime(), - ); - } - } - if tls.dual_protocol { warn!( "Listening on {addrs:?} with TLS certificate {certs} and supporting plain text \ (HTTP) connections too (insecure!)", ); } else { + for addr in addrs { + join_set.spawn_on( + axum_server::bind_rustls(*addr, conf.clone()) + .handle(handle.clone()) + .serve(app.clone()), + server.runtime(), + ); + } + info!("Listening on {addrs:?} with TLS certificate {certs}"); } - while join_set.join_next().await.is_some() {} - - let handle_active = server - .metrics - .requests_handle_active - .load(Ordering::Acquire); - - debug_info!( - handle_finished = server - .metrics - .requests_handle_finished - .load(Ordering::Acquire), - panics = server - .metrics - .requests_panic - .load(Ordering::Acquire), - handle_active, - "Stopped listening on {addrs:?}", - ); - - debug_assert_eq!(0, handle_active, "active request handles still pending"); - Ok(()) } diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs index f6e12783..e4c489f4 100644 --- a/src/router/serve/unix.rs +++ b/src/router/serve/unix.rs @@ -1,185 +1,48 @@ #![cfg(unix)] use std::{ - net::{self, IpAddr, Ipv4Addr}, - os::fd::AsRawFd, - path::Path, - sync::{Arc, atomic::Ordering}, -}; - -use axum::{ - Router, - extract::{Request, connect_info::IntoMakeServiceWithConnectInfo}, -}; -use hyper::{body::Incoming, service::service_fn}; -use hyper_util::{ - rt::{TokioExecutor, TokioIo}, - server, -}; -use tokio::{ fs, - net::{UnixListener, UnixStream, unix::SocketAddr}, - sync::broadcast::{self}, - task::JoinSet, - time::{Duration, sleep}, -}; -use tower::{Service, ServiceExt}; -use tuwunel_core::{ - Err, Result, Server, debug, debug_error, info, result::UnwrapInfallible, trace, warn, + net::SocketAddr, + os::unix::{fs::PermissionsExt, net::UnixListener}, + path::Path, + sync::Arc, }; -type MakeService = IntoMakeServiceWithConnectInfo; - -const NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); -const FINI_POLL_INTERVAL: Duration = Duration::from_millis(750); +use axum::{Extension, Router, extract::ConnectInfo}; +use axum_server::Handle; +use tokio::task::JoinSet; +use tuwunel_core::{Result, Server, info, warn}; #[tracing::instrument(skip_all, level = "debug")] pub(super) async fn serve( server: &Arc, - app: Router, - mut shutdown: broadcast::Receiver<()>, + router: &Router, + handle: &Handle, + join_set: &mut JoinSet>, + path: &Path, + socket_perms: u32, ) -> Result { - let mut tasks = JoinSet::<()>::new(); - let executor = TokioExecutor::new(); - let app = app.into_make_service_with_connect_info::(); - let builder = server::conn::auto::Builder::new(executor); - let listener = init(server).await?; - while server.running() { - let app = app.clone(); - let builder = builder.clone(); - tokio::select! { - _sig = shutdown.recv() => break, - conn = listener.accept() => match conn { - Ok(conn) => accept(server, &listener, &mut tasks, app, builder, conn).await, - Err(err) => debug_error!(?listener, "accept error: {err}"), - }, - } - } - - fini(server, listener, tasks).await; - - Ok(()) -} - -#[tracing::instrument( - level = "trace", - skip_all, - fields( - ?listener, - socket = ?conn.0, - ), -)] -async fn accept( - server: &Arc, - listener: &UnixListener, - tasks: &mut JoinSet<()>, - app: MakeService, - builder: server::conn::auto::Builder, - conn: (UnixStream, SocketAddr), -) { - let (socket, _) = conn; - let server_ = server.clone(); - let task = async move { accepted(server_, builder, socket, app).await }; - - _ = tasks.spawn_on(task, server.runtime()); - while tasks.try_join_next().is_some() {} -} - -#[tracing::instrument( - level = "trace", - skip_all, - fields( - fd = %socket.as_raw_fd(), - path = ?socket.local_addr(), - ), -)] -async fn accepted( - server: Arc, - builder: server::conn::auto::Builder, - socket: UnixStream, - mut app: MakeService, -) { - let socket = TokioIo::new(socket); - let called = app.call(NULL_ADDR).await.unwrap_infallible(); - let service = move |req: Request| called.clone().oneshot(req); - let handler = service_fn(service); - trace!(?socket, ?handler, "serving connection"); - - // bug on darwin causes all results to be errors. do not unwrap this - tokio::select! { - () = server.until_shutdown() => (), - _ = builder.serve_connection(socket, handler) => (), - }; -} - -async fn init(server: &Arc) -> Result { - use std::os::unix::fs::PermissionsExt; - - let config = &server.config; - let path = config - .unix_socket_path - .as_ref() - .expect("failed to extract configured unix socket path"); - if path.exists() { - warn!("Removing existing UNIX socket {:#?} (unclean shutdown?)...", path.display()); - fs::remove_file(&path) - .await - .map_err(|e| warn!("Failed to remove existing UNIX socket: {e}")) - .unwrap(); + warn!("Removing existing UNIX socket {path:?} (unclean shutdown?)..."); + fs::remove_file(path)?; } - let dir = path.parent().unwrap_or_else(|| Path::new("/")); - if let Err(e) = fs::create_dir_all(dir).await { - return Err!("Failed to create {dir:?} for socket {path:?}: {e}"); - } + let unix_listener = UnixListener::bind(path)?; + unix_listener.set_nonblocking(true)?; - let listener = UnixListener::bind(path); - if let Err(e) = listener { - return Err!("Failed to bind listener {path:?}: {e}"); - } + let perms = fs::Permissions::from_mode(socket_perms); + fs::set_permissions(path, perms)?; - let socket_perms = config.unix_socket_perms.to_string(); - let octal_perms = - u32::from_str_radix(&socket_perms, 8).expect("failed to convert octal permissions"); - let perms = std::fs::Permissions::from_mode(octal_perms); - if let Err(e) = fs::set_permissions(&path, perms).await { - return Err!("Failed to set socket {path:?} permissions: {e}"); - } + let router = router + .clone() + .layer(Extension(ConnectInfo("0.0.0.0".parse::()))) + .into_make_service(); + let acceptor = axum_server::from_unix(unix_listener)? + .handle(handle.clone()) + .serve(router); + join_set.spawn_on(acceptor, server.runtime()); info!("Listening at {path:?}"); - Ok(listener.unwrap()) -} - -async fn fini(server: &Arc, listener: UnixListener, mut tasks: JoinSet<()>) { - let local = listener.local_addr(); - - debug!("Closing listener at {local:?} ..."); - drop(listener); - - debug!("Waiting for requests to finish..."); - while server - .metrics - .requests_handle_active - .load(Ordering::Acquire) - .gt(&0) - { - tokio::select! { - task = tasks.join_next() => if task.is_none() { break; }, - () = sleep(FINI_POLL_INTERVAL) => {}, - } - } - - debug!("Shutting down..."); - tasks.shutdown().await; - - if let Ok(local) = local - && let Some(path) = local.as_pathname() - { - debug!(?path, "Removing unix socket file."); - if let Err(e) = fs::remove_file(path).await { - warn!(?path, "Failed to remove UNIX socket file: {e}"); - } - } + Ok(()) } diff --git a/tuwunel-example.toml b/tuwunel-example.toml index 6cffba53..ebce64df 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -74,9 +74,6 @@ # The UNIX socket tuwunel will listen on. # -# tuwunel cannot listen on both an IP address and a UNIX socket. If -# listening on a UNIX socket, you MUST remove/comment the `address` key. -# # Remember to make sure that your reverse proxy has access to this socket # file, either by adding your reverse proxy to the 'tuwunel' group or # granting world R/W permissions with `unix_socket_perms` (666 minimum).