diff --git a/sunbeam-net/src/lib.rs b/sunbeam-net/src/lib.rs index 0494032c..27d261cb 100644 --- a/sunbeam-net/src/lib.rs +++ b/sunbeam-net/src/lib.rs @@ -6,6 +6,7 @@ pub mod derp; pub mod error; pub mod keys; pub mod noise; +pub mod proxy; pub mod wg; pub(crate) mod proto; diff --git a/sunbeam-net/src/proxy/engine.rs b/sunbeam-net/src/proxy/engine.rs new file mode 100644 index 00000000..ae6c426c --- /dev/null +++ b/sunbeam-net/src/proxy/engine.rs @@ -0,0 +1,255 @@ +//! NetworkEngine: bridges async I/O with smoltcp's synchronous polling. +//! +//! The engine owns the virtual TCP/IP stack (smoltcp) and manages proxy +//! connections that tunnel through WireGuard. It runs as a single async +//! task with a poll loop. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::Duration; + +use smoltcp::wire::IpAddress; +use tokio::net::TcpStream; +use tokio::sync::mpsc; + +use crate::wg::socket::VirtualNetwork; + +/// Commands sent to the engine from the proxy listener. +pub(crate) enum EngineCommand { + /// A new local TCP connection to proxy through the VPN. + NewConnection { + local: TcpStream, + remote: SocketAddr, + }, +} + +/// The network engine bridges local TCP connections through the smoltcp +/// virtual network and WireGuard tunnel. +pub(crate) struct NetworkEngine { + vnet: VirtualNetwork, + /// Receive commands (new connections). + cmd_rx: mpsc::Receiver, + /// IP packets coming from WireGuard (decrypted) → smoltcp device. + _wg_to_smoltcp_tx: mpsc::Sender>, + /// IP packets from smoltcp device → to be WireGuard encrypted. + smoltcp_to_wg_rx: mpsc::Receiver>, + /// Send encrypted WG IP packets outward (to the daemon for routing). + outbound_tx: mpsc::Sender>, + /// Active proxy connections. + connections: HashMap, + next_id: u64, +} + +struct ProxyConnection { + local: TcpStream, + handle: crate::wg::socket::TcpSocketHandle, + /// Buffer for data read from local TCP, waiting to be sent to smoltcp. + local_buf: Vec, + /// Buffer for data read from smoltcp, waiting to be written to local TCP. + remote_buf: Vec, + /// Whether the local read side is done (EOF or error). + local_read_done: bool, + /// Whether the remote (smoltcp) side is done. + remote_done: bool, +} + +/// Channels for wiring the engine to the WireGuard layer. +pub(crate) struct EngineChannels { + /// Send IP packets into the engine (from WG decap). + pub wg_to_engine_tx: mpsc::Sender>, + /// Receive IP packets from the engine (for WG encap). + pub engine_to_wg_rx: mpsc::Receiver>, + /// Send commands to the engine. + pub cmd_tx: mpsc::Sender, +} + +impl NetworkEngine { + /// Create a new engine with the given local VPN IP address. + /// + /// Returns the engine and a set of channels for communicating with it. + pub fn new( + local_ip: IpAddress, + prefix_len: u8, + ) -> crate::Result<(Self, EngineChannels)> { + // Channels between WG and smoltcp device + let (wg_to_smoltcp_tx, wg_to_smoltcp_rx) = mpsc::channel::>(256); + let (smoltcp_to_wg_tx, smoltcp_to_wg_rx) = mpsc::channel::>(256); + + // Channel for outbound IP packets (engine → daemon for WG encap) + let (outbound_tx, engine_to_wg_rx) = mpsc::channel::>(256); + + // Command channel + let (cmd_tx, cmd_rx) = mpsc::channel(32); + + let vnet = VirtualNetwork::new( + local_ip, + prefix_len, + wg_to_smoltcp_rx, + smoltcp_to_wg_tx, + )?; + + let engine = Self { + vnet, + cmd_rx, + _wg_to_smoltcp_tx: wg_to_smoltcp_tx.clone(), + smoltcp_to_wg_rx, + outbound_tx, + connections: HashMap::new(), + next_id: 0, + }; + + let channels = EngineChannels { + wg_to_engine_tx: wg_to_smoltcp_tx, + engine_to_wg_rx, + cmd_tx, + }; + + Ok((engine, channels)) + } + + /// Run the engine poll loop. This should be spawned as a tokio task. + pub async fn run(mut self, cancel: tokio_util::sync::CancellationToken) { + let mut poll_interval = tokio::time::interval(Duration::from_millis(5)); + poll_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + + loop { + tokio::select! { + _ = cancel.cancelled() => { + tracing::debug!("network engine shutting down"); + return; + } + cmd = self.cmd_rx.recv() => { + match cmd { + Some(EngineCommand::NewConnection { local, remote }) => { + self.add_connection(local, remote); + } + None => return, // command channel closed + } + } + _ = poll_interval.tick() => { + self.poll_cycle().await; + } + } + } + } + + fn add_connection(&mut self, local: TcpStream, remote: SocketAddr) { + match self.vnet.tcp_connect(remote) { + Ok(handle) => { + let id = self.next_id; + self.next_id += 1; + tracing::debug!("proxy connection {id} → {remote}"); + self.connections.insert(id, ProxyConnection { + local, + handle, + local_buf: Vec::with_capacity(8192), + remote_buf: Vec::with_capacity(8192), + local_read_done: false, + remote_done: false, + }); + } + Err(e) => { + tracing::warn!("failed to open virtual TCP to {remote}: {e}"); + } + } + } + + async fn poll_cycle(&mut self) { + // 1. Forward outbound IP packets from smoltcp → WG + while let Ok(ip_packet) = self.smoltcp_to_wg_rx.try_recv() { + let _ = self.outbound_tx.try_send(ip_packet); + } + + // 2. Poll smoltcp + self.vnet.poll(); + + // 3. Bridge each proxy connection + let mut to_remove = Vec::new(); + let conn_ids: Vec = self.connections.keys().copied().collect(); + + for id in conn_ids { + let conn = self.connections.get_mut(&id).unwrap(); + let done = Self::bridge_connection(&mut self.vnet, conn).await; + if done { + to_remove.push(id); + } + } + + for id in to_remove { + tracing::debug!("proxy connection {id} closed"); + self.connections.remove(&id); + } + + // 4. Poll again after bridging to flush any new outbound data + self.vnet.poll(); + + // Drain outbound again + while let Ok(ip_packet) = self.smoltcp_to_wg_rx.try_recv() { + let _ = self.outbound_tx.try_send(ip_packet); + } + } + + /// Bridge data between a local TCP stream and a smoltcp socket. + /// Returns true if the connection is done and should be removed. + async fn bridge_connection( + vnet: &mut VirtualNetwork, + conn: &mut ProxyConnection, + ) -> bool { + // Local → smoltcp: try to read from local TCP (non-blocking) + if !conn.local_read_done && conn.local_buf.len() < 32768 { + let mut tmp = [0u8; 8192]; + match conn.local.try_read(&mut tmp) { + Ok(0) => { + conn.local_read_done = true; + } + Ok(n) => { + conn.local_buf.extend_from_slice(&tmp[..n]); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {} + Err(_) => { + conn.local_read_done = true; + } + } + } + + // Send buffered local data to smoltcp socket + if !conn.local_buf.is_empty() { + match vnet.tcp_send(conn.handle, &conn.local_buf) { + Ok(n) if n > 0 => { + conn.local_buf.drain(..n); + } + _ => {} + } + } + + // smoltcp → local: read from smoltcp socket + let mut tmp = [0u8; 8192]; + match vnet.tcp_recv(conn.handle, &mut tmp) { + Ok(n) if n > 0 => { + conn.remote_buf.extend_from_slice(&tmp[..n]); + } + Err(_) => { + conn.remote_done = true; + } + _ => {} + } + + // Write buffered smoltcp data to local TCP + if !conn.remote_buf.is_empty() { + match conn.local.try_write(&conn.remote_buf) { + Ok(n) if n > 0 => { + conn.remote_buf.drain(..n); + } + _ => {} + } + } + + // Check if the virtual socket is still alive + if !vnet.tcp_is_active(conn.handle) && conn.remote_buf.is_empty() { + conn.remote_done = true; + } + + // Done when both sides are finished + conn.local_read_done && conn.remote_done && conn.local_buf.is_empty() && conn.remote_buf.is_empty() + } +} diff --git a/sunbeam-net/src/proxy/mod.rs b/sunbeam-net/src/proxy/mod.rs new file mode 100644 index 00000000..5a0feb13 --- /dev/null +++ b/sunbeam-net/src/proxy/mod.rs @@ -0,0 +1,3 @@ +pub mod engine; +pub mod tcp; +pub(crate) use tcp::TcpProxy; diff --git a/sunbeam-net/src/proxy/tcp.rs b/sunbeam-net/src/proxy/tcp.rs new file mode 100644 index 00000000..f8ef64c1 --- /dev/null +++ b/sunbeam-net/src/proxy/tcp.rs @@ -0,0 +1,161 @@ +use std::net::SocketAddr; +use tokio::net::TcpListener; + +/// TCP proxy that accepts local connections and forwards them through the VPN. +/// +/// Currently uses direct TCP as a placeholder until the WireGuard virtual +/// network is wired in. +pub(crate) struct TcpProxy { + bind_addr: SocketAddr, + remote_addr: SocketAddr, +} + +impl TcpProxy { + pub fn new(bind_addr: SocketAddr, remote_addr: SocketAddr) -> Self { + Self { bind_addr, remote_addr } + } + + /// Run the proxy, accepting connections until the cancellation token fires. + pub async fn run(&self, cancel: tokio_util::sync::CancellationToken) -> crate::Result<()> { + let listener = TcpListener::bind(self.bind_addr).await + .map_err(|e| crate::Error::Io { + context: format!("bind proxy {}", self.bind_addr), + source: e, + })?; + + tracing::info!("TCP proxy listening on {}", self.bind_addr); + + loop { + tokio::select! { + accept = listener.accept() => { + let (stream, peer) = accept + .map_err(|e| crate::Error::Io { context: "accept proxy".into(), source: e })?; + tracing::debug!("proxy connection from {peer}"); + let remote = self.remote_addr; + tokio::spawn(async move { + if let Err(e) = handle_proxy_connection(stream, remote).await { + tracing::debug!("proxy connection error: {e}"); + } + }); + } + _ = cancel.cancelled() => { + tracing::info!("TCP proxy shutting down"); + return Ok(()); + } + } + } + } +} + +/// Handle a single proxied connection. +/// TODO: Route through WireGuard virtual network instead of direct TCP. +async fn handle_proxy_connection( + mut local: tokio::net::TcpStream, + remote_addr: SocketAddr, +) -> crate::Result<()> { + // For now, direct TCP connect as placeholder. + // When wg/socket.rs is ready, this will dial through the virtual network. + let mut remote = tokio::net::TcpStream::connect(remote_addr).await + .map_err(|e| crate::Error::Io { + context: format!("connect to {remote_addr}"), + source: e, + })?; + + tokio::io::copy_bidirectional(&mut local, &mut remote).await + .map_err(|e| crate::Error::Io { + context: "proxy copy".into(), + source: e, + })?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpListener, TcpStream}; + + #[tokio::test] + async fn test_proxy_bind_and_accept() { + let cancel = tokio_util::sync::CancellationToken::new(); + + // Use port 0 for OS-assigned port; we need to know the actual port. + // Start a dummy remote server so the proxy connection doesn't fail. + let remote_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let remote_addr = remote_listener.local_addr().unwrap(); + + // Bind a temporary listener to find a free port, then release it. + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let proxy_addr = listener.local_addr().unwrap(); + drop(listener); + + let proxy = TcpProxy::new(proxy_addr, remote_addr); + let proxy_cancel = cancel.clone(); + let proxy_task = tokio::spawn(async move { proxy.run(proxy_cancel).await }); + + // Give it a moment to bind. + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + // Connect to the proxy. + let stream = TcpStream::connect(proxy_addr).await.unwrap(); + assert!(stream.peer_addr().is_ok()); + + // Accept on the remote side so the proxy connection completes. + let _ = remote_listener.accept().await.unwrap(); + + cancel.cancel(); + let _ = proxy_task.await; + } + + #[tokio::test] + async fn test_proxy_forwarding() { + let cancel = tokio_util::sync::CancellationToken::new(); + + // Start a mock echo server. + let echo_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let echo_addr = echo_listener.local_addr().unwrap(); + + tokio::spawn(async move { + while let Ok((mut stream, _)) = echo_listener.accept().await { + tokio::spawn(async move { + let mut buf = [0u8; 1024]; + loop { + let n = match stream.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if stream.write_all(&buf[..n]).await.is_err() { + break; + } + } + }); + } + }); + + // Start the proxy pointing at the echo server. + // Find a free port for the proxy. + let tmp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let proxy_addr = tmp_listener.local_addr().unwrap(); + drop(tmp_listener); + + let proxy = TcpProxy::new(proxy_addr, echo_addr); + let proxy_cancel = cancel.clone(); + tokio::spawn(async move { proxy.run(proxy_cancel).await }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + // Connect through the proxy and send data. + let mut stream = TcpStream::connect(proxy_addr).await.unwrap(); + + let payload = b"hello through the proxy"; + stream.write_all(payload).await.unwrap(); + + let mut buf = vec![0u8; payload.len()]; + stream.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, payload); + + cancel.cancel(); + } +}