diff --git a/sunbeam-net/src/control/client.rs b/sunbeam-net/src/control/client.rs new file mode 100644 index 00000000..1fc488bb --- /dev/null +++ b/sunbeam-net/src/control/client.rs @@ -0,0 +1,280 @@ +use bytes::Bytes; +use h2::client::SendRequest; +use tokio::net::TcpStream; + +use crate::config::VpnConfig; +use crate::keys::NodeKeys; +use crate::noise; + +/// Client for the coordination server control protocol. +/// +/// Communicates over HTTP/2 on top of a Noise-encrypted TCP stream. +pub struct ControlClient { + pub(crate) sender: SendRequest, + /// Keep the h2 connection driver alive. + _conn_task: tokio::task::JoinHandle<()>, +} + +impl std::fmt::Debug for ControlClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ControlClient").finish_non_exhaustive() + } +} + +impl ControlClient { + /// Connect to the coordination server. + /// + /// 1. TCP connect to `coordination_url` + /// 2. Perform TS2021 Noise IK handshake + /// 3. Wrap the TCP stream in a [`noise::stream::NoiseStream`] + /// 4. Run `h2::client::handshake` over the encrypted stream + /// 5. Spawn the h2 connection driver task + pub async fn connect(config: &VpnConfig, keys: &NodeKeys) -> crate::Result { + // Parse host:port from the coordination URL. + let addr = parse_coordination_addr(&config.coordination_url)?; + + tracing::debug!("connecting to coordination server at {addr}"); + + let mut tcp = TcpStream::connect(&addr) + .await + .map_err(|e| crate::Error::Control(format!("tcp connect to {addr}: {e}")))?; + + // Resolve the server's Noise public key. + let server_public = match config.server_public_key { + Some(key) => key, + None => fetch_server_key(&addr).await?, + }; + + // Noise IK handshake (controlbase protocol). + let server_pub_key = x25519_dalek::PublicKey::from(server_public); + let result = noise::handshake::perform_handshake( + &mut tcp, + &keys.node_private, + &keys.node_public, + &server_pub_key, + ) + .await?; + + // Wrap in NoiseStream for transparent encryption. + // Pass leftover bytes from the handshake TCP buffer. + let mut noise_stream = + noise::stream::NoiseStream::new(tcp, result.tx_cipher, result.rx_cipher, result.leftover); + + // Consume the early payload (EarlyNoise) before h2 starts. + // Headscale sends this immediately after the handshake. + let early = noise_stream.consume_early_payload().await + .map_err(|e| crate::Error::Control(format!("early payload: {e}")))?; + tracing::debug!("early payload consumed ({} bytes)", early.len()); + + // h2 client handshake over the encrypted stream. + let (sender, connection) = h2::client::handshake(noise_stream) + .await + .map_err(|e| crate::Error::Control(format!("h2 handshake: {e}")))?; + + // Spawn the connection driver — it must run for the lifetime of the client. + let conn_task = tokio::spawn(async move { + if let Err(e) = connection.await { + tracing::error!("h2 connection error: {e}"); + } + }); + + tracing::debug!("control client connected"); + + Ok(Self { + sender, + _conn_task: conn_task, + }) + } + + /// Send a POST request with a JSON body and parse a JSON response. + pub(crate) async fn post_json( + &mut self, + path: &str, + body: &Req, + ) -> crate::Result { + let body_bytes = serde_json::to_vec(body)?; + + let request = http::Request::builder() + .method("POST") + .uri(path) + .header("content-type", "application/json") + .body(()) + .map_err(|e| crate::Error::Control(e.to_string()))?; + + let (response_future, mut send_stream) = self + .sender + .send_request(request, false) + .map_err(|e| crate::Error::Control(format!("send request: {e}")))?; + + // Send the JSON body and signal end-of-stream. + send_stream + .send_data(Bytes::from(body_bytes), true) + .map_err(|e| crate::Error::Control(format!("send body: {e}")))?; + + // Await the response headers. + let response = response_future + .await + .map_err(|e| crate::Error::Control(format!("response: {e}")))?; + + let status = response.status(); + let mut body = response.into_body(); + + // Collect the response body. + let mut response_bytes = Vec::new(); + while let Some(chunk) = body.data().await { + let chunk = chunk.map_err(|e| crate::Error::Control(format!("read body: {e}")))?; + response_bytes.extend_from_slice(&chunk); + body.flow_control() + .release_capacity(chunk.len()) + .map_err(|e| crate::Error::Control(format!("flow control: {e}")))?; + } + + if !status.is_success() { + let text = String::from_utf8_lossy(&response_bytes); + return Err(crate::Error::Control(format!( + "{path} returned {status}: {text}" + ))); + } + + tracing::debug!("{path} response: {}", String::from_utf8_lossy(&response_bytes)); + let parsed = serde_json::from_slice(&response_bytes)?; + Ok(parsed) + } +} + +/// Parse a coordination URL into a `host:port` address string. +/// +/// Accepts forms like `https://control.example.com`, `http://host:8080`, +/// or plain `host:port`. +fn parse_coordination_addr(url: &str) -> crate::Result { + // Strip scheme if present. + let without_scheme = if let Some(rest) = url.strip_prefix("https://") { + rest + } else if let Some(rest) = url.strip_prefix("http://") { + rest + } else { + url + }; + + // Strip trailing path. + let host_port = without_scheme.split('/').next().unwrap_or(without_scheme); + + // Add default port if not present. + if host_port.contains(':') { + Ok(host_port.to_string()) + } else if url.starts_with("http://") { + Ok(format!("{host_port}:80")) + } else { + // Default to 443 for https or unspecified. + Ok(format!("{host_port}:443")) + } +} + +/// Fetch the server's Noise public key from its `/key?v=69` endpoint. +/// +/// Headscale/Tailscale returns JSON: `{"publicKey":"mkey:","legacyPublicKey":"mkey:..."}`. +/// We parse the `publicKey` field and strip the `mkey:` prefix. +async fn fetch_server_key(addr: &str) -> crate::Result<[u8; 32]> { + let mut tcp = TcpStream::connect(addr) + .await + .map_err(|e| crate::Error::Control(format!("connect to /key: {e}")))?; + + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let host = addr.split(':').next().unwrap_or(addr); + let request = format!( + "GET /key?v=69 HTTP/1.1\r\nHost: {host}\r\nConnection: close\r\n\r\n" + ); + tcp.write_all(request.as_bytes()) + .await + .map_err(|e| crate::Error::Control(format!("write /key request: {e}")))?; + + let mut buf = Vec::with_capacity(4096); + tcp.read_to_end(&mut buf) + .await + .map_err(|e| crate::Error::Control(format!("read /key response: {e}")))?; + + let response = String::from_utf8_lossy(&buf); + + // Find the body after the blank line. + let body = response + .split("\r\n\r\n") + .nth(1) + .ok_or_else(|| crate::Error::Control("malformed /key response".into()))? + .trim(); + + // Parse JSON response: {"publicKey":"mkey:", ...} + let json: serde_json::Value = serde_json::from_str(body) + .map_err(|e| crate::Error::Control(format!("parse /key JSON: {e} body={body}")))?; + + let public_key = json + .get("publicKey") + .or_else(|| json.get("public_key")) + .and_then(|v| v.as_str()) + .ok_or_else(|| crate::Error::Control("no publicKey in /key response".into()))?; + + // Strip known prefixes: "mkey:", "nodekey:" + let hex_str = public_key + .strip_prefix("mkey:") + .or_else(|| public_key.strip_prefix("nodekey:")) + .unwrap_or(public_key); + + parse_hex_key(hex_str) +} + +/// Parse a 32-byte key from a hex string. +fn parse_hex_key(hex: &str) -> crate::Result<[u8; 32]> { + if hex.len() != 64 { + return Err(crate::Error::Control(format!( + "expected 64 hex chars for server key, got {}", + hex.len() + ))); + } + let mut key = [0u8; 32]; + for (i, byte) in key.iter_mut().enumerate() { + *byte = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16) + .map_err(|e| crate::Error::Control(format!("bad hex in server key: {e}")))?; + } + Ok(key) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_https_url() { + let addr = parse_coordination_addr("https://control.example.com").unwrap(); + assert_eq!(addr, "control.example.com:443"); + } + + #[test] + fn parse_http_url_with_port() { + let addr = parse_coordination_addr("http://localhost:8080").unwrap(); + assert_eq!(addr, "localhost:8080"); + } + + #[test] + fn parse_url_with_path() { + let addr = parse_coordination_addr("https://control.example.com/ts2021").unwrap(); + assert_eq!(addr, "control.example.com:443"); + } + + #[test] + fn parse_plain_host_port() { + let addr = parse_coordination_addr("10.0.0.1:443").unwrap(); + assert_eq!(addr, "10.0.0.1:443"); + } + + #[test] + fn parse_hex_key_valid() { + let hex = "ab".repeat(32); + let key = parse_hex_key(&hex).unwrap(); + assert_eq!(key, [0xab; 32]); + } + + #[test] + fn parse_hex_key_wrong_length() { + assert!(parse_hex_key("aabb").is_err()); + } +} diff --git a/sunbeam-net/src/control/mod.rs b/sunbeam-net/src/control/mod.rs new file mode 100644 index 00000000..405c5be7 --- /dev/null +++ b/sunbeam-net/src/control/mod.rs @@ -0,0 +1,6 @@ +pub mod client; +pub mod netmap; +pub mod register; + +pub use client::ControlClient; +pub use netmap::{MapStream, MapUpdate}; diff --git a/sunbeam-net/src/control/netmap.rs b/sunbeam-net/src/control/netmap.rs new file mode 100644 index 00000000..58f71478 --- /dev/null +++ b/sunbeam-net/src/control/netmap.rs @@ -0,0 +1,336 @@ +use bytes::{Buf, Bytes}; +use h2::RecvStream; + +use crate::proto::types::{DerpMap, MapRequest, MapResponse, Node}; + +/// A streaming reader for map updates from the coordination server. +/// +/// The map protocol uses HTTP/2 server streaming: the client sends a single +/// `MapRequest` and the server responds with a sequence of length-prefixed +/// JSON messages on the same response body. +pub struct MapStream { + body: RecvStream, + buf: bytes::BytesMut, + /// Whether we have received the first (full) message yet. + first: bool, +} + +impl std::fmt::Debug for MapStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MapStream") + .field("buffered", &self.buf.len()) + .field("first", &self.first) + .finish() + } +} + +/// Parsed map update. +#[derive(Debug)] +pub enum MapUpdate { + /// First response: full network map. + Full { + self_node: Node, + peers: Vec, + derp_map: Option, + }, + /// Delta: peers changed. + PeersChanged(Vec), + /// Delta: peers removed (by node key). + PeersRemoved(Vec), + /// Keep-alive (empty response). + KeepAlive, +} + +impl MapStream { + /// Read the next map update from the stream. + /// + /// Protocol: each update is prefixed with a 4-byte little-endian `u32` + /// length, followed by the message body. The first message is raw JSON; + /// subsequent messages are zstd-compressed JSON. + pub async fn next(&mut self) -> crate::Result> { + // Ensure we have at least 4 bytes for the length prefix. + while self.buf.len() < 4 { + match self.body.data().await { + Some(Ok(chunk)) => { + self.body + .flow_control() + .release_capacity(chunk.len()) + .map_err(|e| { + crate::Error::Control(format!("flow control: {e}")) + })?; + self.buf.extend_from_slice(&chunk); + } + Some(Err(e)) => { + return Err(crate::Error::Control(format!("read stream: {e}"))); + } + None => return Ok(None), // Stream ended. + } + } + + // Read the 4-byte LE length prefix. + let msg_len = u32::from_le_bytes([ + self.buf[0], + self.buf[1], + self.buf[2], + self.buf[3], + ]) as usize; + self.buf.advance(4); + + // Read the full message body. + while self.buf.len() < msg_len { + match self.body.data().await { + Some(Ok(chunk)) => { + self.body + .flow_control() + .release_capacity(chunk.len()) + .map_err(|e| { + crate::Error::Control(format!("flow control: {e}")) + })?; + self.buf.extend_from_slice(&chunk); + } + Some(Err(e)) => { + return Err(crate::Error::Control(format!("read stream: {e}"))); + } + None => { + return Err(crate::Error::Control(format!( + "stream ended mid-message (need {msg_len} bytes, have {})", + self.buf.len() + ))); + } + } + } + + let raw = self.buf.split_to(msg_len); + if self.first { + self.first = false; + } + + // Detect zstd compression by magic bytes (0x28 0xB5 0x2F 0xFD). + // Headscale only zstd-compresses if the client requested it via the + // `Compress` field in MapRequest; otherwise messages are raw JSON. + let json_bytes = if raw.len() >= 4 && raw[0] == 0x28 && raw[1] == 0xB5 && raw[2] == 0x2F && raw[3] == 0xFD { + zstd::stream::decode_all(raw.as_ref()).map_err(|e| { + crate::Error::Control(format!("zstd decompress: {e}")) + })? + } else { + raw.to_vec() + }; + + let resp: MapResponse = serde_json::from_slice(&json_bytes)?; + Ok(Some(classify_response(resp))) + } +} + +/// Classify a [`MapResponse`] into a [`MapUpdate`] variant based on which +/// fields are populated. +fn classify_response(resp: MapResponse) -> MapUpdate { + // Full map: has node + peers. + if let Some(self_node) = resp.node { + if let Some(peers) = resp.peers { + return MapUpdate::Full { + self_node, + peers, + derp_map: resp.derp_map, + }; + } + } + + // Delta: peers changed. + if let Some(changed) = resp.peers_changed { + return MapUpdate::PeersChanged(changed); + } + + // Delta: peers removed. + if let Some(removed) = resp.peers_removed { + return MapUpdate::PeersRemoved(removed); + } + + // Everything else is a keep-alive. + MapUpdate::KeepAlive +} + +impl super::client::ControlClient { + /// Start a map streaming session. + /// + /// Sends a `MapRequest` to `POST /machine/map` and returns a [`MapStream`] + /// that yields successive [`MapUpdate`]s as the coordination server pushes + /// network map changes. + pub async fn map_stream( + &mut self, + keys: &crate::keys::NodeKeys, + hostname: &str, + ) -> crate::Result { + let req = MapRequest { + version: 74, + node_key: keys.node_key_str(), + disco_key: keys.disco_key_str(), + stream: true, + hostinfo: super::register::build_hostinfo(hostname), + endpoints: None, + }; + + let body_bytes = serde_json::to_vec(&req)?; + + let request = http::Request::builder() + .method("POST") + .uri("/machine/map") + .header("content-type", "application/json") + .body(()) + .map_err(|e| crate::Error::Control(e.to_string()))?; + + let (response_future, mut send_stream) = self + .sender + .send_request(request, false) + .map_err(|e| crate::Error::Control(format!("send map request: {e}")))?; + + send_stream + .send_data(Bytes::from(body_bytes), true) + .map_err(|e| crate::Error::Control(format!("send map body: {e}")))?; + + let response = response_future + .await + .map_err(|e| crate::Error::Control(format!("map response: {e}")))?; + + let status = response.status(); + if !status.is_success() { + return Err(crate::Error::Control(format!( + "/machine/map returned {status}" + ))); + } + + let body = response.into_body(); + + Ok(MapStream { + body, + buf: bytes::BytesMut::new(), + first: true, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::types::HostInfo; + + fn sample_node(id: u64, key: &str) -> Node { + Node { + id, + key: key.to_string(), + disco_key: format!("discokey:{key}"), + addresses: vec!["100.64.0.1/32".to_string()], + allowed_ips: vec!["100.64.0.0/10".to_string()], + endpoints: vec![], + derp: "127.3.3.40:1".to_string(), + hostinfo: HostInfo { + go_arch: "amd64".to_string(), + go_os: "linux".to_string(), + go_version: "sunbeam-net/0.1.0".to_string(), + hostname: "test".to_string(), + os: "linux".to_string(), + os_version: String::new(), + device_model: None, + frontend_log_id: None, + backend_log_id: None, + }, + name: "test.example.com".to_string(), + online: Some(true), + machine_authorized: true, + } + } + + #[test] + fn test_map_update_classify_full() { + let resp = MapResponse { + node: Some(sample_node(1, "nodekey:aa")), + peers: Some(vec![sample_node(2, "nodekey:bb")]), + peers_changed: None, + peers_removed: None, + derp_map: None, + dns_config: None, + packet_filter: None, + domain: Some("example.com".to_string()), + collection_name: None, + }; + + let update = classify_response(resp); + match update { + MapUpdate::Full { + self_node, + peers, + derp_map, + } => { + assert_eq!(self_node.id, 1); + assert_eq!(peers.len(), 1); + assert_eq!(peers[0].id, 2); + assert!(derp_map.is_none()); + } + other => panic!("expected Full, got {other:?}"), + } + } + + #[test] + fn test_map_update_classify_delta() { + let resp = MapResponse { + node: None, + peers: None, + peers_changed: Some(vec![sample_node(3, "nodekey:cc")]), + peers_removed: None, + derp_map: None, + dns_config: None, + packet_filter: None, + domain: None, + collection_name: None, + }; + + let update = classify_response(resp); + match update { + MapUpdate::PeersChanged(changed) => { + assert_eq!(changed.len(), 1); + assert_eq!(changed[0].id, 3); + } + other => panic!("expected PeersChanged, got {other:?}"), + } + } + + #[test] + fn test_map_update_classify_peers_removed() { + let resp = MapResponse { + node: None, + peers: None, + peers_changed: None, + peers_removed: Some(vec!["nodekey:dd".to_string()]), + derp_map: None, + dns_config: None, + packet_filter: None, + domain: None, + collection_name: None, + }; + + let update = classify_response(resp); + match update { + MapUpdate::PeersRemoved(removed) => { + assert_eq!(removed, vec!["nodekey:dd"]); + } + other => panic!("expected PeersRemoved, got {other:?}"), + } + } + + #[test] + fn test_map_update_classify_keepalive() { + let resp = MapResponse { + node: None, + peers: None, + peers_changed: None, + peers_removed: None, + derp_map: None, + dns_config: None, + packet_filter: None, + domain: None, + collection_name: None, + }; + + let update = classify_response(resp); + assert!(matches!(update, MapUpdate::KeepAlive)); + } +} diff --git a/sunbeam-net/src/control/register.rs b/sunbeam-net/src/control/register.rs new file mode 100644 index 00000000..2fafca01 --- /dev/null +++ b/sunbeam-net/src/control/register.rs @@ -0,0 +1,92 @@ +use crate::keys::NodeKeys; +use crate::proto::types::{AuthInfo, HostInfo, RegisterRequest, RegisterResponse}; + +impl super::client::ControlClient { + /// Register this node with the coordination server using a pre-auth key. + /// + /// Sends a `RegisterRequest` to `POST /machine/register` and returns the + /// server's `RegisterResponse`, which includes the assigned addresses, + /// user info, and machine authorization status. + pub async fn register( + &mut self, + auth_key: &str, + hostname: &str, + keys: &NodeKeys, + ) -> crate::Result { + let req = RegisterRequest { + version: 74, // capability version + node_key: keys.node_key_str(), + old_node_key: format!("nodekey:{}", "0".repeat(64)), + auth: Some(AuthInfo { + auth_key: Some(auth_key.to_string()), + }), + hostinfo: build_hostinfo(hostname), + followup: None, + timestamp: None, + }; + + self.post_json("/machine/register", &req).await + } +} + +/// Build a [`HostInfo`] for the current platform. +pub(crate) fn build_hostinfo(hostname: &str) -> HostInfo { + HostInfo { + go_arch: std::env::consts::ARCH.to_string(), + go_os: std::env::consts::OS.to_string(), + go_version: format!("sunbeam-net/{}", env!("CARGO_PKG_VERSION")), + hostname: hostname.to_string(), + os: std::env::consts::OS.to_string(), + os_version: String::new(), + device_model: None, + frontend_log_id: None, + backend_log_id: None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_hostinfo() { + let hi = build_hostinfo("myhost"); + assert_eq!(hi.hostname, "myhost"); + assert!(!hi.go_arch.is_empty(), "go_arch should be set"); + assert!(!hi.go_os.is_empty(), "go_os should be set"); + assert!( + hi.go_version.starts_with("sunbeam-net/"), + "go_version should start with sunbeam-net/" + ); + assert!(!hi.os.is_empty(), "os should be set"); + } + + #[test] + fn test_register_request_construction() { + let keys = crate::keys::NodeKeys::generate(); + let req = RegisterRequest { + version: 74, + node_key: keys.node_key_str(), + old_node_key: String::new(), + auth: Some(AuthInfo { + auth_key: Some("tskey-auth-test123".to_string()), + }), + hostinfo: build_hostinfo("test-host"), + followup: None, + timestamp: None, + }; + + let json = serde_json::to_string(&req).unwrap(); + + // Verify key fields are present with expected values. + assert!(json.contains("\"Version\":74")); + assert!(json.contains(&format!("\"NodeKey\":\"{}\"", keys.node_key_str()))); + assert!(json.contains("\"AuthKey\":\"tskey-auth-test123\"")); + assert!(json.contains("\"Hostname\":\"test-host\"")); + + // Round-trip. + let parsed: RegisterRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.version, 74); + assert_eq!(parsed.node_key, keys.node_key_str()); + } +} diff --git a/sunbeam-net/src/lib.rs b/sunbeam-net/src/lib.rs index b03a0dd9..0494032c 100644 --- a/sunbeam-net/src/lib.rs +++ b/sunbeam-net/src/lib.rs @@ -1,6 +1,7 @@ //! sunbeam-net: Pure Rust Headscale/Tailscale-compatible VPN client. pub mod config; +pub mod control; pub mod derp; pub mod error; pub mod keys;