From 85d34bb0358cb9ae4536881cc4e4b14aeeaa7c16 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Tue, 7 Apr 2026 13:48:59 +0100 Subject: [PATCH] feat(net): add UDP transport for direct peer connections MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DERP works for everything but adds relay latency. Add a parallel UDP transport so peers with reachable endpoints can talk directly: - wg/tunnel: track each peer's local boringtun index in PeerTunnel and expose find_peer_by_local_index / find_peer_by_endpoint lookups - daemon/lifecycle: bind a UdpSocket on 0.0.0.0:0 alongside DERP, run the recv loop on a clone of an Arc so send and recv can proceed concurrently - run_wg_loop: new udp_in_rx select arm. For inbound UDP we identify the source peer by parsing the WireGuard receiver_index out of the packet header (msg types 2/3/4) and falling back to source-address matching for type-1 handshake initiations - dispatch_encap: SendUdp now actually forwards via the UDP channel UDP failure is non-fatal — DERP can carry traffic alone if the bind fails or packets are dropped. --- sunbeam-net/src/daemon/lifecycle.rs | 174 +++++++++++++++++++++++++--- sunbeam-net/src/wg/tunnel.rs | 28 +++++ 2 files changed, 183 insertions(+), 19 deletions(-) diff --git a/sunbeam-net/src/daemon/lifecycle.rs b/sunbeam-net/src/daemon/lifecycle.rs index 470b9a62..cd82d4bb 100644 --- a/sunbeam-net/src/daemon/lifecycle.rs +++ b/sunbeam-net/src/daemon/lifecycle.rs @@ -155,6 +155,28 @@ async fn run_session( let (derp_out_tx, derp_out_rx) = mpsc::channel::<([u8; 32], Vec)>(256); let (derp_in_tx, derp_in_rx) = mpsc::channel::<([u8; 32], Vec)>(256); + // 9a. Bind a UDP socket for direct WireGuard transport. Failure here is + // non-fatal — DERP can carry traffic alone, just slower. + let (udp_out_tx, udp_out_rx) = + mpsc::channel::<(std::net::SocketAddr, Vec)>(256); + let (udp_in_tx, udp_in_rx) = + mpsc::channel::<(std::net::SocketAddr, Vec)>(256); + let _udp_task = match tokio::net::UdpSocket::bind("0.0.0.0:0").await { + Ok(socket) => { + let local = socket.local_addr().ok(); + tracing::info!("WG UDP socket bound on {local:?}"); + let socket = std::sync::Arc::new(socket); + let udp_cancel = cancel.clone(); + Some(tokio::spawn(async move { + run_udp_loop(socket, udp_out_rx, udp_in_tx, udp_cancel).await; + })) + } + Err(e) => { + tracing::warn!("UDP bind failed: {e}; continuing with DERP only"); + None + } + }; + let derp_endpoint = derp_map .as_ref() .and_then(pick_derp_node) @@ -193,6 +215,8 @@ async fn run_session( wg_to_engine_tx, derp_out_tx, derp_in_rx, + udp_out_tx, + udp_in_rx, wg_cancel, ) .await @@ -284,21 +308,23 @@ async fn run_proxy_listener( /// WireGuard encapsulation/decapsulation loop. /// /// Reads IP packets from the engine, encapsulates them through WireGuard, -/// and sends WG packets out via DERP. Receives WG packets from DERP, -/// decapsulates, and feeds IP packets back to the engine. +/// and sends WG packets out via UDP (preferred) or DERP relay (fallback). +/// Receives WG packets from either transport, decapsulates them, and feeds +/// the resulting IP packets back to the engine. +#[allow(clippy::too_many_arguments)] async fn run_wg_loop( mut tunnel: WgTunnel, mut from_engine: mpsc::Receiver>, to_engine: mpsc::Sender>, derp_out_tx: mpsc::Sender<([u8; 32], Vec)>, mut derp_in_rx: mpsc::Receiver<([u8; 32], Vec)>, + udp_out_tx: mpsc::Sender<(std::net::SocketAddr, Vec)>, + mut udp_in_rx: mpsc::Receiver<(std::net::SocketAddr, Vec)>, cancel: tokio_util::sync::CancellationToken, ) { let mut tick_interval = tokio::time::interval(Duration::from_millis(250)); tick_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); - // TODO: wire UDP socket for direct transport (currently DERP-only) - loop { tokio::select! { _ = cancel.cancelled() => return, @@ -307,7 +333,7 @@ async fn run_wg_loop( Some(packet) => { if let Some(dst_ip) = parse_dst_ip(&packet) { let action = tunnel.encapsulate(dst_ip, &packet); - dispatch_encap(action, &derp_out_tx).await; + dispatch_encap(action, &derp_out_tx, &udp_out_tx).await; } } None => return, // engine dropped @@ -317,25 +343,28 @@ async fn run_wg_loop( match incoming { Some((src_key, data)) => { let action = tunnel.decapsulate(&src_key, &data); - match action { - DecapAction::Packet(p) => { - if to_engine.send(p).await.is_err() { - return; - } - } - DecapAction::Response(r) => { - let _ = derp_out_tx.send((src_key, r)).await; - } - DecapAction::Nothing => {} - } + handle_decap(action, src_key, &to_engine, &derp_out_tx).await; } None => return, // DERP loop dropped } } + incoming = udp_in_rx.recv() => { + match incoming { + Some((src_addr, data)) => { + let Some(peer_key) = identify_udp_peer(&tunnel, src_addr, &data) else { + tracing::trace!("UDP packet from {src_addr}: no peer match"); + continue; + }; + let action = tunnel.decapsulate(&peer_key, &data); + handle_decap(action, peer_key, &to_engine, &derp_out_tx).await; + } + None => return, // UDP loop dropped + } + } _ = tick_interval.tick() => { let actions = tunnel.tick(); for ta in actions { - dispatch_encap(ta.action, &derp_out_tx).await; + dispatch_encap(ta.action, &derp_out_tx, &udp_out_tx).await; } } } @@ -346,11 +375,12 @@ async fn run_wg_loop( async fn dispatch_encap( action: crate::wg::tunnel::EncapAction, derp_out_tx: &mpsc::Sender<([u8; 32], Vec)>, + udp_out_tx: &mpsc::Sender<(std::net::SocketAddr, Vec)>, ) { match action { crate::wg::tunnel::EncapAction::SendUdp { endpoint, data } => { - tracing::trace!("WG → UDP {endpoint} ({} bytes) — UDP transport not implemented", data.len()); - // TODO: send via UDP socket + tracing::trace!("WG → UDP {endpoint} ({} bytes)", data.len()); + let _ = udp_out_tx.send((endpoint, data)).await; } crate::wg::tunnel::EncapAction::SendDerp { dest_key, data } => { tracing::trace!("WG → DERP ({} bytes)", data.len()); @@ -360,6 +390,112 @@ async fn dispatch_encap( } } +/// Handle a single decapsulation result regardless of which transport it +/// arrived on. Decrypted IP packets go to the engine; handshake responses +/// go back out via DERP (we don't know a UDP endpoint for response peers +/// at this layer — DERP is always a safe fallback). +async fn handle_decap( + action: DecapAction, + peer_key: [u8; 32], + to_engine: &mpsc::Sender>, + derp_out_tx: &mpsc::Sender<([u8; 32], Vec)>, +) { + match action { + DecapAction::Packet(p) => { + let _ = to_engine.send(p).await; + } + DecapAction::Response(r) => { + let _ = derp_out_tx.send((peer_key, r)).await; + } + DecapAction::Nothing => {} + } +} + +/// Identify which peer a UDP-delivered WireGuard packet belongs to. +/// +/// WireGuard message types 2 (HandshakeResponse), 3 (CookieReply), and 4 +/// (TransportData) all carry a `receiver_index` at bytes 4..8 — that's the +/// boringtun local index we assigned when adding the peer, so we can look +/// it up directly. Type 1 (HandshakeInitiation) doesn't carry it; for +/// those we fall back to matching the source address against advertised +/// peer endpoints. +fn identify_udp_peer( + tunnel: &WgTunnel, + src_addr: std::net::SocketAddr, + packet: &[u8], +) -> Option<[u8; 32]> { + if packet.len() < 8 { + return None; + } + let msg_type = packet[0]; + match msg_type { + 2 | 3 | 4 => { + let idx = u32::from_le_bytes([packet[4], packet[5], packet[6], packet[7]]); + tunnel + .find_peer_by_local_index(idx) + .or_else(|| tunnel.find_peer_by_endpoint(src_addr)) + } + _ => tunnel.find_peer_by_endpoint(src_addr), + } +} + +/// Bridge a UDP socket to the WG layer via mpsc channels. +async fn run_udp_loop( + socket: std::sync::Arc, + mut out_rx: mpsc::Receiver<(std::net::SocketAddr, Vec)>, + in_tx: mpsc::Sender<(std::net::SocketAddr, Vec)>, + cancel: tokio_util::sync::CancellationToken, +) { + // Spawn a dedicated recv task — UdpSocket supports concurrent send+recv + // when shared via Arc, but mixing both in one select! makes lifetimes + // awkward, so split them. + let recv_socket = socket.clone(); + let recv_cancel = cancel.clone(); + let recv_task = tokio::spawn(async move { + let mut buf = vec![0u8; 65535]; + loop { + tokio::select! { + _ = recv_cancel.cancelled() => return, + result = recv_socket.recv_from(&mut buf) => { + match result { + Ok((n, src)) => { + if in_tx.send((src, buf[..n].to_vec())).await.is_err() { + return; + } + } + Err(e) => { + tracing::warn!("UDP recv error: {e}"); + return; + } + } + } + } + } + }); + + loop { + tokio::select! { + _ = cancel.cancelled() => { + recv_task.abort(); + return; + } + outgoing = out_rx.recv() => { + match outgoing { + Some((dst, data)) => { + if let Err(e) = socket.send_to(&data, dst).await { + tracing::warn!("UDP send to {dst} failed: {e}"); + } + } + None => { + recv_task.abort(); + return; + } + } + } + } + } +} + /// DERP relay loop: bridges packets between WG layer and a DERP client. async fn run_derp_loop( mut client: DerpClient, diff --git a/sunbeam-net/src/wg/tunnel.rs b/sunbeam-net/src/wg/tunnel.rs index 85358309..52170f98 100644 --- a/sunbeam-net/src/wg/tunnel.rs +++ b/sunbeam-net/src/wg/tunnel.rs @@ -18,6 +18,10 @@ struct PeerTunnel { endpoint: Option, derp_region: Option, allowed_ips: Vec, + /// Local boringtun index assigned to this peer's tunnel. Used to route + /// inbound UDP packets back to the right peer via the receiver_index + /// field in WireGuard message types 2/3/4. + local_index: u32, } /// Result of encapsulating an outbound IP packet. @@ -99,6 +103,7 @@ impl WgTunnel { endpoint: parse_first_endpoint(&node.endpoints), derp_region: parse_derp_region(&node.derp), allowed_ips: parse_allowed_ips(&node.allowed_ips), + local_index: index, }); } } @@ -195,6 +200,29 @@ impl WgTunnel { actions } + /// Find a peer by the boringtun local index we assigned to it. + /// WireGuard message types 2/3/4 carry this in the receiver_index field. + pub fn find_peer_by_local_index(&self, idx: u32) -> Option<[u8; 32]> { + for (key, peer) in &self.peers { + if peer.local_index == idx { + return Some(*key); + } + } + None + } + + /// Find a peer whose advertised endpoint matches the given socket addr. + /// Used as a fallback for inbound UDP packets that don't carry our index + /// (i.e. type-1 handshake initiations from peers we already know about). + pub fn find_peer_by_endpoint(&self, addr: SocketAddr) -> Option<[u8; 32]> { + for (key, peer) in &self.peers { + if peer.endpoint == Some(addr) { + return Some(*key); + } + } + None + } + /// Find which peer owns a given IP address. fn find_peer_for_ip(&self, ip: IpAddr) -> Option<&[u8; 32]> { for (key, peer) in &self.peers {