feat(net): add WireGuard tunnel and smoltcp virtual network
- wg/tunnel: per-peer boringtun Tunn management with peer table sync from netmap (add/remove/update endpoints, allowed_ips, DERP region) and encapsulate/decapsulate/tick that route to UDP or DERP - wg/socket: smoltcp Interface backed by an mpsc-channel Device that bridges sync poll-based smoltcp with async tokio mpsc channels - wg/router: skeleton PacketRouter (currently unused; reserved for the unified UDP/DERP ingress path)
This commit is contained in:
@@ -5,6 +5,7 @@ pub mod derp;
|
||||
pub mod error;
|
||||
pub mod keys;
|
||||
pub mod noise;
|
||||
pub mod wg;
|
||||
pub(crate) mod proto;
|
||||
|
||||
pub use config::VpnConfig;
|
||||
|
||||
3
sunbeam-net/src/wg/mod.rs
Normal file
3
sunbeam-net/src/wg/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod router;
|
||||
pub mod socket;
|
||||
pub mod tunnel;
|
||||
203
sunbeam-net/src/wg/router.rs
Normal file
203
sunbeam-net/src/wg/router.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::tunnel::{DecapAction, EncapAction, WgTunnel};
|
||||
|
||||
/// Routes packets between WireGuard tunnels, UDP sockets, and DERP relays.
|
||||
pub(crate) struct PacketRouter {
|
||||
tunnel: WgTunnel,
|
||||
udp: Option<UdpSocket>,
|
||||
ingress_tx: mpsc::Sender<IngressPacket>,
|
||||
ingress_rx: mpsc::Receiver<IngressPacket>,
|
||||
}
|
||||
|
||||
/// An inbound WireGuard packet from either UDP or DERP.
|
||||
pub(crate) enum IngressPacket {
|
||||
/// WireGuard packet received over UDP.
|
||||
Udp { src: SocketAddr, data: Vec<u8> },
|
||||
/// WireGuard packet received via DERP relay.
|
||||
Derp { src_key: [u8; 32], data: Vec<u8> },
|
||||
}
|
||||
|
||||
impl PacketRouter {
|
||||
pub fn new(tunnel: WgTunnel) -> Self {
|
||||
let (tx, rx) = mpsc::channel(256);
|
||||
Self {
|
||||
tunnel,
|
||||
udp: None,
|
||||
ingress_tx: tx,
|
||||
ingress_rx: rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Bind the UDP socket for direct WireGuard traffic.
|
||||
pub async fn bind_udp(&mut self, addr: SocketAddr) -> crate::Result<()> {
|
||||
let sock = UdpSocket::bind(addr)
|
||||
.await
|
||||
.map_err(|e| crate::Error::WireGuard(format!("failed to bind UDP {addr}: {e}")))?;
|
||||
self.udp = Some(sock);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a clone of the ingress sender (for DERP recv loops to feed packets in).
|
||||
pub fn ingress_sender(&self) -> mpsc::Sender<IngressPacket> {
|
||||
self.ingress_tx.clone()
|
||||
}
|
||||
|
||||
/// Process one ingress packet: decapsulate and return any decrypted IP packet.
|
||||
/// If decapsulation produces a response (handshake, cookie), it is sent back
|
||||
/// to the peer over the appropriate transport.
|
||||
pub fn process_ingress(&mut self, packet: IngressPacket) -> Option<Vec<u8>> {
|
||||
match packet {
|
||||
IngressPacket::Udp { src, data } => {
|
||||
// We need to figure out which peer sent this. For UDP, we look up
|
||||
// the peer by source address. If we can't find it, try all peers.
|
||||
let peer_key = self.find_peer_by_endpoint(src);
|
||||
let key = match peer_key {
|
||||
Some(k) => k,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
match self.tunnel.decapsulate(&key, &data) {
|
||||
DecapAction::Packet(ip_packet) => Some(ip_packet),
|
||||
DecapAction::Response(response) => {
|
||||
// Queue the response to send back — best-effort.
|
||||
self.try_send_udp_sync(src, &response);
|
||||
None
|
||||
}
|
||||
DecapAction::Nothing => None,
|
||||
}
|
||||
}
|
||||
IngressPacket::Derp { src_key, data } => {
|
||||
match self.tunnel.decapsulate(&src_key, &data) {
|
||||
DecapAction::Packet(ip_packet) => Some(ip_packet),
|
||||
DecapAction::Response(_response) => {
|
||||
// DERP responses would need to be sent back via DERP.
|
||||
// The caller should handle this via the DERP send path.
|
||||
None
|
||||
}
|
||||
DecapAction::Nothing => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulate an outbound IP packet and send it via the appropriate transport.
|
||||
pub async fn send_outbound(&mut self, ip_packet: &[u8]) -> crate::Result<()> {
|
||||
// Parse the destination IP from the IP packet header.
|
||||
let dst_ip = match parse_dst_ip(ip_packet) {
|
||||
Some(ip) => ip,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
match self.tunnel.encapsulate(dst_ip, ip_packet) {
|
||||
EncapAction::SendUdp { endpoint, data } => {
|
||||
if let Some(ref udp) = self.udp {
|
||||
udp.send_to(&data, endpoint).await.map_err(|e| {
|
||||
crate::Error::WireGuard(format!("UDP send to {endpoint}: {e}"))
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
EncapAction::SendDerp { dest_key: _, data: _ } => {
|
||||
// DERP sending would be handled by the caller via a DERP client.
|
||||
// This layer just signals the intent; the daemon wires it up.
|
||||
Ok(())
|
||||
}
|
||||
EncapAction::Nothing => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run timer ticks on all tunnels, send any resulting packets.
|
||||
pub async fn tick(&mut self) -> crate::Result<()> {
|
||||
let actions = self.tunnel.tick();
|
||||
for ta in actions {
|
||||
match ta.action {
|
||||
EncapAction::SendUdp { endpoint, data } => {
|
||||
if let Some(ref udp) = self.udp {
|
||||
let _ = udp.send_to(&data, endpoint).await;
|
||||
}
|
||||
}
|
||||
EncapAction::SendDerp { .. } => {
|
||||
// DERP timer packets handled by daemon layer.
|
||||
}
|
||||
EncapAction::Nothing => {}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update peers from netmap.
|
||||
pub fn update_peers(&mut self, peers: &[crate::proto::types::Node]) {
|
||||
self.tunnel.update_peers(peers);
|
||||
}
|
||||
|
||||
/// Find a peer key by its known UDP endpoint.
|
||||
fn find_peer_by_endpoint(&self, _src: SocketAddr) -> Option<[u8; 32]> {
|
||||
// In a full implementation, we'd maintain an endpoint→key index.
|
||||
// For now, we expose internal peer iteration via the tunnel.
|
||||
// This is a simplification — a production router would have a reverse map.
|
||||
// We try all peers since the tunnel has the endpoint info.
|
||||
None
|
||||
}
|
||||
|
||||
/// Best-effort synchronous UDP send (used in process_ingress which is sync).
|
||||
fn try_send_udp_sync(&self, _dst: SocketAddr, _data: &[u8]) {
|
||||
// UdpSocket::try_send_to is not available on tokio UdpSocket without
|
||||
// prior connect. In a real implementation, we'd use a std UdpSocket
|
||||
// or queue the response for async sending. For now, responses from
|
||||
// process_ingress are dropped — the caller should handle them.
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the destination IP address from a raw IP packet.
|
||||
fn parse_dst_ip(packet: &[u8]) -> Option<std::net::IpAddr> {
|
||||
if packet.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let version = packet[0] >> 4;
|
||||
match version {
|
||||
4 if packet.len() >= 20 => {
|
||||
let dst = std::net::Ipv4Addr::new(packet[16], packet[17], packet[18], packet[19]);
|
||||
Some(std::net::IpAddr::V4(dst))
|
||||
}
|
||||
6 if packet.len() >= 40 => {
|
||||
let mut octets = [0u8; 16];
|
||||
octets.copy_from_slice(&packet[24..40]);
|
||||
Some(std::net::IpAddr::V6(std::net::Ipv6Addr::from(octets)))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use x25519_dalek::StaticSecret;
|
||||
|
||||
fn make_tunnel() -> WgTunnel {
|
||||
let secret = StaticSecret::random_from_rng(rand::rngs::OsRng);
|
||||
WgTunnel::new(secret)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_router() {
|
||||
let tunnel = make_tunnel();
|
||||
let router = PacketRouter::new(tunnel);
|
||||
// Verify the ingress channel is functional by checking the sender isn't closed.
|
||||
assert!(!router.ingress_tx.is_closed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ingress_sender_clone() {
|
||||
let tunnel = make_tunnel();
|
||||
let router = PacketRouter::new(tunnel);
|
||||
let sender1 = router.ingress_sender();
|
||||
let sender2 = router.ingress_sender();
|
||||
// Both senders should be live.
|
||||
assert!(!sender1.is_closed());
|
||||
assert!(!sender2.is_closed());
|
||||
}
|
||||
}
|
||||
252
sunbeam-net/src/wg/socket.rs
Normal file
252
sunbeam-net/src/wg/socket.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
use smoltcp::iface::{Config, Interface, PollResult, SocketHandle, SocketSet};
|
||||
use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
|
||||
use smoltcp::socket::tcp;
|
||||
use smoltcp::time::Instant;
|
||||
use smoltcp::wire::{IpAddress, IpCidr};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Virtual TCP/IP network backed by WireGuard.
|
||||
pub(crate) struct VirtualNetwork {
|
||||
iface: Interface,
|
||||
sockets: SocketSet<'static>,
|
||||
device: ChannelDevice,
|
||||
}
|
||||
|
||||
/// A smoltcp Device backed by mpsc channels.
|
||||
///
|
||||
/// The `try_recv` / `try_send` methods are used because smoltcp's Device trait
|
||||
/// is synchronous and poll-based.
|
||||
struct ChannelDevice {
|
||||
rx: mpsc::Receiver<Vec<u8>>,
|
||||
tx: mpsc::Sender<Vec<u8>>,
|
||||
mtu: usize,
|
||||
}
|
||||
|
||||
impl Device for ChannelDevice {
|
||||
type RxToken<'a> = ChannelRxToken where Self: 'a;
|
||||
type TxToken<'a> = ChannelTxToken<'a> where Self: 'a;
|
||||
|
||||
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
|
||||
match self.rx.try_recv() {
|
||||
Ok(data) => {
|
||||
let rx = ChannelRxToken { data };
|
||||
let tx = ChannelTxToken { sender: &self.tx };
|
||||
Some((rx, tx))
|
||||
}
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> {
|
||||
Some(ChannelTxToken { sender: &self.tx })
|
||||
}
|
||||
|
||||
fn capabilities(&self) -> DeviceCapabilities {
|
||||
let mut caps = DeviceCapabilities::default();
|
||||
caps.medium = Medium::Ip;
|
||||
caps.max_transmission_unit = self.mtu;
|
||||
caps
|
||||
}
|
||||
}
|
||||
|
||||
struct ChannelRxToken {
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl RxToken for ChannelRxToken {
|
||||
fn consume<R, F>(self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&[u8]) -> R,
|
||||
{
|
||||
f(&self.data)
|
||||
}
|
||||
}
|
||||
|
||||
struct ChannelTxToken<'a> {
|
||||
sender: &'a mpsc::Sender<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl<'a> TxToken for ChannelTxToken<'a> {
|
||||
fn consume<R, F>(self, len: usize, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut [u8]) -> R,
|
||||
{
|
||||
let mut buf = vec![0u8; len];
|
||||
let result = f(&mut buf);
|
||||
// Best-effort send — if the channel is full, drop the packet.
|
||||
let _ = self.sender.try_send(buf);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle to a TCP socket within the virtual network.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) struct TcpSocketHandle(SocketHandle);
|
||||
|
||||
impl VirtualNetwork {
|
||||
/// Create a new virtual network with the given local IP (from tailnet assignment).
|
||||
///
|
||||
/// `rx_from_wg` receives decrypted IP packets from the WireGuard tunnel.
|
||||
/// `tx_to_wg` sends IP packets back to the WireGuard tunnel for encryption.
|
||||
pub fn new(
|
||||
local_ip: IpAddress,
|
||||
prefix_len: u8,
|
||||
rx_from_wg: mpsc::Receiver<Vec<u8>>,
|
||||
tx_to_wg: mpsc::Sender<Vec<u8>>,
|
||||
) -> crate::Result<Self> {
|
||||
let mut device = ChannelDevice {
|
||||
rx: rx_from_wg,
|
||||
tx: tx_to_wg,
|
||||
mtu: 1420, // Standard WireGuard MTU
|
||||
};
|
||||
|
||||
let config = Config::new(smoltcp::wire::HardwareAddress::Ip);
|
||||
let now = Instant::from_millis(
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as i64,
|
||||
);
|
||||
|
||||
let mut iface = Interface::new(config, &mut device, now);
|
||||
iface.update_ip_addrs(|addrs| {
|
||||
addrs
|
||||
.push(IpCidr::new(local_ip, prefix_len))
|
||||
.expect("interface IP address capacity exceeded");
|
||||
});
|
||||
|
||||
let sockets = SocketSet::new(vec![]);
|
||||
|
||||
Ok(Self {
|
||||
iface,
|
||||
sockets,
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
/// Poll the network stack. Returns true if any socket state may have changed.
|
||||
pub fn poll(&mut self) -> bool {
|
||||
let now = Instant::from_millis(
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as i64,
|
||||
);
|
||||
self.iface.poll(now, &mut self.device, &mut self.sockets) == PollResult::SocketStateChanged
|
||||
}
|
||||
|
||||
/// Open a TCP connection to a remote address through the virtual network.
|
||||
pub fn tcp_connect(
|
||||
&mut self,
|
||||
remote: std::net::SocketAddr,
|
||||
) -> crate::Result<TcpSocketHandle> {
|
||||
let rx_buf = tcp::SocketBuffer::new(vec![0u8; 65536]);
|
||||
let tx_buf = tcp::SocketBuffer::new(vec![0u8; 65536]);
|
||||
let mut socket = tcp::Socket::new(rx_buf, tx_buf);
|
||||
|
||||
let remote_ep = match remote {
|
||||
std::net::SocketAddr::V4(v4) => {
|
||||
smoltcp::wire::IpEndpoint::new(
|
||||
IpAddress::Ipv4(smoltcp::wire::Ipv4Address::from(*v4.ip())),
|
||||
v4.port(),
|
||||
)
|
||||
}
|
||||
std::net::SocketAddr::V6(v6) => {
|
||||
smoltcp::wire::IpEndpoint::new(
|
||||
IpAddress::Ipv6(smoltcp::wire::Ipv6Address::from(*v6.ip())),
|
||||
v6.port(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let cx = self.iface.context();
|
||||
// Use an ephemeral local port.
|
||||
let local_port = ephemeral_port();
|
||||
socket
|
||||
.connect(cx, remote_ep, local_port)
|
||||
.map_err(|e| crate::Error::WireGuard(format!("TCP connect: {e:?}")))?;
|
||||
|
||||
let handle = self.sockets.add(socket);
|
||||
Ok(TcpSocketHandle(handle))
|
||||
}
|
||||
|
||||
/// Check if a TCP connection is established and ready.
|
||||
pub fn tcp_is_active(&self, handle: TcpSocketHandle) -> bool {
|
||||
let socket = self.sockets.get::<tcp::Socket>(handle.0);
|
||||
socket.is_active()
|
||||
}
|
||||
|
||||
/// Read data from a TCP socket.
|
||||
pub fn tcp_recv(
|
||||
&mut self,
|
||||
handle: TcpSocketHandle,
|
||||
buf: &mut [u8],
|
||||
) -> crate::Result<usize> {
|
||||
let socket = self.sockets.get_mut::<tcp::Socket>(handle.0);
|
||||
socket
|
||||
.recv_slice(buf)
|
||||
.map_err(|e| crate::Error::WireGuard(format!("TCP recv: {e:?}")))
|
||||
}
|
||||
|
||||
/// Write data to a TCP socket.
|
||||
pub fn tcp_send(
|
||||
&mut self,
|
||||
handle: TcpSocketHandle,
|
||||
data: &[u8],
|
||||
) -> crate::Result<usize> {
|
||||
let socket = self.sockets.get_mut::<tcp::Socket>(handle.0);
|
||||
socket
|
||||
.send_slice(data)
|
||||
.map_err(|e| crate::Error::WireGuard(format!("TCP send: {e:?}")))
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple ephemeral port allocator using a thread-local counter.
|
||||
fn ephemeral_port() -> u16 {
|
||||
use std::sync::atomic::{AtomicU16, Ordering};
|
||||
static PORT: AtomicU16 = AtomicU16::new(49152);
|
||||
let port = PORT.fetch_add(1, Ordering::Relaxed);
|
||||
if port == 0 {
|
||||
PORT.store(49153, Ordering::Relaxed);
|
||||
49152
|
||||
} else {
|
||||
port
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_virtual_network_creation() {
|
||||
let (tx_to_wg, _rx) = mpsc::channel(16);
|
||||
let (_tx, rx_from_wg) = mpsc::channel(16);
|
||||
|
||||
let net = VirtualNetwork::new(
|
||||
IpAddress::v4(100, 64, 0, 1),
|
||||
32,
|
||||
rx_from_wg,
|
||||
tx_to_wg,
|
||||
);
|
||||
assert!(net.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_poll_empty() {
|
||||
let (tx_to_wg, _rx) = mpsc::channel(16);
|
||||
let (_tx, rx_from_wg) = mpsc::channel(16);
|
||||
|
||||
let mut net = VirtualNetwork::new(
|
||||
IpAddress::v4(100, 64, 0, 1),
|
||||
32,
|
||||
rx_from_wg,
|
||||
tx_to_wg,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Polling with no packets should return false (no state changes).
|
||||
let changed = net.poll();
|
||||
assert!(!changed);
|
||||
}
|
||||
}
|
||||
383
sunbeam-net/src/wg/tunnel.rs
Normal file
383
sunbeam-net/src/wg/tunnel.rs
Normal file
@@ -0,0 +1,383 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use ipnet::IpNet;
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
/// Manages WireGuard tunnels for all peers.
|
||||
pub(crate) struct WgTunnel {
|
||||
private_key: StaticSecret,
|
||||
peers: HashMap<[u8; 32], PeerTunnel>,
|
||||
/// Index counter for boringtun tunnel creation.
|
||||
next_index: u32,
|
||||
}
|
||||
|
||||
struct PeerTunnel {
|
||||
tunn: Tunn,
|
||||
endpoint: Option<SocketAddr>,
|
||||
derp_region: Option<u16>,
|
||||
allowed_ips: Vec<IpNet>,
|
||||
}
|
||||
|
||||
/// Result of encapsulating an outbound IP packet.
|
||||
pub(crate) enum EncapAction {
|
||||
/// Send these bytes over UDP to the given endpoint.
|
||||
SendUdp { endpoint: SocketAddr, data: Vec<u8> },
|
||||
/// Send these bytes via DERP relay.
|
||||
SendDerp { dest_key: [u8; 32], data: Vec<u8> },
|
||||
/// Nothing to send (handshake pending, etc.)
|
||||
Nothing,
|
||||
}
|
||||
|
||||
/// Result of decapsulating an inbound WireGuard packet.
|
||||
pub(crate) enum DecapAction {
|
||||
/// Decrypted IP packet ready for the virtual network stack.
|
||||
Packet(Vec<u8>),
|
||||
/// Need to send a response (handshake response, cookie, etc.)
|
||||
Response(Vec<u8>),
|
||||
/// Nothing (keep-alive, etc.)
|
||||
Nothing,
|
||||
}
|
||||
|
||||
pub(crate) struct TimerAction {
|
||||
pub peer_key: [u8; 32],
|
||||
pub action: EncapAction,
|
||||
}
|
||||
|
||||
/// Size of the scratch buffer for boringtun operations.
|
||||
const BUF_SIZE: usize = 65536;
|
||||
|
||||
impl WgTunnel {
|
||||
pub fn new(private_key: StaticSecret) -> Self {
|
||||
Self {
|
||||
private_key,
|
||||
peers: HashMap::new(),
|
||||
next_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the peer table from a network map.
|
||||
/// Adds new peers, removes stale ones, updates endpoints.
|
||||
pub fn update_peers(&mut self, peers: &[crate::proto::types::Node]) {
|
||||
// Collect the set of peer keys we should have.
|
||||
let mut desired_keys: HashMap<[u8; 32], &crate::proto::types::Node> = HashMap::new();
|
||||
for node in peers {
|
||||
if let Some(key) = parse_node_key(&node.key) {
|
||||
desired_keys.insert(key, node);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove peers not in the desired set.
|
||||
self.peers.retain(|k, _| desired_keys.contains_key(k));
|
||||
|
||||
// Add or update peers.
|
||||
for (key_bytes, node) in &desired_keys {
|
||||
if self.peers.contains_key(key_bytes) {
|
||||
// Update endpoint and DERP region on existing peer.
|
||||
let peer = self.peers.get_mut(key_bytes).unwrap();
|
||||
peer.endpoint = parse_first_endpoint(&node.endpoints);
|
||||
peer.derp_region = parse_derp_region(&node.derp);
|
||||
peer.allowed_ips = parse_allowed_ips(&node.allowed_ips);
|
||||
} else {
|
||||
// Create a new boringtun tunnel for this peer.
|
||||
let peer_public = PublicKey::from(*key_bytes);
|
||||
let index = self.next_index;
|
||||
self.next_index = self.next_index.wrapping_add(1);
|
||||
|
||||
let tunn = Tunn::new(
|
||||
self.private_key.clone(),
|
||||
peer_public,
|
||||
None, // no preshared key
|
||||
Some(25), // 25 second persistent keepalive
|
||||
index,
|
||||
None, // no rate limiter
|
||||
);
|
||||
|
||||
self.peers.insert(*key_bytes, PeerTunnel {
|
||||
tunn,
|
||||
endpoint: parse_first_endpoint(&node.endpoints),
|
||||
derp_region: parse_derp_region(&node.derp),
|
||||
allowed_ips: parse_allowed_ips(&node.allowed_ips),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulate an outbound IP packet for transmission.
|
||||
pub fn encapsulate(&mut self, dst_ip: IpAddr, payload: &[u8]) -> EncapAction {
|
||||
let peer_key = match self.find_peer_for_ip(dst_ip) {
|
||||
Some(k) => *k,
|
||||
None => return EncapAction::Nothing,
|
||||
};
|
||||
|
||||
let peer = match self.peers.get_mut(&peer_key) {
|
||||
Some(p) => p,
|
||||
None => return EncapAction::Nothing,
|
||||
};
|
||||
|
||||
let mut buf = vec![0u8; BUF_SIZE];
|
||||
match peer.tunn.encapsulate(payload, &mut buf) {
|
||||
TunnResult::WriteToNetwork(data) => {
|
||||
let packet = data.to_vec();
|
||||
route_packet(peer, &peer_key, packet)
|
||||
}
|
||||
TunnResult::Err(_) | TunnResult::Done => EncapAction::Nothing,
|
||||
// These shouldn't happen during encapsulate, but handle gracefully.
|
||||
TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => {
|
||||
EncapAction::Nothing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decapsulate an inbound WireGuard packet.
|
||||
pub fn decapsulate(&mut self, peer_key: &[u8; 32], packet: &[u8]) -> DecapAction {
|
||||
let peer = match self.peers.get_mut(peer_key) {
|
||||
Some(p) => p,
|
||||
None => return DecapAction::Nothing,
|
||||
};
|
||||
|
||||
let mut buf = vec![0u8; BUF_SIZE];
|
||||
let result = peer.tunn.decapsulate(None, packet, &mut buf);
|
||||
|
||||
match result {
|
||||
TunnResult::WriteToTunnelV4(data, _addr) => DecapAction::Packet(data.to_vec()),
|
||||
TunnResult::WriteToTunnelV6(data, _addr) => DecapAction::Packet(data.to_vec()),
|
||||
TunnResult::WriteToNetwork(data) => {
|
||||
// This is a handshake response or cookie that needs to be sent back.
|
||||
let response = data.to_vec();
|
||||
// Check for chained results — loop to drain.
|
||||
let mut chain_buf = vec![0u8; BUF_SIZE];
|
||||
loop {
|
||||
match peer.tunn.decapsulate(None, &[], &mut chain_buf) {
|
||||
TunnResult::WriteToNetwork(more) => {
|
||||
// Multiple chained responses; we return the first.
|
||||
let _ = more;
|
||||
}
|
||||
TunnResult::Done => break,
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
DecapAction::Response(response)
|
||||
}
|
||||
TunnResult::Err(_) | TunnResult::Done => DecapAction::Nothing,
|
||||
}
|
||||
}
|
||||
|
||||
/// Tick all tunnels for timer-based actions (keepalives, handshake retries).
|
||||
pub fn tick(&mut self) -> Vec<TimerAction> {
|
||||
let mut actions = Vec::new();
|
||||
let peer_keys: Vec<[u8; 32]> = self.peers.keys().copied().collect();
|
||||
|
||||
for peer_key in peer_keys {
|
||||
let peer = match self.peers.get_mut(&peer_key) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let mut buf = vec![0u8; BUF_SIZE];
|
||||
let result = peer.tunn.update_timers(&mut buf);
|
||||
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(data) => {
|
||||
let packet = data.to_vec();
|
||||
let action = route_packet(peer, &peer_key, packet);
|
||||
actions.push(TimerAction {
|
||||
peer_key,
|
||||
action,
|
||||
});
|
||||
}
|
||||
TunnResult::Err(_) | TunnResult::Done => {}
|
||||
TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => {}
|
||||
}
|
||||
}
|
||||
|
||||
actions
|
||||
}
|
||||
|
||||
/// Find which peer owns a given IP address.
|
||||
fn find_peer_for_ip(&self, ip: IpAddr) -> Option<&[u8; 32]> {
|
||||
for (key, peer) in &self.peers {
|
||||
for net in &peer.allowed_ips {
|
||||
if net.contains(&ip) {
|
||||
return Some(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Expose peer count for testing.
|
||||
#[cfg(test)]
|
||||
fn peer_count(&self) -> usize {
|
||||
self.peers.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Decide how to route a WireGuard packet for a given peer.
|
||||
fn route_packet(peer: &PeerTunnel, peer_key: &[u8; 32], data: Vec<u8>) -> EncapAction {
|
||||
if let Some(endpoint) = peer.endpoint {
|
||||
EncapAction::SendUdp { endpoint, data }
|
||||
} else if peer.derp_region.is_some() {
|
||||
EncapAction::SendDerp {
|
||||
dest_key: *peer_key,
|
||||
data,
|
||||
}
|
||||
} else {
|
||||
EncapAction::Nothing
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a "nodekey:<hex>" string into raw 32-byte key.
|
||||
fn parse_node_key(s: &str) -> Option<[u8; 32]> {
|
||||
let hex_str = s.strip_prefix("nodekey:")?;
|
||||
let bytes = hex_decode(hex_str)?;
|
||||
if bytes.len() != 32 {
|
||||
return None;
|
||||
}
|
||||
let mut arr = [0u8; 32];
|
||||
arr.copy_from_slice(&bytes);
|
||||
Some(arr)
|
||||
}
|
||||
|
||||
/// Parse the first endpoint from a list of "ip:port" strings.
|
||||
fn parse_first_endpoint(endpoints: &[String]) -> Option<SocketAddr> {
|
||||
endpoints.first().and_then(|s| s.parse().ok())
|
||||
}
|
||||
|
||||
/// Parse DERP region from a DERP string like "127.3.3.40:<region_id>".
|
||||
fn parse_derp_region(derp: &str) -> Option<u16> {
|
||||
// Tailscale/Headscale DERP format: "127.3.3.40:<region_id>"
|
||||
let parts: Vec<&str> = derp.split(':').collect();
|
||||
if parts.len() == 2 {
|
||||
parts[1].parse().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse allowed IP strings into IpNet values.
|
||||
fn parse_allowed_ips(ips: &[String]) -> Vec<IpNet> {
|
||||
ips.iter().filter_map(|s| s.parse().ok()).collect()
|
||||
}
|
||||
|
||||
/// Simple hex decoder.
|
||||
fn hex_decode(s: &str) -> Option<Vec<u8>> {
|
||||
if s.len() % 2 != 0 {
|
||||
return None;
|
||||
}
|
||||
(0..s.len())
|
||||
.step_by(2)
|
||||
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).ok())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Simple hex encoder.
|
||||
fn hex_encode(bytes: &[u8]) -> String {
|
||||
bytes.iter().map(|b| format!("{b:02x}")).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::Ipv4Addr;
|
||||
use super::*;
|
||||
use crate::proto::types::{HostInfo, Node};
|
||||
|
||||
fn test_hostinfo() -> HostInfo {
|
||||
HostInfo {
|
||||
go_arch: "arm64".into(),
|
||||
go_os: "linux".into(),
|
||||
go_version: "sunbeam-net/0.1".into(),
|
||||
hostname: "test".into(),
|
||||
os: "linux".into(),
|
||||
os_version: "6.1".into(),
|
||||
device_model: None,
|
||||
frontend_log_id: None,
|
||||
backend_log_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_peer_node(key_bytes: &[u8; 32], allowed_ips: Vec<&str>) -> Node {
|
||||
Node {
|
||||
id: 1,
|
||||
key: format!("nodekey:{}", hex_encode(key_bytes)),
|
||||
disco_key: "discokey:0000000000000000000000000000000000000000000000000000000000000000".into(),
|
||||
addresses: vec!["100.64.0.2/32".into()],
|
||||
allowed_ips: allowed_ips.into_iter().map(String::from).collect(),
|
||||
endpoints: vec!["1.2.3.4:41641".into()],
|
||||
derp: "127.3.3.40:1".into(),
|
||||
hostinfo: test_hostinfo(),
|
||||
name: "peer.example.com".into(),
|
||||
online: Some(true),
|
||||
machine_authorized: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_key() -> ([u8; 32], StaticSecret) {
|
||||
let secret = StaticSecret::random_from_rng(rand::rngs::OsRng);
|
||||
let public = PublicKey::from(&secret);
|
||||
(*public.as_bytes(), secret)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_tunnel() {
|
||||
let (_, secret) = generate_key();
|
||||
let tunnel = WgTunnel::new(secret);
|
||||
assert_eq!(tunnel.peer_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_peers_adds_peer() {
|
||||
let (_, my_secret) = generate_key();
|
||||
let mut tunnel = WgTunnel::new(my_secret);
|
||||
|
||||
let (peer_pub, _peer_secret) = generate_key();
|
||||
let node = make_peer_node(&peer_pub, vec!["100.64.0.2/32"]);
|
||||
|
||||
tunnel.update_peers(&[node]);
|
||||
assert_eq!(tunnel.peer_count(), 1);
|
||||
assert!(tunnel.peers.contains_key(&peer_pub));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_peers_removes_stale() {
|
||||
let (_, my_secret) = generate_key();
|
||||
let mut tunnel = WgTunnel::new(my_secret);
|
||||
|
||||
let (peer_pub, _) = generate_key();
|
||||
let node = make_peer_node(&peer_pub, vec!["100.64.0.2/32"]);
|
||||
tunnel.update_peers(&[node]);
|
||||
assert_eq!(tunnel.peer_count(), 1);
|
||||
|
||||
// Update with empty list — stale peer should be removed.
|
||||
tunnel.update_peers(&[]);
|
||||
assert_eq!(tunnel.peer_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_peer_for_ip() {
|
||||
let (_, my_secret) = generate_key();
|
||||
let mut tunnel = WgTunnel::new(my_secret);
|
||||
|
||||
let (peer_pub, _) = generate_key();
|
||||
let node = make_peer_node(&peer_pub, vec!["100.64.0.0/24"]);
|
||||
tunnel.update_peers(&[node]);
|
||||
|
||||
let found = tunnel.find_peer_for_ip(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 5)));
|
||||
assert_eq!(found, Some(&peer_pub));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_peer_no_match() {
|
||||
let (_, my_secret) = generate_key();
|
||||
let mut tunnel = WgTunnel::new(my_secret);
|
||||
|
||||
let (peer_pub, _) = generate_key();
|
||||
let node = make_peer_node(&peer_pub, vec!["100.64.0.0/24"]);
|
||||
tunnel.update_peers(&[node]);
|
||||
|
||||
// IP outside the allowed range.
|
||||
let found = tunnel.find_peer_for_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)));
|
||||
assert!(found.is_none());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user