diff --git a/sunbeam-net/src/lib.rs b/sunbeam-net/src/lib.rs index 4a66792d..e5b75ce5 100644 --- a/sunbeam-net/src/lib.rs +++ b/sunbeam-net/src/lib.rs @@ -3,6 +3,7 @@ pub mod config; pub mod error; pub mod keys; +pub mod noise; pub(crate) mod proto; pub use config::VpnConfig; diff --git a/sunbeam-net/src/noise/framing.rs b/sunbeam-net/src/noise/framing.rs new file mode 100644 index 00000000..15656959 --- /dev/null +++ b/sunbeam-net/src/noise/framing.rs @@ -0,0 +1,181 @@ +use bytes::{Buf, BufMut, BytesMut}; +use tokio_util::codec::{Decoder, Encoder}; + +/// Record frame type (controlbase post-handshake data). +pub const FRAME_TYPE_DATA: u8 = 0x04; +/// Control frame type. +pub const FRAME_TYPE_CONTROL: u8 = 0x01; +/// Maximum payload size per frame. +pub const MAX_FRAME_SIZE: usize = 4096; +/// Size of the frame header (1 byte type + 2 bytes length). +pub const HEADER_SIZE: usize = 3; + +/// A single Noise frame (type + payload). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NoiseFrame { + pub frame_type: u8, + pub payload: BytesMut, +} + +/// Codec for encoding/decoding Noise frames on a byte stream. +#[derive(Debug, Default)] +pub struct NoiseFrameCodec; + +impl Decoder for NoiseFrameCodec { + type Item = NoiseFrame; + type Error = std::io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> std::result::Result, Self::Error> { + if src.len() < HEADER_SIZE { + return Ok(None); + } + + let frame_type = src[0]; + let payload_len = u16::from_be_bytes([src[1], src[2]]) as usize; + + if payload_len > MAX_FRAME_SIZE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("frame payload {payload_len} exceeds maximum {MAX_FRAME_SIZE}"), + )); + } + + let total = HEADER_SIZE + payload_len; + if src.len() < total { + // Reserve space so the next read can fill it. + src.reserve(total - src.len()); + return Ok(None); + } + + src.advance(HEADER_SIZE); + let payload = src.split_to(payload_len); + + Ok(Some(NoiseFrame { + frame_type, + payload, + })) + } +} + +impl Encoder for NoiseFrameCodec { + type Error = std::io::Error; + + fn encode(&mut self, item: NoiseFrame, dst: &mut BytesMut) -> std::result::Result<(), Self::Error> { + if item.payload.len() > MAX_FRAME_SIZE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "payload length {} exceeds maximum {MAX_FRAME_SIZE}", + item.payload.len() + ), + )); + } + + let len = item.payload.len() as u16; + dst.reserve(HEADER_SIZE + item.payload.len()); + dst.put_u8(item.frame_type); + dst.put_u16(len); + dst.put(item.payload); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn round_trip(frame_type: u8, payload: &[u8]) { + let mut codec = NoiseFrameCodec; + let frame = NoiseFrame { + frame_type, + payload: BytesMut::from(payload), + }; + + let mut buf = BytesMut::new(); + codec.encode(frame.clone(), &mut buf).unwrap(); + + let decoded = codec.decode(&mut buf).unwrap().expect("should decode a frame"); + assert_eq!(decoded, frame); + } + + #[test] + fn round_trip_data_frame() { + round_trip(FRAME_TYPE_DATA, b"hello world"); + } + + #[test] + fn round_trip_control_frame() { + round_trip(FRAME_TYPE_CONTROL, b"\x01\x02\x03"); + } + + #[test] + fn round_trip_empty_payload() { + round_trip(FRAME_TYPE_DATA, b""); + } + + #[test] + fn round_trip_max_size() { + let payload = vec![0xAB; MAX_FRAME_SIZE]; + round_trip(FRAME_TYPE_DATA, &payload); + } + + #[test] + fn partial_header_returns_none() { + let mut codec = NoiseFrameCodec; + let mut buf = BytesMut::from(&[0x00, 0x00][..]); + assert!(codec.decode(&mut buf).unwrap().is_none()); + } + + #[test] + fn partial_payload_returns_none() { + let mut codec = NoiseFrameCodec; + // Header says 10 bytes of payload, but only 3 provided. + let mut buf = BytesMut::from(&[0x00, 0x00, 0x0A, 0x01, 0x02, 0x03][..]); + assert!(codec.decode(&mut buf).unwrap().is_none()); + } + + #[test] + fn oversized_payload_is_rejected_on_decode() { + let mut codec = NoiseFrameCodec; + let bad_len = (MAX_FRAME_SIZE as u16) + 1; + let mut buf = BytesMut::new(); + buf.put_u8(0x00); + buf.put_u16(bad_len); + buf.extend_from_slice(&vec![0; bad_len as usize]); + assert!(codec.decode(&mut buf).is_err()); + } + + #[test] + fn oversized_payload_is_rejected_on_encode() { + let mut codec = NoiseFrameCodec; + let frame = NoiseFrame { + frame_type: FRAME_TYPE_DATA, + payload: BytesMut::from(&vec![0; MAX_FRAME_SIZE + 1][..]), + }; + let mut buf = BytesMut::new(); + assert!(codec.encode(frame, &mut buf).is_err()); + } + + #[test] + fn multiple_frames_in_buffer() { + let mut codec = NoiseFrameCodec; + let mut buf = BytesMut::new(); + + let f1 = NoiseFrame { + frame_type: FRAME_TYPE_DATA, + payload: BytesMut::from(&b"aaa"[..]), + }; + let f2 = NoiseFrame { + frame_type: FRAME_TYPE_CONTROL, + payload: BytesMut::from(&b"bbb"[..]), + }; + codec.encode(f1.clone(), &mut buf).unwrap(); + codec.encode(f2.clone(), &mut buf).unwrap(); + + let d1 = codec.decode(&mut buf).unwrap().unwrap(); + let d2 = codec.decode(&mut buf).unwrap().unwrap(); + assert_eq!(d1, f1); + assert_eq!(d2, f2); + assert!(codec.decode(&mut buf).unwrap().is_none()); + } +} diff --git a/sunbeam-net/src/noise/handshake.rs b/sunbeam-net/src/noise/handshake.rs new file mode 100644 index 00000000..c776cf3a --- /dev/null +++ b/sunbeam-net/src/noise/handshake.rs @@ -0,0 +1,428 @@ +//! Tailscale controlbase handshake implementation. +//! +//! This implements the custom Noise IK variant used by the Tailscale/Headscale +//! control protocol. It is NOT compatible with standard Noise libraries like `snow`. + +use base64::Engine; +use blake2::Digest; +use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce, Tag}; +use hkdf::Hkdf; +use hmac::SimpleHmac; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use x25519_dalek::{PublicKey, StaticSecret}; + +type HkdfBlake2s = Hkdf>; + +const PROTOCOL_NAME: &[u8] = b"Noise_IK_25519_ChaChaPoly_BLAKE2s"; +const PROTOCOL_VERSION_PREFIX: &[u8] = b"Tailscale Control Protocol v"; +const PROTOCOL_VERSION: u16 = 49; + +// Message types +const MSG_TYPE_INITIATION: u8 = 1; +const MSG_TYPE_RESPONSE: u8 = 2; +const MSG_TYPE_ERROR: u8 = 3; +pub(crate) const MSG_TYPE_RECORD: u8 = 4; + +// Sizes +const TAG_SIZE: usize = 16; + +pub(crate) struct HandshakeResult { + pub tx_cipher: ChaCha20Poly1305, // client -> server + pub rx_cipher: ChaCha20Poly1305, // server -> client + pub handshake_hash: [u8; 32], + pub protocol_version: u16, + /// Leftover bytes from the TCP buffer after the handshake response. + /// These are the start of the first Noise transport record. + pub leftover: Vec, +} + +struct SymmetricState { + h: [u8; 32], // hash + ck: [u8; 32], // chaining key +} + +impl SymmetricState { + fn initialize() -> Self { + let h: [u8; 32] = blake2::Blake2s256::digest(PROTOCOL_NAME).into(); + Self { h, ck: h } + } + + fn mix_hash(&mut self, data: &[u8]) { + let mut hasher = blake2::Blake2s256::new(); + hasher.update(&self.h); + hasher.update(data); + self.h = hasher.finalize().into(); + } + + fn mix_dh( + &mut self, + priv_key: &StaticSecret, + pub_key: &PublicKey, + ) -> crate::Result { + let shared = priv_key.diffie_hellman(pub_key); + let hk = HkdfBlake2s::new(Some(&self.ck), shared.as_bytes()); + let mut output = [0u8; 64]; + hk.expand(&[], &mut output) + .map_err(|e| crate::Error::Noise(format!("HKDF expand: {e}")))?; + self.ck.copy_from_slice(&output[..32]); + let key: [u8; 32] = output[32..64].try_into().unwrap(); + ChaCha20Poly1305::new_from_slice(&key) + .map_err(|e| crate::Error::Noise(format!("AEAD key: {e}"))) + } + + fn encrypt_and_hash(&mut self, cipher: ChaCha20Poly1305, plaintext: &[u8]) -> Vec { + let nonce = Nonce::default(); // all zeros + let mut ciphertext = plaintext.to_vec(); + let tag = cipher + .encrypt_in_place_detached(&nonce, &self.h, &mut ciphertext) + .expect("encryption failed"); + ciphertext.extend_from_slice(tag.as_slice()); + self.mix_hash(&ciphertext); + ciphertext + } + + fn decrypt_and_hash( + &mut self, + cipher: ChaCha20Poly1305, + ciphertext: &[u8], + ) -> crate::Result> { + if ciphertext.len() < TAG_SIZE { + return Err(crate::Error::Noise("ciphertext too short".into())); + } + let nonce = Nonce::default(); + let (ct, tag_bytes) = ciphertext.split_at(ciphertext.len() - TAG_SIZE); + let tag = Tag::from_slice(tag_bytes); + let mut plaintext = ct.to_vec(); + cipher + .decrypt_in_place_detached(&nonce, &self.h, &mut plaintext, tag) + .map_err(|_| crate::Error::Noise("AEAD decryption failed".into()))?; + self.mix_hash(ciphertext); + Ok(plaintext) + } + + fn split(self) -> crate::Result<(ChaCha20Poly1305, ChaCha20Poly1305)> { + let hk = HkdfBlake2s::new(Some(&self.ck), &[]); + let mut output = [0u8; 64]; + hk.expand(&[], &mut output) + .map_err(|e| crate::Error::Noise(format!("HKDF split: {e}")))?; + let k1: [u8; 32] = output[..32].try_into().unwrap(); + let k2: [u8; 32] = output[32..64].try_into().unwrap(); + let c1 = ChaCha20Poly1305::new_from_slice(&k1) + .map_err(|e| crate::Error::Noise(format!("c1 key: {e}")))?; + let c2 = ChaCha20Poly1305::new_from_slice(&k2) + .map_err(|e| crate::Error::Noise(format!("c2 key: {e}")))?; + Ok((c1, c2)) + } +} + +fn protocol_version_prologue(version: u16) -> Vec { + let mut p = PROTOCOL_VERSION_PREFIX.to_vec(); + p.extend_from_slice(version.to_string().as_bytes()); + p +} + +/// Perform the TS2021 HTTP upgrade and controlbase Noise IK handshake. +/// +/// Returns a [`HandshakeResult`] containing the transport ciphers. +pub(crate) async fn perform_handshake( + stream: &mut TcpStream, + machine_private: &StaticSecret, + machine_public: &PublicKey, + server_public: &PublicKey, +) -> crate::Result { + let mut s = SymmetricState::initialize(); + + // Prologue + s.mix_hash(&protocol_version_prologue(PROTOCOL_VERSION)); + + // <- s (mix in known server static key) + s.mix_hash(server_public.as_bytes()); + + // -> e, es, s, ss + let ephemeral_secret = StaticSecret::random_from_rng(rand::rngs::OsRng); + let ephemeral_public = PublicKey::from(&ephemeral_secret); + + // Build initiation message (101 bytes) + let mut init = vec![0u8; 101]; + // Header + init[0..2].copy_from_slice(&PROTOCOL_VERSION.to_be_bytes()); + init[2] = MSG_TYPE_INITIATION; + init[3..5].copy_from_slice(&96u16.to_be_bytes()); + // Payload + init[5..37].copy_from_slice(ephemeral_public.as_bytes()); + s.mix_hash(ephemeral_public.as_bytes()); + + let cipher_es = s.mix_dh(&ephemeral_secret, server_public)?; + let encrypted_machine_pub = s.encrypt_and_hash(cipher_es, machine_public.as_bytes()); + init[37..85].copy_from_slice(&encrypted_machine_pub); + + let cipher_ss = s.mix_dh(machine_private, server_public)?; + let tag = s.encrypt_and_hash(cipher_ss, &[]); + init[85..101].copy_from_slice(&tag); + + // Base64-encode for HTTP header + let init_b64 = base64::engine::general_purpose::STANDARD.encode(&init); + + // Send HTTP upgrade + let peer = stream + .peer_addr() + .map(|a| a.to_string()) + .unwrap_or_else(|_| "localhost".into()); + let upgrade_req = format!( + "POST /ts2021 HTTP/1.1\r\n\ + Host: {peer}\r\n\ + Upgrade: tailscale-control-protocol\r\n\ + Connection: upgrade\r\n\ + X-Tailscale-Handshake: {init_b64}\r\n\ + \r\n" + ); + stream + .write_all(upgrade_req.as_bytes()) + .await + .map_err(|e| crate::Error::Noise(format!("write upgrade: {e}")))?; + + // Read HTTP 101 response + let mut resp_buf = vec![0u8; 8192]; + let mut total = 0; + loop { + let n = stream + .read(&mut resp_buf[total..]) + .await + .map_err(|e| crate::Error::Noise(format!("read response: {e}")))?; + if n == 0 { + return Err(crate::Error::ConnectionClosed); + } + total += n; + if resp_buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let header_end = resp_buf[..total] + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map(|p| p + 4) + .ok_or_else(|| crate::Error::Noise("no header boundary".into()))?; + + let headers = String::from_utf8_lossy(&resp_buf[..header_end]); + if !headers.contains("101") { + let body = String::from_utf8_lossy(&resp_buf[header_end..total]); + return Err(crate::Error::Noise(format!( + "expected 101, got: {headers}\nBody: {body}" + ))); + } + + // Read server response message (51 bytes: 3 header + 48 payload) + // Some bytes may already be in resp_buf after the HTTP headers. + let mut resp_data = resp_buf[header_end..total].to_vec(); + while resp_data.len() < 51 { + let mut tmp = vec![0u8; 51 - resp_data.len()]; + let n = stream + .read(&mut tmp) + .await + .map_err(|e| crate::Error::Noise(format!("read response msg: {e}")))?; + if n == 0 { + return Err(crate::Error::ConnectionClosed); + } + resp_data.extend_from_slice(&tmp[..n]); + } + + // Parse response + if resp_data[0] == MSG_TYPE_ERROR { + let err_len = u16::from_be_bytes([resp_data[1], resp_data[2]]) as usize; + let err_msg = + String::from_utf8_lossy(&resp_data[3..3 + err_len.min(resp_data.len() - 3)]); + return Err(crate::Error::Noise(format!("server error: {err_msg}"))); + } + if resp_data[0] != MSG_TYPE_RESPONSE { + return Err(crate::Error::Noise(format!( + "unexpected response type {}", + resp_data[0] + ))); + } + + let server_ephemeral_bytes: [u8; 32] = resp_data[3..35].try_into().unwrap(); + let server_ephemeral = PublicKey::from(server_ephemeral_bytes); + let resp_tag = &resp_data[35..51]; + + // <- e, ee, se + s.mix_hash(server_ephemeral.as_bytes()); + let _cipher_ee = s.mix_dh(&ephemeral_secret, &server_ephemeral)?; + let cipher_se = s.mix_dh(machine_private, &server_ephemeral)?; + s.decrypt_and_hash(cipher_se, resp_tag)?; + + // Split into transport ciphers + let h = s.h; + let (c1, c2) = s.split()?; + + tracing::debug!("controlbase handshake complete"); + + // Any bytes in resp_data beyond the 51-byte response are the start + // of the first Noise transport record (early payload). + let leftover = if resp_data.len() > 51 { + resp_data[51..].to_vec() + } else { + vec![] + }; + + Ok(HandshakeResult { + tx_cipher: c1, + rx_cipher: c2, + handshake_hash: h, + protocol_version: PROTOCOL_VERSION, + leftover, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::net::TcpListener; + + /// Mock server that implements the controlbase server-side handshake. + async fn mock_controlbase_server( + mut sock: TcpStream, + server_private: StaticSecret, + server_public: PublicKey, + ) { + // Read the full HTTP request. + let mut buf = vec![0u8; 8192]; + let mut total = 0; + loop { + let n = sock.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + + // Extract the X-Tailscale-Handshake header. + let req_str = String::from_utf8_lossy(&buf[..total]); + let init_b64 = req_str + .lines() + .find(|l| l.starts_with("X-Tailscale-Handshake:")) + .map(|l| l.trim_start_matches("X-Tailscale-Handshake:").trim()) + .expect("missing X-Tailscale-Handshake header"); + let init_msg = base64::engine::general_purpose::STANDARD + .decode(init_b64) + .expect("bad base64"); + + assert_eq!(init_msg.len(), 101, "initiation message should be 101 bytes"); + + // Parse initiation + let _version = u16::from_be_bytes([init_msg[0], init_msg[1]]); + assert_eq!(init_msg[2], MSG_TYPE_INITIATION); + let _payload_len = u16::from_be_bytes([init_msg[3], init_msg[4]]); + + let client_ephemeral_bytes: [u8; 32] = init_msg[5..37].try_into().unwrap(); + let client_ephemeral = PublicKey::from(client_ephemeral_bytes); + let encrypted_machine_pub = &init_msg[37..85]; + let init_tag = &init_msg[85..101]; + + // Server-side symmetric state + let mut s = SymmetricState::initialize(); + s.mix_hash(&protocol_version_prologue(PROTOCOL_VERSION)); + s.mix_hash(server_public.as_bytes()); + + // Process -> e + s.mix_hash(client_ephemeral.as_bytes()); + + // -> es: MixDH(server_private, client_ephemeral) + let cipher_es = s.mix_dh(&server_private, &client_ephemeral).unwrap(); + + // -> s: DecryptAndHash to get client machine public key + let client_machine_pub_bytes = s.decrypt_and_hash(cipher_es, encrypted_machine_pub).unwrap(); + assert_eq!(client_machine_pub_bytes.len(), 32); + let client_machine_pub = PublicKey::from(<[u8; 32]>::try_from(&client_machine_pub_bytes[..]).unwrap()); + + // -> ss: MixDH(server_private, client_machine_pub) + let cipher_ss = s.mix_dh(&server_private, &client_machine_pub).unwrap(); + s.decrypt_and_hash(cipher_ss, init_tag).unwrap(); + + // Send 101 response. + sock.write_all( + b"HTTP/1.1 101 Switching Protocols\r\n\ + Connection: Upgrade\r\n\ + Upgrade: tailscale-control-protocol\r\n\ + \r\n", + ) + .await + .unwrap(); + + // <- e (server ephemeral) + let server_ephemeral_secret = StaticSecret::random_from_rng(rand::rngs::OsRng); + let server_ephemeral_public = PublicKey::from(&server_ephemeral_secret); + + s.mix_hash(server_ephemeral_public.as_bytes()); + + // <- ee: MixDH(server_ephemeral_private, client_ephemeral) + let _cipher_ee = s.mix_dh(&server_ephemeral_secret, &client_ephemeral).unwrap(); + + // <- se: MixDH(server_ephemeral_private, client_machine_pub) + let cipher_se = s.mix_dh(&server_ephemeral_secret, &client_machine_pub).unwrap(); + let tag = s.encrypt_and_hash(cipher_se, &[]); + + // Build response message (51 bytes) + let mut resp = vec![0u8; 51]; + resp[0] = MSG_TYPE_RESPONSE; + resp[1..3].copy_from_slice(&48u16.to_be_bytes()); + resp[3..35].copy_from_slice(server_ephemeral_public.as_bytes()); + resp[35..51].copy_from_slice(&tag); + + sock.write_all(&resp).await.unwrap(); + + // Split for transport + let (_c1, _c2) = s.split().unwrap(); + } + + #[tokio::test] + async fn handshake_with_mock_server() { + // Generate server keys. + let server_private = StaticSecret::random_from_rng(rand::rngs::OsRng); + let server_public = PublicKey::from(&server_private); + + // Generate client keys. + let machine_private = StaticSecret::random_from_rng(rand::rngs::OsRng); + let machine_public = PublicKey::from(&machine_private); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let sp = server_public; + let spriv = server_private.clone(); + + // Server task. + let server_handle = tokio::spawn(async move { + let (sock, _) = listener.accept().await.unwrap(); + mock_controlbase_server(sock, spriv, sp).await; + }); + + // Client. + let mut client_stream = TcpStream::connect(addr).await.unwrap(); + let result = perform_handshake( + &mut client_stream, + &machine_private, + &machine_public, + &server_public, + ) + .await + .unwrap(); + + // Verify we got transport ciphers and can encrypt. + let nonce = Nonce::default(); + let mut data = b"hello".to_vec(); + result + .tx_cipher + .encrypt_in_place_detached(&nonce, &[], &mut data) + .unwrap(); + // Ciphertext should differ from plaintext. + assert_ne!(&data[..], b"hello"); + + assert_eq!(result.protocol_version, 49); + + server_handle.await.unwrap(); + } +} diff --git a/sunbeam-net/src/noise/mod.rs b/sunbeam-net/src/noise/mod.rs new file mode 100644 index 00000000..992a7741 --- /dev/null +++ b/sunbeam-net/src/noise/mod.rs @@ -0,0 +1,3 @@ +pub mod framing; +pub mod handshake; +pub mod stream; diff --git a/sunbeam-net/src/noise/stream.rs b/sunbeam-net/src/noise/stream.rs new file mode 100644 index 00000000..86b108e6 --- /dev/null +++ b/sunbeam-net/src/noise/stream.rs @@ -0,0 +1,318 @@ +use std::pin::Pin; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::task::{Context, Poll}; + +use bytes::BytesMut; +use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, Nonce, Tag}; +use futures::stream::StreamExt; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::TcpStream; +use tokio_util::codec::Framed; + +use super::framing::{NoiseFrame, NoiseFrameCodec, FRAME_TYPE_DATA, MAX_FRAME_SIZE}; + +/// AEAD tag size for ChaCha20Poly1305. +const TAG_SIZE: usize = 16; + +/// Maximum plaintext chunk size, leaving room for the AEAD tag (16 bytes). +const MAX_PLAINTEXT_CHUNK: usize = MAX_FRAME_SIZE - TAG_SIZE; + +/// An encrypted, framed stream using the controlbase protocol over TCP. +/// +/// Implements `AsyncRead + AsyncWrite` so it can be used transparently by `h2` +/// and other async I/O consumers. +pub struct NoiseStream { + inner: Framed, + tx_cipher: ChaCha20Poly1305, + rx_cipher: ChaCha20Poly1305, + read_nonce: AtomicU64, + write_nonce: AtomicU64, + read_buf: BytesMut, + // Encrypted frames ready to be sent, accumulated during poll_write. + pending_write: BytesMut, + // Whether we have an error from a previous operation. + write_err: Option, +} + +// SAFETY: ChaCha20Poly1305 is Send, all other fields are Send. +unsafe impl Send for NoiseStream {} + +impl Unpin for NoiseStream {} + +impl std::fmt::Debug for NoiseStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NoiseStream") + .field("read_nonce", &self.read_nonce.load(Ordering::Relaxed)) + .field("write_nonce", &self.write_nonce.load(Ordering::Relaxed)) + .finish() + } +} + +/// Build a 12-byte nonce: first 4 bytes zero, last 8 bytes little-endian u64 counter. +fn build_nonce(counter: u64) -> Nonce { + let mut nonce = [0u8; 12]; + nonce[4..12].copy_from_slice(&counter.to_be_bytes()); + Nonce::from(nonce) +} + +impl NoiseStream { + /// Wrap a TCP stream with controlbase encryption using the given transport ciphers. + /// `leftover` contains any bytes already read from TCP that belong to the first + /// Noise transport record (from the handshake buffer overflow). + pub fn new( + stream: TcpStream, + tx_cipher: ChaCha20Poly1305, + rx_cipher: ChaCha20Poly1305, + leftover: Vec, + ) -> Self { + let mut framed = Framed::new(stream, NoiseFrameCodec); + // Prepend leftover bytes so the codec can parse them as part of the first frame. + if !leftover.is_empty() { + framed.read_buffer_mut().extend_from_slice(&leftover); + } + Self { + inner: framed, + tx_cipher, + rx_cipher, + read_nonce: AtomicU64::new(0), + write_nonce: AtomicU64::new(0), + read_buf: BytesMut::with_capacity(MAX_FRAME_SIZE), + pending_write: BytesMut::new(), + write_err: None, + } + } + + /// Read and consume the early payload sent by Headscale after the handshake. + /// + /// The early payload consists of multiple Noise records: + /// 1. 5 bytes: magic `\xff\xff\xffTS` + /// 2. 4 bytes: BE u32 JSON length + /// 3. N bytes: JSON payload + /// + /// Each part may be a separate Noise record. We read and reassemble them. + /// If the first record is not an early payload (i.e. it's an HTTP/2 frame), + /// we buffer it for h2 to read. + /// + /// Returns the early payload JSON bytes, or an empty vec if no early payload. + pub async fn consume_early_payload(&mut self) -> std::io::Result> { + // Accumulate plaintext from multiple records + let mut accum = Vec::new(); + + // Read first frame + let frame = match self.inner.next().await { + Some(Ok(frame)) => frame, + Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())), + None => return Ok(vec![]), + }; + let plaintext = self.decrypt_frame(&frame)?; + accum.extend_from_slice(&plaintext); + + // Check for early payload magic + if accum.len() < 5 + || accum[0] != 0xff + || accum[1] != 0xff + || accum[2] != 0xff + || accum[3] != b'T' + || accum[4] != b'S' + { + // Not an early payload — buffer for h2 + self.read_buf.extend_from_slice(&accum); + return Ok(vec![]); + } + + // Read more frames until we have the 4-byte length + while accum.len() < 9 { + let frame = match self.inner.next().await { + Some(Ok(f)) => f, + Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())), + None => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "early payload truncated")), + }; + accum.extend_from_slice(&self.decrypt_frame(&frame)?); + } + + let json_len = u32::from_be_bytes([accum[5], accum[6], accum[7], accum[8]]) as usize; + + // Read more frames until we have the full JSON + while accum.len() < 9 + json_len { + let frame = match self.inner.next().await { + Some(Ok(f)) => f, + Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())), + None => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "early payload JSON truncated")), + }; + accum.extend_from_slice(&self.decrypt_frame(&frame)?); + } + + let json_bytes = accum[9..9 + json_len].to_vec(); + tracing::debug!("consumed early payload ({} bytes)", json_bytes.len()); + + // If there's any leftover data after the early payload, buffer it for h2 + if accum.len() > 9 + json_len { + self.read_buf.extend_from_slice(&accum[9 + json_len..]); + } + + Ok(json_bytes) + } + + /// Decrypt a data frame and return plaintext. + fn decrypt_frame(&self, frame: &NoiseFrame) -> std::io::Result> { + let counter = self.read_nonce.fetch_add(1, Ordering::Relaxed); + let nonce = build_nonce(counter); + tracing::trace!("decrypt_frame: type={:02x}, payload_len={}, nonce={}", frame.frame_type, frame.payload.len(), counter); + + if frame.payload.len() < TAG_SIZE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "frame too short for AEAD tag", + )); + } + + let (ct, tag_bytes) = frame.payload.split_at(frame.payload.len() - TAG_SIZE); + let tag = Tag::from_slice(tag_bytes); + let mut plaintext = ct.to_vec(); + self.rx_cipher + .decrypt_in_place_detached(&nonce, &[], &mut plaintext, tag) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?; + Ok(plaintext) + } + + /// Encrypt a chunk and produce a NoiseFrame. + fn encrypt_chunk(&self, plaintext: &[u8]) -> std::io::Result { + let counter = self.write_nonce.fetch_add(1, Ordering::Relaxed); + let nonce = build_nonce(counter); + + let mut ciphertext = plaintext.to_vec(); + let tag = self + .tx_cipher + .encrypt_in_place_detached(&nonce, &[], &mut ciphertext) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?; + ciphertext.extend_from_slice(tag.as_slice()); + + Ok(NoiseFrame { + frame_type: FRAME_TYPE_DATA, + payload: BytesMut::from(&ciphertext[..]), + }) + } +} + +impl AsyncRead for NoiseStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + + // If we have buffered plaintext, return it. + if !this.read_buf.is_empty() { + let to_copy = std::cmp::min(this.read_buf.len(), buf.remaining()); + buf.put_slice(&this.read_buf[..to_copy]); + let _ = this.read_buf.split_to(to_copy); + return Poll::Ready(Ok(())); + } + + // Poll the inner framed stream for the next frame. + match this.inner.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(frame))) => { + if frame.frame_type != FRAME_TYPE_DATA { + // Skip non-data frames; wake to try again. + cx.waker().wake_by_ref(); + return Poll::Pending; + } + match this.decrypt_frame(&frame) { + Ok(plaintext) => { + let to_copy = std::cmp::min(plaintext.len(), buf.remaining()); + buf.put_slice(&plaintext[..to_copy]); + if to_copy < plaintext.len() { + this.read_buf.extend_from_slice(&plaintext[to_copy..]); + } + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), + } + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)), + Poll::Ready(None) => Poll::Ready(Ok(())), // EOF + Poll::Pending => Poll::Pending, + } + } +} + +impl AsyncWrite for NoiseStream { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + + if let Some(e) = this.write_err.take() { + return Poll::Ready(Err(e)); + } + + // Encrypt in chunks and buffer the frames. + let mut written = 0; + for chunk in buf.chunks(MAX_PLAINTEXT_CHUNK) { + match this.encrypt_chunk(chunk) { + Ok(frame) => { + // Manually serialize the frame into pending_write. + let len = frame.payload.len() as u16; + this.pending_write.extend_from_slice(&[ + frame.frame_type, + (len >> 8) as u8, + len as u8, + ]); + this.pending_write.extend_from_slice(&frame.payload); + written += chunk.len(); + } + Err(e) => { + if written > 0 { + // Report what we encrypted so far; save error for next call. + this.write_err = Some(e); + return Poll::Ready(Ok(written)); + } + return Poll::Ready(Err(e)); + } + } + } + + Poll::Ready(Ok(written)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // Flush pending_write bytes through the underlying stream. + if !this.pending_write.is_empty() { + let inner_stream = this.inner.get_mut(); + loop { + if this.pending_write.is_empty() { + break; + } + match Pin::new(&mut *inner_stream).poll_write(cx, &this.pending_write) { + Poll::Ready(Ok(n)) => { + let _ = this.pending_write.split_to(n); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + } + + Pin::new(this.inner.get_mut()).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + Pin::new(this.inner.get_mut()).poll_shutdown(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn max_plaintext_chunk_leaves_room_for_tag() { + assert!(MAX_PLAINTEXT_CHUNK + 16 <= MAX_FRAME_SIZE); + } +}