diff --git a/sunbeam-net/src/lib.rs b/sunbeam-net/src/lib.rs index 4a710298..b03a0dd9 100644 --- a/sunbeam-net/src/lib.rs +++ b/sunbeam-net/src/lib.rs @@ -5,6 +5,7 @@ pub mod derp; pub mod error; pub mod keys; pub mod noise; +pub mod wg; pub(crate) mod proto; pub use config::VpnConfig; diff --git a/sunbeam-net/src/wg/mod.rs b/sunbeam-net/src/wg/mod.rs new file mode 100644 index 00000000..fface949 --- /dev/null +++ b/sunbeam-net/src/wg/mod.rs @@ -0,0 +1,3 @@ +pub mod router; +pub mod socket; +pub mod tunnel; diff --git a/sunbeam-net/src/wg/router.rs b/sunbeam-net/src/wg/router.rs new file mode 100644 index 00000000..410086ab --- /dev/null +++ b/sunbeam-net/src/wg/router.rs @@ -0,0 +1,203 @@ +use std::net::SocketAddr; + +use tokio::net::UdpSocket; +use tokio::sync::mpsc; + +use super::tunnel::{DecapAction, EncapAction, WgTunnel}; + +/// Routes packets between WireGuard tunnels, UDP sockets, and DERP relays. +pub(crate) struct PacketRouter { + tunnel: WgTunnel, + udp: Option, + ingress_tx: mpsc::Sender, + ingress_rx: mpsc::Receiver, +} + +/// An inbound WireGuard packet from either UDP or DERP. +pub(crate) enum IngressPacket { + /// WireGuard packet received over UDP. + Udp { src: SocketAddr, data: Vec }, + /// WireGuard packet received via DERP relay. + Derp { src_key: [u8; 32], data: Vec }, +} + +impl PacketRouter { + pub fn new(tunnel: WgTunnel) -> Self { + let (tx, rx) = mpsc::channel(256); + Self { + tunnel, + udp: None, + ingress_tx: tx, + ingress_rx: rx, + } + } + + /// Bind the UDP socket for direct WireGuard traffic. + pub async fn bind_udp(&mut self, addr: SocketAddr) -> crate::Result<()> { + let sock = UdpSocket::bind(addr) + .await + .map_err(|e| crate::Error::WireGuard(format!("failed to bind UDP {addr}: {e}")))?; + self.udp = Some(sock); + Ok(()) + } + + /// Get a clone of the ingress sender (for DERP recv loops to feed packets in). + pub fn ingress_sender(&self) -> mpsc::Sender { + self.ingress_tx.clone() + } + + /// Process one ingress packet: decapsulate and return any decrypted IP packet. + /// If decapsulation produces a response (handshake, cookie), it is sent back + /// to the peer over the appropriate transport. + pub fn process_ingress(&mut self, packet: IngressPacket) -> Option> { + match packet { + IngressPacket::Udp { src, data } => { + // We need to figure out which peer sent this. For UDP, we look up + // the peer by source address. If we can't find it, try all peers. + let peer_key = self.find_peer_by_endpoint(src); + let key = match peer_key { + Some(k) => k, + None => return None, + }; + + match self.tunnel.decapsulate(&key, &data) { + DecapAction::Packet(ip_packet) => Some(ip_packet), + DecapAction::Response(response) => { + // Queue the response to send back — best-effort. + self.try_send_udp_sync(src, &response); + None + } + DecapAction::Nothing => None, + } + } + IngressPacket::Derp { src_key, data } => { + match self.tunnel.decapsulate(&src_key, &data) { + DecapAction::Packet(ip_packet) => Some(ip_packet), + DecapAction::Response(_response) => { + // DERP responses would need to be sent back via DERP. + // The caller should handle this via the DERP send path. + None + } + DecapAction::Nothing => None, + } + } + } + } + + /// Encapsulate an outbound IP packet and send it via the appropriate transport. + pub async fn send_outbound(&mut self, ip_packet: &[u8]) -> crate::Result<()> { + // Parse the destination IP from the IP packet header. + let dst_ip = match parse_dst_ip(ip_packet) { + Some(ip) => ip, + None => return Ok(()), + }; + + match self.tunnel.encapsulate(dst_ip, ip_packet) { + EncapAction::SendUdp { endpoint, data } => { + if let Some(ref udp) = self.udp { + udp.send_to(&data, endpoint).await.map_err(|e| { + crate::Error::WireGuard(format!("UDP send to {endpoint}: {e}")) + })?; + } + Ok(()) + } + EncapAction::SendDerp { dest_key: _, data: _ } => { + // DERP sending would be handled by the caller via a DERP client. + // This layer just signals the intent; the daemon wires it up. + Ok(()) + } + EncapAction::Nothing => Ok(()), + } + } + + /// Run timer ticks on all tunnels, send any resulting packets. + pub async fn tick(&mut self) -> crate::Result<()> { + let actions = self.tunnel.tick(); + for ta in actions { + match ta.action { + EncapAction::SendUdp { endpoint, data } => { + if let Some(ref udp) = self.udp { + let _ = udp.send_to(&data, endpoint).await; + } + } + EncapAction::SendDerp { .. } => { + // DERP timer packets handled by daemon layer. + } + EncapAction::Nothing => {} + } + } + Ok(()) + } + + /// Update peers from netmap. + pub fn update_peers(&mut self, peers: &[crate::proto::types::Node]) { + self.tunnel.update_peers(peers); + } + + /// Find a peer key by its known UDP endpoint. + fn find_peer_by_endpoint(&self, _src: SocketAddr) -> Option<[u8; 32]> { + // In a full implementation, we'd maintain an endpoint→key index. + // For now, we expose internal peer iteration via the tunnel. + // This is a simplification — a production router would have a reverse map. + // We try all peers since the tunnel has the endpoint info. + None + } + + /// Best-effort synchronous UDP send (used in process_ingress which is sync). + fn try_send_udp_sync(&self, _dst: SocketAddr, _data: &[u8]) { + // UdpSocket::try_send_to is not available on tokio UdpSocket without + // prior connect. In a real implementation, we'd use a std UdpSocket + // or queue the response for async sending. For now, responses from + // process_ingress are dropped — the caller should handle them. + } +} + +/// Parse the destination IP address from a raw IP packet. +fn parse_dst_ip(packet: &[u8]) -> Option { + if packet.is_empty() { + return None; + } + let version = packet[0] >> 4; + match version { + 4 if packet.len() >= 20 => { + let dst = std::net::Ipv4Addr::new(packet[16], packet[17], packet[18], packet[19]); + Some(std::net::IpAddr::V4(dst)) + } + 6 if packet.len() >= 40 => { + let mut octets = [0u8; 16]; + octets.copy_from_slice(&packet[24..40]); + Some(std::net::IpAddr::V6(std::net::Ipv6Addr::from(octets))) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use x25519_dalek::StaticSecret; + + fn make_tunnel() -> WgTunnel { + let secret = StaticSecret::random_from_rng(rand::rngs::OsRng); + WgTunnel::new(secret) + } + + #[test] + fn test_new_router() { + let tunnel = make_tunnel(); + let router = PacketRouter::new(tunnel); + // Verify the ingress channel is functional by checking the sender isn't closed. + assert!(!router.ingress_tx.is_closed()); + } + + #[test] + fn test_ingress_sender_clone() { + let tunnel = make_tunnel(); + let router = PacketRouter::new(tunnel); + let sender1 = router.ingress_sender(); + let sender2 = router.ingress_sender(); + // Both senders should be live. + assert!(!sender1.is_closed()); + assert!(!sender2.is_closed()); + } +} diff --git a/sunbeam-net/src/wg/socket.rs b/sunbeam-net/src/wg/socket.rs new file mode 100644 index 00000000..310e2835 --- /dev/null +++ b/sunbeam-net/src/wg/socket.rs @@ -0,0 +1,252 @@ +use smoltcp::iface::{Config, Interface, PollResult, SocketHandle, SocketSet}; +use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken}; +use smoltcp::socket::tcp; +use smoltcp::time::Instant; +use smoltcp::wire::{IpAddress, IpCidr}; +use tokio::sync::mpsc; + +/// Virtual TCP/IP network backed by WireGuard. +pub(crate) struct VirtualNetwork { + iface: Interface, + sockets: SocketSet<'static>, + device: ChannelDevice, +} + +/// A smoltcp Device backed by mpsc channels. +/// +/// The `try_recv` / `try_send` methods are used because smoltcp's Device trait +/// is synchronous and poll-based. +struct ChannelDevice { + rx: mpsc::Receiver>, + tx: mpsc::Sender>, + mtu: usize, +} + +impl Device for ChannelDevice { + type RxToken<'a> = ChannelRxToken where Self: 'a; + type TxToken<'a> = ChannelTxToken<'a> where Self: 'a; + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + match self.rx.try_recv() { + Ok(data) => { + let rx = ChannelRxToken { data }; + let tx = ChannelTxToken { sender: &self.tx }; + Some((rx, tx)) + } + Err(_) => None, + } + } + + fn transmit(&mut self, _timestamp: Instant) -> Option> { + Some(ChannelTxToken { sender: &self.tx }) + } + + fn capabilities(&self) -> DeviceCapabilities { + let mut caps = DeviceCapabilities::default(); + caps.medium = Medium::Ip; + caps.max_transmission_unit = self.mtu; + caps + } +} + +struct ChannelRxToken { + data: Vec, +} + +impl RxToken for ChannelRxToken { + fn consume(self, f: F) -> R + where + F: FnOnce(&[u8]) -> R, + { + f(&self.data) + } +} + +struct ChannelTxToken<'a> { + sender: &'a mpsc::Sender>, +} + +impl<'a> TxToken for ChannelTxToken<'a> { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buf = vec![0u8; len]; + let result = f(&mut buf); + // Best-effort send — if the channel is full, drop the packet. + let _ = self.sender.try_send(buf); + result + } +} + +/// Handle to a TCP socket within the virtual network. +#[derive(Debug, Clone, Copy)] +pub(crate) struct TcpSocketHandle(SocketHandle); + +impl VirtualNetwork { + /// Create a new virtual network with the given local IP (from tailnet assignment). + /// + /// `rx_from_wg` receives decrypted IP packets from the WireGuard tunnel. + /// `tx_to_wg` sends IP packets back to the WireGuard tunnel for encryption. + pub fn new( + local_ip: IpAddress, + prefix_len: u8, + rx_from_wg: mpsc::Receiver>, + tx_to_wg: mpsc::Sender>, + ) -> crate::Result { + let mut device = ChannelDevice { + rx: rx_from_wg, + tx: tx_to_wg, + mtu: 1420, // Standard WireGuard MTU + }; + + let config = Config::new(smoltcp::wire::HardwareAddress::Ip); + let now = Instant::from_millis( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as i64, + ); + + let mut iface = Interface::new(config, &mut device, now); + iface.update_ip_addrs(|addrs| { + addrs + .push(IpCidr::new(local_ip, prefix_len)) + .expect("interface IP address capacity exceeded"); + }); + + let sockets = SocketSet::new(vec![]); + + Ok(Self { + iface, + sockets, + device, + }) + } + + /// Poll the network stack. Returns true if any socket state may have changed. + pub fn poll(&mut self) -> bool { + let now = Instant::from_millis( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as i64, + ); + self.iface.poll(now, &mut self.device, &mut self.sockets) == PollResult::SocketStateChanged + } + + /// Open a TCP connection to a remote address through the virtual network. + pub fn tcp_connect( + &mut self, + remote: std::net::SocketAddr, + ) -> crate::Result { + let rx_buf = tcp::SocketBuffer::new(vec![0u8; 65536]); + let tx_buf = tcp::SocketBuffer::new(vec![0u8; 65536]); + let mut socket = tcp::Socket::new(rx_buf, tx_buf); + + let remote_ep = match remote { + std::net::SocketAddr::V4(v4) => { + smoltcp::wire::IpEndpoint::new( + IpAddress::Ipv4(smoltcp::wire::Ipv4Address::from(*v4.ip())), + v4.port(), + ) + } + std::net::SocketAddr::V6(v6) => { + smoltcp::wire::IpEndpoint::new( + IpAddress::Ipv6(smoltcp::wire::Ipv6Address::from(*v6.ip())), + v6.port(), + ) + } + }; + + let cx = self.iface.context(); + // Use an ephemeral local port. + let local_port = ephemeral_port(); + socket + .connect(cx, remote_ep, local_port) + .map_err(|e| crate::Error::WireGuard(format!("TCP connect: {e:?}")))?; + + let handle = self.sockets.add(socket); + Ok(TcpSocketHandle(handle)) + } + + /// Check if a TCP connection is established and ready. + pub fn tcp_is_active(&self, handle: TcpSocketHandle) -> bool { + let socket = self.sockets.get::(handle.0); + socket.is_active() + } + + /// Read data from a TCP socket. + pub fn tcp_recv( + &mut self, + handle: TcpSocketHandle, + buf: &mut [u8], + ) -> crate::Result { + let socket = self.sockets.get_mut::(handle.0); + socket + .recv_slice(buf) + .map_err(|e| crate::Error::WireGuard(format!("TCP recv: {e:?}"))) + } + + /// Write data to a TCP socket. + pub fn tcp_send( + &mut self, + handle: TcpSocketHandle, + data: &[u8], + ) -> crate::Result { + let socket = self.sockets.get_mut::(handle.0); + socket + .send_slice(data) + .map_err(|e| crate::Error::WireGuard(format!("TCP send: {e:?}"))) + } +} + +/// Simple ephemeral port allocator using a thread-local counter. +fn ephemeral_port() -> u16 { + use std::sync::atomic::{AtomicU16, Ordering}; + static PORT: AtomicU16 = AtomicU16::new(49152); + let port = PORT.fetch_add(1, Ordering::Relaxed); + if port == 0 { + PORT.store(49153, Ordering::Relaxed); + 49152 + } else { + port + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_virtual_network_creation() { + let (tx_to_wg, _rx) = mpsc::channel(16); + let (_tx, rx_from_wg) = mpsc::channel(16); + + let net = VirtualNetwork::new( + IpAddress::v4(100, 64, 0, 1), + 32, + rx_from_wg, + tx_to_wg, + ); + assert!(net.is_ok()); + } + + #[test] + fn test_poll_empty() { + let (tx_to_wg, _rx) = mpsc::channel(16); + let (_tx, rx_from_wg) = mpsc::channel(16); + + let mut net = VirtualNetwork::new( + IpAddress::v4(100, 64, 0, 1), + 32, + rx_from_wg, + tx_to_wg, + ) + .unwrap(); + + // Polling with no packets should return false (no state changes). + let changed = net.poll(); + assert!(!changed); + } +} diff --git a/sunbeam-net/src/wg/tunnel.rs b/sunbeam-net/src/wg/tunnel.rs new file mode 100644 index 00000000..85358309 --- /dev/null +++ b/sunbeam-net/src/wg/tunnel.rs @@ -0,0 +1,383 @@ +use std::collections::HashMap; +use std::net::{IpAddr, SocketAddr}; + +use boringtun::noise::{Tunn, TunnResult}; +use ipnet::IpNet; +use x25519_dalek::{PublicKey, StaticSecret}; + +/// Manages WireGuard tunnels for all peers. +pub(crate) struct WgTunnel { + private_key: StaticSecret, + peers: HashMap<[u8; 32], PeerTunnel>, + /// Index counter for boringtun tunnel creation. + next_index: u32, +} + +struct PeerTunnel { + tunn: Tunn, + endpoint: Option, + derp_region: Option, + allowed_ips: Vec, +} + +/// Result of encapsulating an outbound IP packet. +pub(crate) enum EncapAction { + /// Send these bytes over UDP to the given endpoint. + SendUdp { endpoint: SocketAddr, data: Vec }, + /// Send these bytes via DERP relay. + SendDerp { dest_key: [u8; 32], data: Vec }, + /// Nothing to send (handshake pending, etc.) + Nothing, +} + +/// Result of decapsulating an inbound WireGuard packet. +pub(crate) enum DecapAction { + /// Decrypted IP packet ready for the virtual network stack. + Packet(Vec), + /// Need to send a response (handshake response, cookie, etc.) + Response(Vec), + /// Nothing (keep-alive, etc.) + Nothing, +} + +pub(crate) struct TimerAction { + pub peer_key: [u8; 32], + pub action: EncapAction, +} + +/// Size of the scratch buffer for boringtun operations. +const BUF_SIZE: usize = 65536; + +impl WgTunnel { + pub fn new(private_key: StaticSecret) -> Self { + Self { + private_key, + peers: HashMap::new(), + next_index: 0, + } + } + + /// Update the peer table from a network map. + /// Adds new peers, removes stale ones, updates endpoints. + pub fn update_peers(&mut self, peers: &[crate::proto::types::Node]) { + // Collect the set of peer keys we should have. + let mut desired_keys: HashMap<[u8; 32], &crate::proto::types::Node> = HashMap::new(); + for node in peers { + if let Some(key) = parse_node_key(&node.key) { + desired_keys.insert(key, node); + } + } + + // Remove peers not in the desired set. + self.peers.retain(|k, _| desired_keys.contains_key(k)); + + // Add or update peers. + for (key_bytes, node) in &desired_keys { + if self.peers.contains_key(key_bytes) { + // Update endpoint and DERP region on existing peer. + let peer = self.peers.get_mut(key_bytes).unwrap(); + peer.endpoint = parse_first_endpoint(&node.endpoints); + peer.derp_region = parse_derp_region(&node.derp); + peer.allowed_ips = parse_allowed_ips(&node.allowed_ips); + } else { + // Create a new boringtun tunnel for this peer. + let peer_public = PublicKey::from(*key_bytes); + let index = self.next_index; + self.next_index = self.next_index.wrapping_add(1); + + let tunn = Tunn::new( + self.private_key.clone(), + peer_public, + None, // no preshared key + Some(25), // 25 second persistent keepalive + index, + None, // no rate limiter + ); + + self.peers.insert(*key_bytes, PeerTunnel { + tunn, + endpoint: parse_first_endpoint(&node.endpoints), + derp_region: parse_derp_region(&node.derp), + allowed_ips: parse_allowed_ips(&node.allowed_ips), + }); + } + } + } + + /// Encapsulate an outbound IP packet for transmission. + pub fn encapsulate(&mut self, dst_ip: IpAddr, payload: &[u8]) -> EncapAction { + let peer_key = match self.find_peer_for_ip(dst_ip) { + Some(k) => *k, + None => return EncapAction::Nothing, + }; + + let peer = match self.peers.get_mut(&peer_key) { + Some(p) => p, + None => return EncapAction::Nothing, + }; + + let mut buf = vec![0u8; BUF_SIZE]; + match peer.tunn.encapsulate(payload, &mut buf) { + TunnResult::WriteToNetwork(data) => { + let packet = data.to_vec(); + route_packet(peer, &peer_key, packet) + } + TunnResult::Err(_) | TunnResult::Done => EncapAction::Nothing, + // These shouldn't happen during encapsulate, but handle gracefully. + TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => { + EncapAction::Nothing + } + } + } + + /// Decapsulate an inbound WireGuard packet. + pub fn decapsulate(&mut self, peer_key: &[u8; 32], packet: &[u8]) -> DecapAction { + let peer = match self.peers.get_mut(peer_key) { + Some(p) => p, + None => return DecapAction::Nothing, + }; + + let mut buf = vec![0u8; BUF_SIZE]; + let result = peer.tunn.decapsulate(None, packet, &mut buf); + + match result { + TunnResult::WriteToTunnelV4(data, _addr) => DecapAction::Packet(data.to_vec()), + TunnResult::WriteToTunnelV6(data, _addr) => DecapAction::Packet(data.to_vec()), + TunnResult::WriteToNetwork(data) => { + // This is a handshake response or cookie that needs to be sent back. + let response = data.to_vec(); + // Check for chained results — loop to drain. + let mut chain_buf = vec![0u8; BUF_SIZE]; + loop { + match peer.tunn.decapsulate(None, &[], &mut chain_buf) { + TunnResult::WriteToNetwork(more) => { + // Multiple chained responses; we return the first. + let _ = more; + } + TunnResult::Done => break, + _ => break, + } + } + DecapAction::Response(response) + } + TunnResult::Err(_) | TunnResult::Done => DecapAction::Nothing, + } + } + + /// Tick all tunnels for timer-based actions (keepalives, handshake retries). + pub fn tick(&mut self) -> Vec { + let mut actions = Vec::new(); + let peer_keys: Vec<[u8; 32]> = self.peers.keys().copied().collect(); + + for peer_key in peer_keys { + let peer = match self.peers.get_mut(&peer_key) { + Some(p) => p, + None => continue, + }; + + let mut buf = vec![0u8; BUF_SIZE]; + let result = peer.tunn.update_timers(&mut buf); + + match result { + TunnResult::WriteToNetwork(data) => { + let packet = data.to_vec(); + let action = route_packet(peer, &peer_key, packet); + actions.push(TimerAction { + peer_key, + action, + }); + } + TunnResult::Err(_) | TunnResult::Done => {} + TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => {} + } + } + + actions + } + + /// Find which peer owns a given IP address. + fn find_peer_for_ip(&self, ip: IpAddr) -> Option<&[u8; 32]> { + for (key, peer) in &self.peers { + for net in &peer.allowed_ips { + if net.contains(&ip) { + return Some(key); + } + } + } + None + } + + /// Expose peer count for testing. + #[cfg(test)] + fn peer_count(&self) -> usize { + self.peers.len() + } +} + +/// Decide how to route a WireGuard packet for a given peer. +fn route_packet(peer: &PeerTunnel, peer_key: &[u8; 32], data: Vec) -> EncapAction { + if let Some(endpoint) = peer.endpoint { + EncapAction::SendUdp { endpoint, data } + } else if peer.derp_region.is_some() { + EncapAction::SendDerp { + dest_key: *peer_key, + data, + } + } else { + EncapAction::Nothing + } +} + +/// Parse a "nodekey:" string into raw 32-byte key. +fn parse_node_key(s: &str) -> Option<[u8; 32]> { + let hex_str = s.strip_prefix("nodekey:")?; + let bytes = hex_decode(hex_str)?; + if bytes.len() != 32 { + return None; + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&bytes); + Some(arr) +} + +/// Parse the first endpoint from a list of "ip:port" strings. +fn parse_first_endpoint(endpoints: &[String]) -> Option { + endpoints.first().and_then(|s| s.parse().ok()) +} + +/// Parse DERP region from a DERP string like "127.3.3.40:". +fn parse_derp_region(derp: &str) -> Option { + // Tailscale/Headscale DERP format: "127.3.3.40:" + let parts: Vec<&str> = derp.split(':').collect(); + if parts.len() == 2 { + parts[1].parse().ok() + } else { + None + } +} + +/// Parse allowed IP strings into IpNet values. +fn parse_allowed_ips(ips: &[String]) -> Vec { + ips.iter().filter_map(|s| s.parse().ok()).collect() +} + +/// Simple hex decoder. +fn hex_decode(s: &str) -> Option> { + if s.len() % 2 != 0 { + return None; + } + (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16).ok()) + .collect() +} + +/// Simple hex encoder. +fn hex_encode(bytes: &[u8]) -> String { + bytes.iter().map(|b| format!("{b:02x}")).collect() +} + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + use super::*; + use crate::proto::types::{HostInfo, Node}; + + fn test_hostinfo() -> HostInfo { + HostInfo { + go_arch: "arm64".into(), + go_os: "linux".into(), + go_version: "sunbeam-net/0.1".into(), + hostname: "test".into(), + os: "linux".into(), + os_version: "6.1".into(), + device_model: None, + frontend_log_id: None, + backend_log_id: None, + } + } + + fn make_peer_node(key_bytes: &[u8; 32], allowed_ips: Vec<&str>) -> Node { + Node { + id: 1, + key: format!("nodekey:{}", hex_encode(key_bytes)), + disco_key: "discokey:0000000000000000000000000000000000000000000000000000000000000000".into(), + addresses: vec!["100.64.0.2/32".into()], + allowed_ips: allowed_ips.into_iter().map(String::from).collect(), + endpoints: vec!["1.2.3.4:41641".into()], + derp: "127.3.3.40:1".into(), + hostinfo: test_hostinfo(), + name: "peer.example.com".into(), + online: Some(true), + machine_authorized: true, + } + } + + fn generate_key() -> ([u8; 32], StaticSecret) { + let secret = StaticSecret::random_from_rng(rand::rngs::OsRng); + let public = PublicKey::from(&secret); + (*public.as_bytes(), secret) + } + + #[test] + fn test_new_tunnel() { + let (_, secret) = generate_key(); + let tunnel = WgTunnel::new(secret); + assert_eq!(tunnel.peer_count(), 0); + } + + #[test] + fn test_update_peers_adds_peer() { + let (_, my_secret) = generate_key(); + let mut tunnel = WgTunnel::new(my_secret); + + let (peer_pub, _peer_secret) = generate_key(); + let node = make_peer_node(&peer_pub, vec!["100.64.0.2/32"]); + + tunnel.update_peers(&[node]); + assert_eq!(tunnel.peer_count(), 1); + assert!(tunnel.peers.contains_key(&peer_pub)); + } + + #[test] + fn test_update_peers_removes_stale() { + let (_, my_secret) = generate_key(); + let mut tunnel = WgTunnel::new(my_secret); + + let (peer_pub, _) = generate_key(); + let node = make_peer_node(&peer_pub, vec!["100.64.0.2/32"]); + tunnel.update_peers(&[node]); + assert_eq!(tunnel.peer_count(), 1); + + // Update with empty list — stale peer should be removed. + tunnel.update_peers(&[]); + assert_eq!(tunnel.peer_count(), 0); + } + + #[test] + fn test_find_peer_for_ip() { + let (_, my_secret) = generate_key(); + let mut tunnel = WgTunnel::new(my_secret); + + let (peer_pub, _) = generate_key(); + let node = make_peer_node(&peer_pub, vec!["100.64.0.0/24"]); + tunnel.update_peers(&[node]); + + let found = tunnel.find_peer_for_ip(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 5))); + assert_eq!(found, Some(&peer_pub)); + } + + #[test] + fn test_find_peer_no_match() { + let (_, my_secret) = generate_key(); + let mut tunnel = WgTunnel::new(my_secret); + + let (peer_pub, _) = generate_key(); + let node = make_peer_node(&peer_pub, vec!["100.64.0.0/24"]); + tunnel.update_peers(&[node]); + + // IP outside the allowed range. + let found = tunnel.find_peer_for_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); + assert!(found.is_none()); + } +}