Update axum-server to 0.8; switch to axum unix listener.
This commit is contained in:
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -402,12 +402,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "axum-server"
|
||||
version = "0.7.3"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1ab4a3ec9ea8a657c72d99a03a824af695bd0fb5ec639ccbd9cd3543b41a5f9"
|
||||
checksum = "b1df331683d982a0b9492b38127151e6453639cd34926eb9c07d4cd8c6d22bfc"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"bytes",
|
||||
"either",
|
||||
"fs-err",
|
||||
"http",
|
||||
"http-body",
|
||||
@@ -415,7 +416,6 @@ dependencies = [
|
||||
"hyper-util",
|
||||
"pin-project-lite",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
@@ -424,9 +424,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "axum-server-dual-protocol"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2164551db024e87f20316d164eab9f5ad342d8188b08051ceb15ca92a60ea7b7"
|
||||
version = "0.8.0"
|
||||
source = "git+https://github.com/matrix-construct/axum-server-dual-protocol?rev=76c782fa6f129f83ffdca59e903093e798d4a82f#76c782fa6f129f83ffdca59e903093e798d4a82f"
|
||||
dependencies = [
|
||||
"axum-server",
|
||||
"bytes",
|
||||
@@ -4049,15 +4048,6 @@ dependencies = [
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.14.0"
|
||||
@@ -5393,7 +5383,6 @@ dependencies = [
|
||||
"http",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"ruma",
|
||||
"rustls",
|
||||
|
||||
14
Cargo.toml
14
Cargo.toml
@@ -82,11 +82,12 @@ features = [
|
||||
]
|
||||
|
||||
[workspace.dependencies.axum-server]
|
||||
version = "0.7"
|
||||
version = "0.8"
|
||||
default-features = false
|
||||
|
||||
[workspace.dependencies.axum-server-dual-protocol]
|
||||
version = "0.7"
|
||||
git = "https://github.com/matrix-construct/axum-server-dual-protocol"
|
||||
rev = "76c782fa6f129f83ffdca59e903093e798d4a82f"
|
||||
default-features = false
|
||||
|
||||
[workspace.dependencies.base64]
|
||||
@@ -201,15 +202,6 @@ features = [
|
||||
"http2",
|
||||
]
|
||||
|
||||
[workspace.dependencies.hyper-util]
|
||||
version = "0.1.20"
|
||||
default-features = false
|
||||
features = [
|
||||
"server-auto",
|
||||
"server-graceful",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[workspace.dependencies.image]
|
||||
version = "0.25"
|
||||
default-features = false
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
use std::env::consts::OS;
|
||||
|
||||
use either::Either;
|
||||
use figment::Figment;
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::{DEPRECATED_KEYS, IdentityProvider};
|
||||
use crate::{Config, Err, Result, Server, debug, debug_info, debug_warn, error, warn};
|
||||
use crate::{Config, Err, Result, Server, debug, debug_info, error, warn};
|
||||
|
||||
/// Performs check() with additional checks specific to reloading old config
|
||||
/// with new config.
|
||||
@@ -24,9 +23,8 @@ pub fn reload(old: &Config, new: &Config) -> Result {
|
||||
}
|
||||
|
||||
pub fn check(config: &Config) -> Result {
|
||||
if cfg!(debug_assertions) {
|
||||
#[cfg(debug_assertions)]
|
||||
warn!("Note: tuwunel was built without optimisations (i.e. debug build)");
|
||||
}
|
||||
|
||||
warn_deprecated(config);
|
||||
warn_unknown_key(config)?;
|
||||
@@ -38,14 +36,18 @@ pub fn check(config: &Config) -> Result {
|
||||
));
|
||||
}
|
||||
|
||||
if cfg!(all(feature = "hardened_malloc", feature = "jemalloc", not(target_env = "msvc"))) {
|
||||
#[cfg(all(
|
||||
feature = "hardened_malloc",
|
||||
feature = "jemalloc",
|
||||
not(target_env = "msvc")
|
||||
))]
|
||||
debug_warn!(
|
||||
"hardened_malloc and jemalloc compile-time features are both enabled, this causes \
|
||||
jemalloc to be used."
|
||||
);
|
||||
}
|
||||
|
||||
if cfg!(not(unix)) && config.unix_socket_path.is_some() {
|
||||
#[cfg(not(unix))]
|
||||
if config.unix_socket_path.is_some() {
|
||||
return Err!(Config(
|
||||
"unix_socket_path",
|
||||
"UNIX socket support is only available on *nix platforms. Please remove \
|
||||
@@ -53,13 +55,21 @@ pub fn check(config: &Config) -> Result {
|
||||
));
|
||||
}
|
||||
|
||||
if config.unix_socket_path.is_none() && config.get_bind_hosts().is_empty() {
|
||||
if config.unix_socket_path.is_none() {
|
||||
if config.get_bind_hosts().is_empty() {
|
||||
return Err!(Config("address", "No TCP addresses were specified to listen on"));
|
||||
}
|
||||
|
||||
if config.unix_socket_path.is_none() && config.get_bind_ports().is_empty() {
|
||||
if config.get_bind_ports().is_empty() {
|
||||
return Err!(Config("port", "No ports were specified to listen on"));
|
||||
}
|
||||
}
|
||||
|
||||
let certs_set = config.tls.certs.is_some();
|
||||
let key_set = config.tls.key.is_some();
|
||||
if certs_set ^ key_set {
|
||||
return Err!(Config("tls", "tls.certs and tls.key must either both be set or unset"));
|
||||
}
|
||||
|
||||
if !config.listening {
|
||||
warn!("Configuration item `listening` is set to `false`. Cannot hear anyone.");
|
||||
@@ -115,7 +125,8 @@ pub fn check(config: &Config) -> Result {
|
||||
}
|
||||
|
||||
// yeah, unless the user built a debug build hopefully for local testing only
|
||||
if cfg!(not(debug_assertions)) && config.server_name == "your.server.name" {
|
||||
#[cfg(not(debug_assertions))]
|
||||
if config.server_name == "your.server.name" {
|
||||
return Err!(Config(
|
||||
"server_name",
|
||||
"You must specify a valid server name for production usage of tuwunel."
|
||||
@@ -412,18 +423,3 @@ fn warn_unknown_key(config: &Config) -> Result {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks the presence of the `address` and `unix_socket_path` keys in the
|
||||
/// raw_config, exiting the process if both keys were detected.
|
||||
pub(super) fn is_dual_listening(raw_config: &Figment) -> Result {
|
||||
let contains_address = raw_config.contains("address");
|
||||
let contains_unix_socket = raw_config.contains("unix_socket_path");
|
||||
if contains_address && contains_unix_socket {
|
||||
return Err!(
|
||||
"TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify \
|
||||
only one option."
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -105,8 +105,8 @@ pub struct Config {
|
||||
/// "::1"]
|
||||
///
|
||||
/// default: ["127.0.0.1", "::1"]
|
||||
#[serde(default = "default_address")]
|
||||
address: ListeningAddr,
|
||||
#[serde(default)]
|
||||
address: Option<ListeningAddr>,
|
||||
|
||||
/// The port(s) tuwunel will listen on.
|
||||
///
|
||||
@@ -128,9 +128,6 @@ pub struct Config {
|
||||
|
||||
/// The UNIX socket tuwunel will listen on.
|
||||
///
|
||||
/// tuwunel cannot listen on both an IP address and a UNIX socket. If
|
||||
/// listening on a UNIX socket, you MUST remove/comment the `address` key.
|
||||
///
|
||||
/// Remember to make sure that your reverse proxy has access to this socket
|
||||
/// file, either by adding your reverse proxy to the 'tuwunel' group or
|
||||
/// granting world R/W permissions with `unix_socket_perms` (666 minimum).
|
||||
@@ -3010,18 +3007,24 @@ impl Config {
|
||||
.extract::<Self>()
|
||||
.map_err(|e| err!("There was a problem with your configuration file: {e}"))?;
|
||||
|
||||
// don't start if we're listening on both UNIX sockets and TCP at same time
|
||||
check::is_dual_listening(raw_config)?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn get_unix_socket_perms(&self) -> Result<u32> {
|
||||
let octal_perms = self.unix_socket_perms.to_string();
|
||||
let socket_perms = u32::from_str_radix(&octal_perms, 8).map_err(|_| {
|
||||
err!(Config("unix_socket_perms", "failed to convert octal permissions"))
|
||||
})?;
|
||||
|
||||
Ok(socket_perms)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_bind_addrs(&self) -> Vec<SocketAddr> {
|
||||
let mut addrs = Vec::with_capacity(
|
||||
self.get_bind_hosts()
|
||||
.len()
|
||||
.saturating_add(self.get_bind_ports().len()),
|
||||
.saturating_mul(self.get_bind_ports().len()),
|
||||
);
|
||||
for host in &self.get_bind_hosts() {
|
||||
for port in &self.get_bind_ports() {
|
||||
@@ -3033,10 +3036,16 @@ impl Config {
|
||||
}
|
||||
|
||||
fn get_bind_hosts(&self) -> Vec<IpAddr> {
|
||||
match &self.address.addrs {
|
||||
if let Some(address) = &self.address {
|
||||
match &address.addrs {
|
||||
| Left(addr) => vec![*addr],
|
||||
| Right(addrs) => addrs.clone(),
|
||||
}
|
||||
} else if self.unix_socket_path.is_some() {
|
||||
vec![]
|
||||
} else {
|
||||
vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]
|
||||
}
|
||||
}
|
||||
|
||||
fn get_bind_ports(&self) -> Vec<u16> {
|
||||
@@ -3056,12 +3065,6 @@ fn default_server_name() -> OwnedServerName { ruma::owned_server_name!("localhos
|
||||
|
||||
fn default_database_path() -> PathBuf { "/var/lib/tuwunel".to_owned().into() }
|
||||
|
||||
fn default_address() -> ListeningAddr {
|
||||
ListeningAddr {
|
||||
addrs: Right(vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]),
|
||||
}
|
||||
}
|
||||
|
||||
fn default_port() -> ListeningPort { ListeningPort { ports: Left(8008) } }
|
||||
|
||||
fn default_unix_socket_perms() -> u32 { 660 }
|
||||
|
||||
@@ -115,7 +115,6 @@ futures.workspace = true
|
||||
http.workspace = true
|
||||
http-body-util.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper-util.workspace = true
|
||||
log.workspace = true
|
||||
ruma.workspace = true
|
||||
rustls.workspace = true
|
||||
|
||||
27
src/router/handle.rs
Normal file
27
src/router/handle.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use axum_server::Handle;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ServerHandle {
|
||||
pub(crate) handle_ip: Handle<std::net::SocketAddr>,
|
||||
#[cfg(unix)]
|
||||
pub(crate) handle_unix: Handle<std::os::unix::net::SocketAddr>,
|
||||
}
|
||||
|
||||
impl ServerHandle {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
handle_ip: Handle::<std::net::SocketAddr>::new(),
|
||||
#[cfg(unix)]
|
||||
handle_unix: Handle::<std::os::unix::net::SocketAddr>::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn graceful_shutdown(&self, duration: Option<Duration>) {
|
||||
self.handle_ip.graceful_shutdown(duration);
|
||||
|
||||
#[cfg(unix)]
|
||||
self.handle_unix.graceful_shutdown(duration);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
#![type_length_limit = "32768"] //TODO: reduce me
|
||||
#![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91
|
||||
|
||||
mod handle;
|
||||
mod layers;
|
||||
mod request;
|
||||
mod router;
|
||||
|
||||
@@ -3,15 +3,11 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use axum_server::Handle as ServerHandle;
|
||||
use tokio::{
|
||||
sync::broadcast::{self, Sender},
|
||||
task::JoinHandle,
|
||||
};
|
||||
use futures::FutureExt;
|
||||
use tuwunel_core::{Error, Result, Server, debug, debug_error, debug_info, error, info};
|
||||
use tuwunel_service::Services;
|
||||
|
||||
use crate::serve;
|
||||
use crate::{handle::ServerHandle, serve};
|
||||
|
||||
/// Main loop base
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -24,20 +20,30 @@ pub(crate) async fn run(services: Arc<Services>) -> Result {
|
||||
|
||||
// Setup shutdown/signal handling
|
||||
let handle = ServerHandle::new();
|
||||
let (tx, _) = broadcast::channel::<()>(1);
|
||||
let sigs = server
|
||||
.runtime()
|
||||
.spawn(signal(server.clone(), tx.clone(), handle.clone()));
|
||||
.spawn(signal(server.clone(), handle.clone()));
|
||||
|
||||
let mut listener =
|
||||
let mut listener = if services.config.listening {
|
||||
let future = serve::serve(services.clone(), handle);
|
||||
server
|
||||
.runtime()
|
||||
.spawn(serve::serve(services.clone(), handle.clone(), tx.subscribe()));
|
||||
.spawn(future)
|
||||
.map(|res| res.map_err(Error::from).unwrap_or_else(Err))
|
||||
.boxed()
|
||||
} else {
|
||||
let server = server.clone();
|
||||
async move {
|
||||
server.until_shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
};
|
||||
|
||||
// Focal point
|
||||
debug!("Running");
|
||||
let res = tokio::select! {
|
||||
res = &mut listener => res.map_err(Error::from).unwrap_or_else(Err),
|
||||
res = &mut listener => res,
|
||||
res = services.poll() => handle_services_poll(server, res, listener).await,
|
||||
};
|
||||
|
||||
@@ -104,16 +110,12 @@ pub(crate) async fn stop(services: Arc<Services>) -> Result {
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn signal(server: Arc<Server>, tx: Sender<()>, handle: axum_server::Handle) {
|
||||
async fn signal(server: Arc<Server>, handle: ServerHandle) {
|
||||
server.until_shutdown().await;
|
||||
handle_shutdown(&server, &tx, &handle);
|
||||
}
|
||||
|
||||
fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_server::Handle) {
|
||||
if let Err(e) = tx.send(()) {
|
||||
error!("failed sending shutdown transaction to channel: {e}");
|
||||
handle_shutdown(&server, &handle);
|
||||
}
|
||||
|
||||
fn handle_shutdown(server: &Arc<Server>, handle: &ServerHandle) {
|
||||
let timeout = server.config.client_shutdown_timeout;
|
||||
let timeout = Duration::from_secs(timeout);
|
||||
debug!(
|
||||
@@ -128,7 +130,7 @@ fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_server::
|
||||
async fn handle_services_poll(
|
||||
server: &Arc<Server>,
|
||||
result: Result,
|
||||
listener: JoinHandle<Result>,
|
||||
listener: impl Future<Output = Result>,
|
||||
) -> Result {
|
||||
debug!("Service manager finished: {result:?}");
|
||||
|
||||
|
||||
@@ -3,39 +3,41 @@ mod plain;
|
||||
mod tls;
|
||||
mod unix;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, atomic::Ordering};
|
||||
|
||||
use axum_server::Handle as ServerHandle;
|
||||
use tokio::sync::broadcast;
|
||||
use tuwunel_core::{Result, err};
|
||||
use tokio::task::JoinSet;
|
||||
use tuwunel_core::{Result, debug_info};
|
||||
use tuwunel_service::Services;
|
||||
|
||||
use super::layers;
|
||||
use crate::handle::ServerHandle;
|
||||
|
||||
/// Serve clients
|
||||
pub(super) async fn serve(
|
||||
services: Arc<Services>,
|
||||
handle: ServerHandle,
|
||||
mut shutdown: broadcast::Receiver<()>,
|
||||
) -> Result {
|
||||
pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Result {
|
||||
let server = &services.server;
|
||||
let config = &server.config;
|
||||
if !config.listening {
|
||||
return shutdown
|
||||
.recv()
|
||||
.await
|
||||
.map_err(|e| err!(error!("channel error: {e}")));
|
||||
|
||||
let (app, _guard) = layers::build(&services)?;
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
#[cfg(unix)]
|
||||
if let Some(unix_socket) = &config.unix_socket_path {
|
||||
let socket_perms = config.get_unix_socket_perms()?;
|
||||
|
||||
unix::serve(server, &app, &handle.handle_unix, &mut join_set, unix_socket, socket_perms)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let addrs = config.get_bind_addrs();
|
||||
let (app, _guard) = layers::build(&services)?;
|
||||
if cfg!(unix) && config.unix_socket_path.is_some() {
|
||||
unix::serve(server, app, shutdown).await
|
||||
} else if config.tls.certs.is_some() {
|
||||
if !addrs.is_empty() {
|
||||
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
|
||||
if config.tls.certs.is_some() {
|
||||
#[cfg(feature = "direct_tls")]
|
||||
{
|
||||
services.globals.init_rustls_provider()?;
|
||||
#[cfg(feature = "direct_tls")]
|
||||
return tls::serve(server, app, handle, addrs).await;
|
||||
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs).await?;
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "direct_tls"))]
|
||||
return tuwunel_core::Err!(Config(
|
||||
@@ -43,6 +45,33 @@ pub(super) async fn serve(
|
||||
"tuwunel was not built with direct TLS support (\"direct_tls\")"
|
||||
));
|
||||
} else {
|
||||
plain::serve(server, app, handle, addrs).await
|
||||
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs);
|
||||
}
|
||||
}
|
||||
|
||||
assert!(!join_set.is_empty(), "at least one listener should be installed");
|
||||
|
||||
join_set.join_all().await;
|
||||
|
||||
let handle_active = server
|
||||
.metrics
|
||||
.requests_handle_active
|
||||
.load(Ordering::Acquire);
|
||||
|
||||
debug_info!(
|
||||
handle_finished = server
|
||||
.metrics
|
||||
.requests_handle_finished
|
||||
.load(Ordering::Acquire),
|
||||
panics = server
|
||||
.metrics
|
||||
.requests_panic
|
||||
.load(Ordering::Acquire),
|
||||
handle_active,
|
||||
"Stopped listening on {addrs:?}",
|
||||
);
|
||||
|
||||
debug_assert_eq!(0, handle_active, "active request handles still pending");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,50 +1,26 @@
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::{Arc, atomic::Ordering},
|
||||
};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::{Handle as ServerHandle, bind};
|
||||
use axum_server::Handle;
|
||||
use tokio::task::JoinSet;
|
||||
use tuwunel_core::{Result, Server, debug_info, info};
|
||||
use tuwunel_core::{Server, info};
|
||||
|
||||
pub(super) async fn serve(
|
||||
pub(super) fn serve(
|
||||
server: &Arc<Server>,
|
||||
router: Router,
|
||||
handle: ServerHandle,
|
||||
addrs: Vec<SocketAddr>,
|
||||
) -> Result {
|
||||
let mut join_set = JoinSet::new();
|
||||
let router = router.into_make_service_with_connect_info::<SocketAddr>();
|
||||
for addr in &addrs {
|
||||
let bound = bind(*addr);
|
||||
let handler = bound.handle(handle.clone());
|
||||
let acceptor = handler.serve(router.clone());
|
||||
router: &Router,
|
||||
handle: &Handle<SocketAddr>,
|
||||
join_set: &mut JoinSet<Result<(), std::io::Error>>,
|
||||
addrs: &[SocketAddr],
|
||||
) {
|
||||
let router = router
|
||||
.clone()
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
for addr in addrs {
|
||||
let acceptor = axum_server::bind(*addr)
|
||||
.handle(handle.clone())
|
||||
.serve(router.clone());
|
||||
join_set.spawn_on(acceptor, server.runtime());
|
||||
}
|
||||
|
||||
info!("Listening on {addrs:?}");
|
||||
while join_set.join_next().await.is_some() {}
|
||||
|
||||
let handle_active = server
|
||||
.metrics
|
||||
.requests_handle_active
|
||||
.load(Ordering::Acquire);
|
||||
|
||||
debug_info!(
|
||||
handle_finished = server
|
||||
.metrics
|
||||
.requests_handle_finished
|
||||
.load(Ordering::Acquire),
|
||||
panics = server
|
||||
.metrics
|
||||
.requests_panic
|
||||
.load(Ordering::Acquire),
|
||||
handle_active,
|
||||
"Stopped listening on {addrs:?}",
|
||||
);
|
||||
|
||||
debug_assert_eq!(0, handle_active, "active request handles still pending");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,31 +1,21 @@
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::{Arc, atomic::Ordering},
|
||||
};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::Handle as ServerHandle;
|
||||
use axum_server_dual_protocol::{
|
||||
ServerExt,
|
||||
axum_server::{bind_rustls, tls_rustls::RustlsConfig},
|
||||
};
|
||||
use axum_server::Handle;
|
||||
use axum_server_dual_protocol::{ServerExt, axum_server::tls_rustls::RustlsConfig};
|
||||
use tokio::task::JoinSet;
|
||||
use tuwunel_core::{Result, Server, debug, debug_info, err, info, warn};
|
||||
use tuwunel_core::{Result, Server, debug, err, info, warn};
|
||||
|
||||
pub(super) async fn serve(
|
||||
server: &Arc<Server>,
|
||||
app: Router,
|
||||
handle: ServerHandle,
|
||||
addrs: Vec<SocketAddr>,
|
||||
app: &Router,
|
||||
handle: &Handle<SocketAddr>,
|
||||
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
|
||||
addrs: &[SocketAddr],
|
||||
) -> Result {
|
||||
let tls = &server.config.tls;
|
||||
let certs = tls.certs.as_ref().ok_or_else(|| {
|
||||
err!(Config("tls.certs", "Missing required value in tls config section"))
|
||||
})?;
|
||||
let key = tls
|
||||
.key
|
||||
.as_ref()
|
||||
.ok_or_else(|| err!(Config("tls.key", "Missing required value in tls config section")))?;
|
||||
let certs = tls.certs.as_ref().unwrap();
|
||||
let key = tls.key.as_ref().unwrap();
|
||||
|
||||
info!(
|
||||
"Note: It is strongly recommended that you use a reverse proxy instead of running \
|
||||
@@ -36,10 +26,11 @@ pub(super) async fn serve(
|
||||
.await
|
||||
.map_err(|e| err!(Config("tls", "Failed to load certificates or key: {e}")))?;
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
let app = app.into_make_service_with_connect_info::<SocketAddr>();
|
||||
let app = app
|
||||
.clone()
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
if tls.dual_protocol {
|
||||
for addr in &addrs {
|
||||
for addr in addrs {
|
||||
join_set.spawn_on(
|
||||
axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone())
|
||||
.set_upgrade(false)
|
||||
@@ -48,47 +39,23 @@ pub(super) async fn serve(
|
||||
server.runtime(),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
for addr in &addrs {
|
||||
join_set.spawn_on(
|
||||
bind_rustls(*addr, conf.clone())
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
server.runtime(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if tls.dual_protocol {
|
||||
warn!(
|
||||
"Listening on {addrs:?} with TLS certificate {certs} and supporting plain text \
|
||||
(HTTP) connections too (insecure!)",
|
||||
);
|
||||
} else {
|
||||
for addr in addrs {
|
||||
join_set.spawn_on(
|
||||
axum_server::bind_rustls(*addr, conf.clone())
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
server.runtime(),
|
||||
);
|
||||
}
|
||||
|
||||
info!("Listening on {addrs:?} with TLS certificate {certs}");
|
||||
}
|
||||
|
||||
while join_set.join_next().await.is_some() {}
|
||||
|
||||
let handle_active = server
|
||||
.metrics
|
||||
.requests_handle_active
|
||||
.load(Ordering::Acquire);
|
||||
|
||||
debug_info!(
|
||||
handle_finished = server
|
||||
.metrics
|
||||
.requests_handle_finished
|
||||
.load(Ordering::Acquire),
|
||||
panics = server
|
||||
.metrics
|
||||
.requests_panic
|
||||
.load(Ordering::Acquire),
|
||||
handle_active,
|
||||
"Stopped listening on {addrs:?}",
|
||||
);
|
||||
|
||||
debug_assert_eq!(0, handle_active, "active request handles still pending");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,185 +1,48 @@
|
||||
#![cfg(unix)]
|
||||
|
||||
use std::{
|
||||
net::{self, IpAddr, Ipv4Addr},
|
||||
os::fd::AsRawFd,
|
||||
path::Path,
|
||||
sync::{Arc, atomic::Ordering},
|
||||
};
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{Request, connect_info::IntoMakeServiceWithConnectInfo},
|
||||
};
|
||||
use hyper::{body::Incoming, service::service_fn};
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
server,
|
||||
};
|
||||
use tokio::{
|
||||
fs,
|
||||
net::{UnixListener, UnixStream, unix::SocketAddr},
|
||||
sync::broadcast::{self},
|
||||
task::JoinSet,
|
||||
time::{Duration, sleep},
|
||||
};
|
||||
use tower::{Service, ServiceExt};
|
||||
use tuwunel_core::{
|
||||
Err, Result, Server, debug, debug_error, info, result::UnwrapInfallible, trace, warn,
|
||||
net::SocketAddr,
|
||||
os::unix::{fs::PermissionsExt, net::UnixListener},
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
type MakeService = IntoMakeServiceWithConnectInfo<Router, net::SocketAddr>;
|
||||
|
||||
const NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0);
|
||||
const FINI_POLL_INTERVAL: Duration = Duration::from_millis(750);
|
||||
use axum::{Extension, Router, extract::ConnectInfo};
|
||||
use axum_server::Handle;
|
||||
use tokio::task::JoinSet;
|
||||
use tuwunel_core::{Result, Server, info, warn};
|
||||
|
||||
#[tracing::instrument(skip_all, level = "debug")]
|
||||
pub(super) async fn serve(
|
||||
server: &Arc<Server>,
|
||||
app: Router,
|
||||
mut shutdown: broadcast::Receiver<()>,
|
||||
router: &Router,
|
||||
handle: &Handle<std::os::unix::net::SocketAddr>,
|
||||
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
|
||||
path: &Path,
|
||||
socket_perms: u32,
|
||||
) -> Result {
|
||||
let mut tasks = JoinSet::<()>::new();
|
||||
let executor = TokioExecutor::new();
|
||||
let app = app.into_make_service_with_connect_info::<net::SocketAddr>();
|
||||
let builder = server::conn::auto::Builder::new(executor);
|
||||
let listener = init(server).await?;
|
||||
while server.running() {
|
||||
let app = app.clone();
|
||||
let builder = builder.clone();
|
||||
tokio::select! {
|
||||
_sig = shutdown.recv() => break,
|
||||
conn = listener.accept() => match conn {
|
||||
Ok(conn) => accept(server, &listener, &mut tasks, app, builder, conn).await,
|
||||
Err(err) => debug_error!(?listener, "accept error: {err}"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fini(server, listener, tasks).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
?listener,
|
||||
socket = ?conn.0,
|
||||
),
|
||||
)]
|
||||
async fn accept(
|
||||
server: &Arc<Server>,
|
||||
listener: &UnixListener,
|
||||
tasks: &mut JoinSet<()>,
|
||||
app: MakeService,
|
||||
builder: server::conn::auto::Builder<TokioExecutor>,
|
||||
conn: (UnixStream, SocketAddr),
|
||||
) {
|
||||
let (socket, _) = conn;
|
||||
let server_ = server.clone();
|
||||
let task = async move { accepted(server_, builder, socket, app).await };
|
||||
|
||||
_ = tasks.spawn_on(task, server.runtime());
|
||||
while tasks.try_join_next().is_some() {}
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
fd = %socket.as_raw_fd(),
|
||||
path = ?socket.local_addr(),
|
||||
),
|
||||
)]
|
||||
async fn accepted(
|
||||
server: Arc<Server>,
|
||||
builder: server::conn::auto::Builder<TokioExecutor>,
|
||||
socket: UnixStream,
|
||||
mut app: MakeService,
|
||||
) {
|
||||
let socket = TokioIo::new(socket);
|
||||
let called = app.call(NULL_ADDR).await.unwrap_infallible();
|
||||
let service = move |req: Request<Incoming>| called.clone().oneshot(req);
|
||||
let handler = service_fn(service);
|
||||
trace!(?socket, ?handler, "serving connection");
|
||||
|
||||
// bug on darwin causes all results to be errors. do not unwrap this
|
||||
tokio::select! {
|
||||
() = server.until_shutdown() => (),
|
||||
_ = builder.serve_connection(socket, handler) => (),
|
||||
};
|
||||
}
|
||||
|
||||
async fn init(server: &Arc<Server>) -> Result<UnixListener> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let config = &server.config;
|
||||
let path = config
|
||||
.unix_socket_path
|
||||
.as_ref()
|
||||
.expect("failed to extract configured unix socket path");
|
||||
|
||||
if path.exists() {
|
||||
warn!("Removing existing UNIX socket {:#?} (unclean shutdown?)...", path.display());
|
||||
fs::remove_file(&path)
|
||||
.await
|
||||
.map_err(|e| warn!("Failed to remove existing UNIX socket: {e}"))
|
||||
.unwrap();
|
||||
warn!("Removing existing UNIX socket {path:?} (unclean shutdown?)...");
|
||||
fs::remove_file(path)?;
|
||||
}
|
||||
|
||||
let dir = path.parent().unwrap_or_else(|| Path::new("/"));
|
||||
if let Err(e) = fs::create_dir_all(dir).await {
|
||||
return Err!("Failed to create {dir:?} for socket {path:?}: {e}");
|
||||
}
|
||||
let unix_listener = UnixListener::bind(path)?;
|
||||
unix_listener.set_nonblocking(true)?;
|
||||
|
||||
let listener = UnixListener::bind(path);
|
||||
if let Err(e) = listener {
|
||||
return Err!("Failed to bind listener {path:?}: {e}");
|
||||
}
|
||||
let perms = fs::Permissions::from_mode(socket_perms);
|
||||
fs::set_permissions(path, perms)?;
|
||||
|
||||
let socket_perms = config.unix_socket_perms.to_string();
|
||||
let octal_perms =
|
||||
u32::from_str_radix(&socket_perms, 8).expect("failed to convert octal permissions");
|
||||
let perms = std::fs::Permissions::from_mode(octal_perms);
|
||||
if let Err(e) = fs::set_permissions(&path, perms).await {
|
||||
return Err!("Failed to set socket {path:?} permissions: {e}");
|
||||
}
|
||||
let router = router
|
||||
.clone()
|
||||
.layer(Extension(ConnectInfo("0.0.0.0".parse::<SocketAddr>())))
|
||||
.into_make_service();
|
||||
let acceptor = axum_server::from_unix(unix_listener)?
|
||||
.handle(handle.clone())
|
||||
.serve(router);
|
||||
join_set.spawn_on(acceptor, server.runtime());
|
||||
|
||||
info!("Listening at {path:?}");
|
||||
|
||||
Ok(listener.unwrap())
|
||||
}
|
||||
|
||||
async fn fini(server: &Arc<Server>, listener: UnixListener, mut tasks: JoinSet<()>) {
|
||||
let local = listener.local_addr();
|
||||
|
||||
debug!("Closing listener at {local:?} ...");
|
||||
drop(listener);
|
||||
|
||||
debug!("Waiting for requests to finish...");
|
||||
while server
|
||||
.metrics
|
||||
.requests_handle_active
|
||||
.load(Ordering::Acquire)
|
||||
.gt(&0)
|
||||
{
|
||||
tokio::select! {
|
||||
task = tasks.join_next() => if task.is_none() { break; },
|
||||
() = sleep(FINI_POLL_INTERVAL) => {},
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Shutting down...");
|
||||
tasks.shutdown().await;
|
||||
|
||||
if let Ok(local) = local
|
||||
&& let Some(path) = local.as_pathname()
|
||||
{
|
||||
debug!(?path, "Removing unix socket file.");
|
||||
if let Err(e) = fs::remove_file(path).await {
|
||||
warn!(?path, "Failed to remove UNIX socket file: {e}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -74,9 +74,6 @@
|
||||
|
||||
# The UNIX socket tuwunel will listen on.
|
||||
#
|
||||
# tuwunel cannot listen on both an IP address and a UNIX socket. If
|
||||
# listening on a UNIX socket, you MUST remove/comment the `address` key.
|
||||
#
|
||||
# Remember to make sure that your reverse proxy has access to this socket
|
||||
# file, either by adding your reverse proxy to the 'tuwunel' group or
|
||||
# granting world R/W permissions with `unix_socket_perms` (666 minimum).
|
||||
|
||||
Reference in New Issue
Block a user