Update axum-server to 0.8; switch to axum unix listener.

This commit is contained in:
dasha_uwu
2026-02-03 02:30:50 +05:00
committed by Jason Volk
parent 87faf818ff
commit bd5203b406
13 changed files with 231 additions and 390 deletions

21
Cargo.lock generated
View File

@@ -402,12 +402,13 @@ dependencies = [
[[package]]
name = "axum-server"
version = "0.7.3"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1ab4a3ec9ea8a657c72d99a03a824af695bd0fb5ec639ccbd9cd3543b41a5f9"
checksum = "b1df331683d982a0b9492b38127151e6453639cd34926eb9c07d4cd8c6d22bfc"
dependencies = [
"arc-swap",
"bytes",
"either",
"fs-err",
"http",
"http-body",
@@ -415,7 +416,6 @@ dependencies = [
"hyper-util",
"pin-project-lite",
"rustls",
"rustls-pemfile",
"rustls-pki-types",
"tokio",
"tokio-rustls",
@@ -424,9 +424,8 @@ dependencies = [
[[package]]
name = "axum-server-dual-protocol"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2164551db024e87f20316d164eab9f5ad342d8188b08051ceb15ca92a60ea7b7"
version = "0.8.0"
source = "git+https://github.com/matrix-construct/axum-server-dual-protocol?rev=76c782fa6f129f83ffdca59e903093e798d4a82f#76c782fa6f129f83ffdca59e903093e798d4a82f"
dependencies = [
"axum-server",
"bytes",
@@ -4049,15 +4048,6 @@ dependencies = [
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "rustls-pki-types"
version = "1.14.0"
@@ -5393,7 +5383,6 @@ dependencies = [
"http",
"http-body-util",
"hyper",
"hyper-util",
"log",
"ruma",
"rustls",

View File

@@ -82,11 +82,12 @@ features = [
]
[workspace.dependencies.axum-server]
version = "0.7"
version = "0.8"
default-features = false
[workspace.dependencies.axum-server-dual-protocol]
version = "0.7"
git = "https://github.com/matrix-construct/axum-server-dual-protocol"
rev = "76c782fa6f129f83ffdca59e903093e798d4a82f"
default-features = false
[workspace.dependencies.base64]
@@ -201,15 +202,6 @@ features = [
"http2",
]
[workspace.dependencies.hyper-util]
version = "0.1.20"
default-features = false
features = [
"server-auto",
"server-graceful",
"tokio",
]
[workspace.dependencies.image]
version = "0.25"
default-features = false

View File

@@ -1,11 +1,10 @@
use std::env::consts::OS;
use either::Either;
use figment::Figment;
use itertools::Itertools;
use super::{DEPRECATED_KEYS, IdentityProvider};
use crate::{Config, Err, Result, Server, debug, debug_info, debug_warn, error, warn};
use crate::{Config, Err, Result, Server, debug, debug_info, error, warn};
/// Performs check() with additional checks specific to reloading old config
/// with new config.
@@ -24,9 +23,8 @@ pub fn reload(old: &Config, new: &Config) -> Result {
}
pub fn check(config: &Config) -> Result {
if cfg!(debug_assertions) {
#[cfg(debug_assertions)]
warn!("Note: tuwunel was built without optimisations (i.e. debug build)");
}
warn_deprecated(config);
warn_unknown_key(config)?;
@@ -38,14 +36,18 @@ pub fn check(config: &Config) -> Result {
));
}
if cfg!(all(feature = "hardened_malloc", feature = "jemalloc", not(target_env = "msvc"))) {
#[cfg(all(
feature = "hardened_malloc",
feature = "jemalloc",
not(target_env = "msvc")
))]
debug_warn!(
"hardened_malloc and jemalloc compile-time features are both enabled, this causes \
jemalloc to be used."
);
}
if cfg!(not(unix)) && config.unix_socket_path.is_some() {
#[cfg(not(unix))]
if config.unix_socket_path.is_some() {
return Err!(Config(
"unix_socket_path",
"UNIX socket support is only available on *nix platforms. Please remove \
@@ -53,13 +55,21 @@ pub fn check(config: &Config) -> Result {
));
}
if config.unix_socket_path.is_none() && config.get_bind_hosts().is_empty() {
if config.unix_socket_path.is_none() {
if config.get_bind_hosts().is_empty() {
return Err!(Config("address", "No TCP addresses were specified to listen on"));
}
if config.unix_socket_path.is_none() && config.get_bind_ports().is_empty() {
if config.get_bind_ports().is_empty() {
return Err!(Config("port", "No ports were specified to listen on"));
}
}
let certs_set = config.tls.certs.is_some();
let key_set = config.tls.key.is_some();
if certs_set ^ key_set {
return Err!(Config("tls", "tls.certs and tls.key must either both be set or unset"));
}
if !config.listening {
warn!("Configuration item `listening` is set to `false`. Cannot hear anyone.");
@@ -115,7 +125,8 @@ pub fn check(config: &Config) -> Result {
}
// yeah, unless the user built a debug build hopefully for local testing only
if cfg!(not(debug_assertions)) && config.server_name == "your.server.name" {
#[cfg(not(debug_assertions))]
if config.server_name == "your.server.name" {
return Err!(Config(
"server_name",
"You must specify a valid server name for production usage of tuwunel."
@@ -412,18 +423,3 @@ fn warn_unknown_key(config: &Config) -> Result {
Ok(())
}
}
/// Checks the presence of the `address` and `unix_socket_path` keys in the
/// raw_config, exiting the process if both keys were detected.
pub(super) fn is_dual_listening(raw_config: &Figment) -> Result {
let contains_address = raw_config.contains("address");
let contains_unix_socket = raw_config.contains("unix_socket_path");
if contains_address && contains_unix_socket {
return Err!(
"TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify \
only one option."
);
}
Ok(())
}

View File

@@ -105,8 +105,8 @@ pub struct Config {
/// "::1"]
///
/// default: ["127.0.0.1", "::1"]
#[serde(default = "default_address")]
address: ListeningAddr,
#[serde(default)]
address: Option<ListeningAddr>,
/// The port(s) tuwunel will listen on.
///
@@ -128,9 +128,6 @@ pub struct Config {
/// The UNIX socket tuwunel will listen on.
///
/// tuwunel cannot listen on both an IP address and a UNIX socket. If
/// listening on a UNIX socket, you MUST remove/comment the `address` key.
///
/// Remember to make sure that your reverse proxy has access to this socket
/// file, either by adding your reverse proxy to the 'tuwunel' group or
/// granting world R/W permissions with `unix_socket_perms` (666 minimum).
@@ -3010,18 +3007,24 @@ impl Config {
.extract::<Self>()
.map_err(|e| err!("There was a problem with your configuration file: {e}"))?;
// don't start if we're listening on both UNIX sockets and TCP at same time
check::is_dual_listening(raw_config)?;
Ok(config)
}
pub fn get_unix_socket_perms(&self) -> Result<u32> {
let octal_perms = self.unix_socket_perms.to_string();
let socket_perms = u32::from_str_radix(&octal_perms, 8).map_err(|_| {
err!(Config("unix_socket_perms", "failed to convert octal permissions"))
})?;
Ok(socket_perms)
}
#[must_use]
pub fn get_bind_addrs(&self) -> Vec<SocketAddr> {
let mut addrs = Vec::with_capacity(
self.get_bind_hosts()
.len()
.saturating_add(self.get_bind_ports().len()),
.saturating_mul(self.get_bind_ports().len()),
);
for host in &self.get_bind_hosts() {
for port in &self.get_bind_ports() {
@@ -3033,10 +3036,16 @@ impl Config {
}
fn get_bind_hosts(&self) -> Vec<IpAddr> {
match &self.address.addrs {
if let Some(address) = &self.address {
match &address.addrs {
| Left(addr) => vec![*addr],
| Right(addrs) => addrs.clone(),
}
} else if self.unix_socket_path.is_some() {
vec![]
} else {
vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]
}
}
fn get_bind_ports(&self) -> Vec<u16> {
@@ -3056,12 +3065,6 @@ fn default_server_name() -> OwnedServerName { ruma::owned_server_name!("localhos
fn default_database_path() -> PathBuf { "/var/lib/tuwunel".to_owned().into() }
fn default_address() -> ListeningAddr {
ListeningAddr {
addrs: Right(vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]),
}
}
fn default_port() -> ListeningPort { ListeningPort { ports: Left(8008) } }
fn default_unix_socket_perms() -> u32 { 660 }

View File

@@ -115,7 +115,6 @@ futures.workspace = true
http.workspace = true
http-body-util.workspace = true
hyper.workspace = true
hyper-util.workspace = true
log.workspace = true
ruma.workspace = true
rustls.workspace = true

27
src/router/handle.rs Normal file
View File

@@ -0,0 +1,27 @@
use std::time::Duration;
use axum_server::Handle;
#[derive(Clone)]
pub(crate) struct ServerHandle {
pub(crate) handle_ip: Handle<std::net::SocketAddr>,
#[cfg(unix)]
pub(crate) handle_unix: Handle<std::os::unix::net::SocketAddr>,
}
impl ServerHandle {
pub(crate) fn new() -> Self {
Self {
handle_ip: Handle::<std::net::SocketAddr>::new(),
#[cfg(unix)]
handle_unix: Handle::<std::os::unix::net::SocketAddr>::new(),
}
}
pub(crate) fn graceful_shutdown(&self, duration: Option<Duration>) {
self.handle_ip.graceful_shutdown(duration);
#[cfg(unix)]
self.handle_unix.graceful_shutdown(duration);
}
}

View File

@@ -1,6 +1,7 @@
#![type_length_limit = "32768"] //TODO: reduce me
#![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91
mod handle;
mod layers;
mod request;
mod router;

View File

@@ -3,15 +3,11 @@ use std::{
time::Duration,
};
use axum_server::Handle as ServerHandle;
use tokio::{
sync::broadcast::{self, Sender},
task::JoinHandle,
};
use futures::FutureExt;
use tuwunel_core::{Error, Result, Server, debug, debug_error, debug_info, error, info};
use tuwunel_service::Services;
use crate::serve;
use crate::{handle::ServerHandle, serve};
/// Main loop base
#[tracing::instrument(skip_all)]
@@ -24,20 +20,30 @@ pub(crate) async fn run(services: Arc<Services>) -> Result {
// Setup shutdown/signal handling
let handle = ServerHandle::new();
let (tx, _) = broadcast::channel::<()>(1);
let sigs = server
.runtime()
.spawn(signal(server.clone(), tx.clone(), handle.clone()));
.spawn(signal(server.clone(), handle.clone()));
let mut listener =
let mut listener = if services.config.listening {
let future = serve::serve(services.clone(), handle);
server
.runtime()
.spawn(serve::serve(services.clone(), handle.clone(), tx.subscribe()));
.spawn(future)
.map(|res| res.map_err(Error::from).unwrap_or_else(Err))
.boxed()
} else {
let server = server.clone();
async move {
server.until_shutdown().await;
Ok(())
}
.boxed()
};
// Focal point
debug!("Running");
let res = tokio::select! {
res = &mut listener => res.map_err(Error::from).unwrap_or_else(Err),
res = &mut listener => res,
res = services.poll() => handle_services_poll(server, res, listener).await,
};
@@ -104,16 +110,12 @@ pub(crate) async fn stop(services: Arc<Services>) -> Result {
}
#[tracing::instrument(skip_all)]
async fn signal(server: Arc<Server>, tx: Sender<()>, handle: axum_server::Handle) {
async fn signal(server: Arc<Server>, handle: ServerHandle) {
server.until_shutdown().await;
handle_shutdown(&server, &tx, &handle);
}
fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_server::Handle) {
if let Err(e) = tx.send(()) {
error!("failed sending shutdown transaction to channel: {e}");
handle_shutdown(&server, &handle);
}
fn handle_shutdown(server: &Arc<Server>, handle: &ServerHandle) {
let timeout = server.config.client_shutdown_timeout;
let timeout = Duration::from_secs(timeout);
debug!(
@@ -128,7 +130,7 @@ fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_server::
async fn handle_services_poll(
server: &Arc<Server>,
result: Result,
listener: JoinHandle<Result>,
listener: impl Future<Output = Result>,
) -> Result {
debug!("Service manager finished: {result:?}");

View File

@@ -3,39 +3,41 @@ mod plain;
mod tls;
mod unix;
use std::sync::Arc;
use std::sync::{Arc, atomic::Ordering};
use axum_server::Handle as ServerHandle;
use tokio::sync::broadcast;
use tuwunel_core::{Result, err};
use tokio::task::JoinSet;
use tuwunel_core::{Result, debug_info};
use tuwunel_service::Services;
use super::layers;
use crate::handle::ServerHandle;
/// Serve clients
pub(super) async fn serve(
services: Arc<Services>,
handle: ServerHandle,
mut shutdown: broadcast::Receiver<()>,
) -> Result {
pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Result {
let server = &services.server;
let config = &server.config;
if !config.listening {
return shutdown
.recv()
.await
.map_err(|e| err!(error!("channel error: {e}")));
let (app, _guard) = layers::build(&services)?;
let mut join_set = JoinSet::new();
#[cfg(unix)]
if let Some(unix_socket) = &config.unix_socket_path {
let socket_perms = config.get_unix_socket_perms()?;
unix::serve(server, &app, &handle.handle_unix, &mut join_set, unix_socket, socket_perms)
.await?;
}
let addrs = config.get_bind_addrs();
let (app, _guard) = layers::build(&services)?;
if cfg!(unix) && config.unix_socket_path.is_some() {
unix::serve(server, app, shutdown).await
} else if config.tls.certs.is_some() {
if !addrs.is_empty() {
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
if config.tls.certs.is_some() {
#[cfg(feature = "direct_tls")]
{
services.globals.init_rustls_provider()?;
#[cfg(feature = "direct_tls")]
return tls::serve(server, app, handle, addrs).await;
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs).await?;
}
#[cfg(not(feature = "direct_tls"))]
return tuwunel_core::Err!(Config(
@@ -43,6 +45,33 @@ pub(super) async fn serve(
"tuwunel was not built with direct TLS support (\"direct_tls\")"
));
} else {
plain::serve(server, app, handle, addrs).await
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs);
}
}
assert!(!join_set.is_empty(), "at least one listener should be installed");
join_set.join_all().await;
let handle_active = server
.metrics
.requests_handle_active
.load(Ordering::Acquire);
debug_info!(
handle_finished = server
.metrics
.requests_handle_finished
.load(Ordering::Acquire),
panics = server
.metrics
.requests_panic
.load(Ordering::Acquire),
handle_active,
"Stopped listening on {addrs:?}",
);
debug_assert_eq!(0, handle_active, "active request handles still pending");
Ok(())
}

View File

@@ -1,50 +1,26 @@
use std::{
net::SocketAddr,
sync::{Arc, atomic::Ordering},
};
use std::{net::SocketAddr, sync::Arc};
use axum::Router;
use axum_server::{Handle as ServerHandle, bind};
use axum_server::Handle;
use tokio::task::JoinSet;
use tuwunel_core::{Result, Server, debug_info, info};
use tuwunel_core::{Server, info};
pub(super) async fn serve(
pub(super) fn serve(
server: &Arc<Server>,
router: Router,
handle: ServerHandle,
addrs: Vec<SocketAddr>,
) -> Result {
let mut join_set = JoinSet::new();
let router = router.into_make_service_with_connect_info::<SocketAddr>();
for addr in &addrs {
let bound = bind(*addr);
let handler = bound.handle(handle.clone());
let acceptor = handler.serve(router.clone());
router: &Router,
handle: &Handle<SocketAddr>,
join_set: &mut JoinSet<Result<(), std::io::Error>>,
addrs: &[SocketAddr],
) {
let router = router
.clone()
.into_make_service_with_connect_info::<SocketAddr>();
for addr in addrs {
let acceptor = axum_server::bind(*addr)
.handle(handle.clone())
.serve(router.clone());
join_set.spawn_on(acceptor, server.runtime());
}
info!("Listening on {addrs:?}");
while join_set.join_next().await.is_some() {}
let handle_active = server
.metrics
.requests_handle_active
.load(Ordering::Acquire);
debug_info!(
handle_finished = server
.metrics
.requests_handle_finished
.load(Ordering::Acquire),
panics = server
.metrics
.requests_panic
.load(Ordering::Acquire),
handle_active,
"Stopped listening on {addrs:?}",
);
debug_assert_eq!(0, handle_active, "active request handles still pending");
Ok(())
}

View File

@@ -1,31 +1,21 @@
use std::{
net::SocketAddr,
sync::{Arc, atomic::Ordering},
};
use std::{net::SocketAddr, sync::Arc};
use axum::Router;
use axum_server::Handle as ServerHandle;
use axum_server_dual_protocol::{
ServerExt,
axum_server::{bind_rustls, tls_rustls::RustlsConfig},
};
use axum_server::Handle;
use axum_server_dual_protocol::{ServerExt, axum_server::tls_rustls::RustlsConfig};
use tokio::task::JoinSet;
use tuwunel_core::{Result, Server, debug, debug_info, err, info, warn};
use tuwunel_core::{Result, Server, debug, err, info, warn};
pub(super) async fn serve(
server: &Arc<Server>,
app: Router,
handle: ServerHandle,
addrs: Vec<SocketAddr>,
app: &Router,
handle: &Handle<SocketAddr>,
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
addrs: &[SocketAddr],
) -> Result {
let tls = &server.config.tls;
let certs = tls.certs.as_ref().ok_or_else(|| {
err!(Config("tls.certs", "Missing required value in tls config section"))
})?;
let key = tls
.key
.as_ref()
.ok_or_else(|| err!(Config("tls.key", "Missing required value in tls config section")))?;
let certs = tls.certs.as_ref().unwrap();
let key = tls.key.as_ref().unwrap();
info!(
"Note: It is strongly recommended that you use a reverse proxy instead of running \
@@ -36,10 +26,11 @@ pub(super) async fn serve(
.await
.map_err(|e| err!(Config("tls", "Failed to load certificates or key: {e}")))?;
let mut join_set = JoinSet::new();
let app = app.into_make_service_with_connect_info::<SocketAddr>();
let app = app
.clone()
.into_make_service_with_connect_info::<SocketAddr>();
if tls.dual_protocol {
for addr in &addrs {
for addr in addrs {
join_set.spawn_on(
axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone())
.set_upgrade(false)
@@ -48,47 +39,23 @@ pub(super) async fn serve(
server.runtime(),
);
}
} else {
for addr in &addrs {
join_set.spawn_on(
bind_rustls(*addr, conf.clone())
.handle(handle.clone())
.serve(app.clone()),
server.runtime(),
);
}
}
if tls.dual_protocol {
warn!(
"Listening on {addrs:?} with TLS certificate {certs} and supporting plain text \
(HTTP) connections too (insecure!)",
);
} else {
for addr in addrs {
join_set.spawn_on(
axum_server::bind_rustls(*addr, conf.clone())
.handle(handle.clone())
.serve(app.clone()),
server.runtime(),
);
}
info!("Listening on {addrs:?} with TLS certificate {certs}");
}
while join_set.join_next().await.is_some() {}
let handle_active = server
.metrics
.requests_handle_active
.load(Ordering::Acquire);
debug_info!(
handle_finished = server
.metrics
.requests_handle_finished
.load(Ordering::Acquire),
panics = server
.metrics
.requests_panic
.load(Ordering::Acquire),
handle_active,
"Stopped listening on {addrs:?}",
);
debug_assert_eq!(0, handle_active, "active request handles still pending");
Ok(())
}

View File

@@ -1,185 +1,48 @@
#![cfg(unix)]
use std::{
net::{self, IpAddr, Ipv4Addr},
os::fd::AsRawFd,
path::Path,
sync::{Arc, atomic::Ordering},
};
use axum::{
Router,
extract::{Request, connect_info::IntoMakeServiceWithConnectInfo},
};
use hyper::{body::Incoming, service::service_fn};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server,
};
use tokio::{
fs,
net::{UnixListener, UnixStream, unix::SocketAddr},
sync::broadcast::{self},
task::JoinSet,
time::{Duration, sleep},
};
use tower::{Service, ServiceExt};
use tuwunel_core::{
Err, Result, Server, debug, debug_error, info, result::UnwrapInfallible, trace, warn,
net::SocketAddr,
os::unix::{fs::PermissionsExt, net::UnixListener},
path::Path,
sync::Arc,
};
type MakeService = IntoMakeServiceWithConnectInfo<Router, net::SocketAddr>;
const NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
const FINI_POLL_INTERVAL: Duration = Duration::from_millis(750);
use axum::{Extension, Router, extract::ConnectInfo};
use axum_server::Handle;
use tokio::task::JoinSet;
use tuwunel_core::{Result, Server, info, warn};
#[tracing::instrument(skip_all, level = "debug")]
pub(super) async fn serve(
server: &Arc<Server>,
app: Router,
mut shutdown: broadcast::Receiver<()>,
router: &Router,
handle: &Handle<std::os::unix::net::SocketAddr>,
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
path: &Path,
socket_perms: u32,
) -> Result {
let mut tasks = JoinSet::<()>::new();
let executor = TokioExecutor::new();
let app = app.into_make_service_with_connect_info::<net::SocketAddr>();
let builder = server::conn::auto::Builder::new(executor);
let listener = init(server).await?;
while server.running() {
let app = app.clone();
let builder = builder.clone();
tokio::select! {
_sig = shutdown.recv() => break,
conn = listener.accept() => match conn {
Ok(conn) => accept(server, &listener, &mut tasks, app, builder, conn).await,
Err(err) => debug_error!(?listener, "accept error: {err}"),
},
}
}
fini(server, listener, tasks).await;
Ok(())
}
#[tracing::instrument(
level = "trace",
skip_all,
fields(
?listener,
socket = ?conn.0,
),
)]
async fn accept(
server: &Arc<Server>,
listener: &UnixListener,
tasks: &mut JoinSet<()>,
app: MakeService,
builder: server::conn::auto::Builder<TokioExecutor>,
conn: (UnixStream, SocketAddr),
) {
let (socket, _) = conn;
let server_ = server.clone();
let task = async move { accepted(server_, builder, socket, app).await };
_ = tasks.spawn_on(task, server.runtime());
while tasks.try_join_next().is_some() {}
}
#[tracing::instrument(
level = "trace",
skip_all,
fields(
fd = %socket.as_raw_fd(),
path = ?socket.local_addr(),
),
)]
async fn accepted(
server: Arc<Server>,
builder: server::conn::auto::Builder<TokioExecutor>,
socket: UnixStream,
mut app: MakeService,
) {
let socket = TokioIo::new(socket);
let called = app.call(NULL_ADDR).await.unwrap_infallible();
let service = move |req: Request<Incoming>| called.clone().oneshot(req);
let handler = service_fn(service);
trace!(?socket, ?handler, "serving connection");
// bug on darwin causes all results to be errors. do not unwrap this
tokio::select! {
() = server.until_shutdown() => (),
_ = builder.serve_connection(socket, handler) => (),
};
}
async fn init(server: &Arc<Server>) -> Result<UnixListener> {
use std::os::unix::fs::PermissionsExt;
let config = &server.config;
let path = config
.unix_socket_path
.as_ref()
.expect("failed to extract configured unix socket path");
if path.exists() {
warn!("Removing existing UNIX socket {:#?} (unclean shutdown?)...", path.display());
fs::remove_file(&path)
.await
.map_err(|e| warn!("Failed to remove existing UNIX socket: {e}"))
.unwrap();
warn!("Removing existing UNIX socket {path:?} (unclean shutdown?)...");
fs::remove_file(path)?;
}
let dir = path.parent().unwrap_or_else(|| Path::new("/"));
if let Err(e) = fs::create_dir_all(dir).await {
return Err!("Failed to create {dir:?} for socket {path:?}: {e}");
}
let unix_listener = UnixListener::bind(path)?;
unix_listener.set_nonblocking(true)?;
let listener = UnixListener::bind(path);
if let Err(e) = listener {
return Err!("Failed to bind listener {path:?}: {e}");
}
let perms = fs::Permissions::from_mode(socket_perms);
fs::set_permissions(path, perms)?;
let socket_perms = config.unix_socket_perms.to_string();
let octal_perms =
u32::from_str_radix(&socket_perms, 8).expect("failed to convert octal permissions");
let perms = std::fs::Permissions::from_mode(octal_perms);
if let Err(e) = fs::set_permissions(&path, perms).await {
return Err!("Failed to set socket {path:?} permissions: {e}");
}
let router = router
.clone()
.layer(Extension(ConnectInfo("0.0.0.0".parse::<SocketAddr>())))
.into_make_service();
let acceptor = axum_server::from_unix(unix_listener)?
.handle(handle.clone())
.serve(router);
join_set.spawn_on(acceptor, server.runtime());
info!("Listening at {path:?}");
Ok(listener.unwrap())
}
async fn fini(server: &Arc<Server>, listener: UnixListener, mut tasks: JoinSet<()>) {
let local = listener.local_addr();
debug!("Closing listener at {local:?} ...");
drop(listener);
debug!("Waiting for requests to finish...");
while server
.metrics
.requests_handle_active
.load(Ordering::Acquire)
.gt(&0)
{
tokio::select! {
task = tasks.join_next() => if task.is_none() { break; },
() = sleep(FINI_POLL_INTERVAL) => {},
}
}
debug!("Shutting down...");
tasks.shutdown().await;
if let Ok(local) = local
&& let Some(path) = local.as_pathname()
{
debug!(?path, "Removing unix socket file.");
if let Err(e) = fs::remove_file(path).await {
warn!(?path, "Failed to remove UNIX socket file: {e}");
}
}
Ok(())
}

View File

@@ -74,9 +74,6 @@
# The UNIX socket tuwunel will listen on.
#
# tuwunel cannot listen on both an IP address and a UNIX socket. If
# listening on a UNIX socket, you MUST remove/comment the `address` key.
#
# Remember to make sure that your reverse proxy has access to this socket
# file, either by adding your reverse proxy to the 'tuwunel' group or
# granting world R/W permissions with `unix_socket_perms` (666 minimum).