feat(net): add Noise IK + HTTP/2 stream layer

Tailscale's TS2021 protocol layers HTTP/2 over an encrypted Noise IK
channel reached via HTTP CONNECT-style upgrade. Add the lower half:

- noise/handshake: hand-rolled Noise_IK_25519_ChaChaPoly_BLAKE2s
  initiator with HKDF + ChaCha20-Poly1305 (no snow dependency)
- noise/framing: 3-byte frame codec (1-byte type + 2-byte BE length)
- noise/stream: NoiseStream implementing AsyncRead + AsyncWrite over
  the framed channel so the h2 crate can sit on top
This commit is contained in:
2026-04-07 13:41:01 +01:00
parent 13539e6e85
commit 91cef0a730
5 changed files with 931 additions and 0 deletions

View File

@@ -3,6 +3,7 @@
pub mod config;
pub mod error;
pub mod keys;
pub mod noise;
pub(crate) mod proto;
pub use config::VpnConfig;

View File

@@ -0,0 +1,181 @@
use bytes::{Buf, BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
/// Record frame type (controlbase post-handshake data).
pub const FRAME_TYPE_DATA: u8 = 0x04;
/// Control frame type.
pub const FRAME_TYPE_CONTROL: u8 = 0x01;
/// Maximum payload size per frame.
pub const MAX_FRAME_SIZE: usize = 4096;
/// Size of the frame header (1 byte type + 2 bytes length).
pub const HEADER_SIZE: usize = 3;
/// A single Noise frame (type + payload).
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NoiseFrame {
pub frame_type: u8,
pub payload: BytesMut,
}
/// Codec for encoding/decoding Noise frames on a byte stream.
#[derive(Debug, Default)]
pub struct NoiseFrameCodec;
impl Decoder for NoiseFrameCodec {
type Item = NoiseFrame;
type Error = std::io::Error;
fn decode(&mut self, src: &mut BytesMut) -> std::result::Result<Option<Self::Item>, Self::Error> {
if src.len() < HEADER_SIZE {
return Ok(None);
}
let frame_type = src[0];
let payload_len = u16::from_be_bytes([src[1], src[2]]) as usize;
if payload_len > MAX_FRAME_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("frame payload {payload_len} exceeds maximum {MAX_FRAME_SIZE}"),
));
}
let total = HEADER_SIZE + payload_len;
if src.len() < total {
// Reserve space so the next read can fill it.
src.reserve(total - src.len());
return Ok(None);
}
src.advance(HEADER_SIZE);
let payload = src.split_to(payload_len);
Ok(Some(NoiseFrame {
frame_type,
payload,
}))
}
}
impl Encoder<NoiseFrame> for NoiseFrameCodec {
type Error = std::io::Error;
fn encode(&mut self, item: NoiseFrame, dst: &mut BytesMut) -> std::result::Result<(), Self::Error> {
if item.payload.len() > MAX_FRAME_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"payload length {} exceeds maximum {MAX_FRAME_SIZE}",
item.payload.len()
),
));
}
let len = item.payload.len() as u16;
dst.reserve(HEADER_SIZE + item.payload.len());
dst.put_u8(item.frame_type);
dst.put_u16(len);
dst.put(item.payload);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn round_trip(frame_type: u8, payload: &[u8]) {
let mut codec = NoiseFrameCodec;
let frame = NoiseFrame {
frame_type,
payload: BytesMut::from(payload),
};
let mut buf = BytesMut::new();
codec.encode(frame.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().expect("should decode a frame");
assert_eq!(decoded, frame);
}
#[test]
fn round_trip_data_frame() {
round_trip(FRAME_TYPE_DATA, b"hello world");
}
#[test]
fn round_trip_control_frame() {
round_trip(FRAME_TYPE_CONTROL, b"\x01\x02\x03");
}
#[test]
fn round_trip_empty_payload() {
round_trip(FRAME_TYPE_DATA, b"");
}
#[test]
fn round_trip_max_size() {
let payload = vec![0xAB; MAX_FRAME_SIZE];
round_trip(FRAME_TYPE_DATA, &payload);
}
#[test]
fn partial_header_returns_none() {
let mut codec = NoiseFrameCodec;
let mut buf = BytesMut::from(&[0x00, 0x00][..]);
assert!(codec.decode(&mut buf).unwrap().is_none());
}
#[test]
fn partial_payload_returns_none() {
let mut codec = NoiseFrameCodec;
// Header says 10 bytes of payload, but only 3 provided.
let mut buf = BytesMut::from(&[0x00, 0x00, 0x0A, 0x01, 0x02, 0x03][..]);
assert!(codec.decode(&mut buf).unwrap().is_none());
}
#[test]
fn oversized_payload_is_rejected_on_decode() {
let mut codec = NoiseFrameCodec;
let bad_len = (MAX_FRAME_SIZE as u16) + 1;
let mut buf = BytesMut::new();
buf.put_u8(0x00);
buf.put_u16(bad_len);
buf.extend_from_slice(&vec![0; bad_len as usize]);
assert!(codec.decode(&mut buf).is_err());
}
#[test]
fn oversized_payload_is_rejected_on_encode() {
let mut codec = NoiseFrameCodec;
let frame = NoiseFrame {
frame_type: FRAME_TYPE_DATA,
payload: BytesMut::from(&vec![0; MAX_FRAME_SIZE + 1][..]),
};
let mut buf = BytesMut::new();
assert!(codec.encode(frame, &mut buf).is_err());
}
#[test]
fn multiple_frames_in_buffer() {
let mut codec = NoiseFrameCodec;
let mut buf = BytesMut::new();
let f1 = NoiseFrame {
frame_type: FRAME_TYPE_DATA,
payload: BytesMut::from(&b"aaa"[..]),
};
let f2 = NoiseFrame {
frame_type: FRAME_TYPE_CONTROL,
payload: BytesMut::from(&b"bbb"[..]),
};
codec.encode(f1.clone(), &mut buf).unwrap();
codec.encode(f2.clone(), &mut buf).unwrap();
let d1 = codec.decode(&mut buf).unwrap().unwrap();
let d2 = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(d1, f1);
assert_eq!(d2, f2);
assert!(codec.decode(&mut buf).unwrap().is_none());
}
}

View File

@@ -0,0 +1,428 @@
//! Tailscale controlbase handshake implementation.
//!
//! This implements the custom Noise IK variant used by the Tailscale/Headscale
//! control protocol. It is NOT compatible with standard Noise libraries like `snow`.
use base64::Engine;
use blake2::Digest;
use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce, Tag};
use hkdf::Hkdf;
use hmac::SimpleHmac;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use x25519_dalek::{PublicKey, StaticSecret};
type HkdfBlake2s = Hkdf<blake2::Blake2s256, SimpleHmac<blake2::Blake2s256>>;
const PROTOCOL_NAME: &[u8] = b"Noise_IK_25519_ChaChaPoly_BLAKE2s";
const PROTOCOL_VERSION_PREFIX: &[u8] = b"Tailscale Control Protocol v";
const PROTOCOL_VERSION: u16 = 49;
// Message types
const MSG_TYPE_INITIATION: u8 = 1;
const MSG_TYPE_RESPONSE: u8 = 2;
const MSG_TYPE_ERROR: u8 = 3;
pub(crate) const MSG_TYPE_RECORD: u8 = 4;
// Sizes
const TAG_SIZE: usize = 16;
pub(crate) struct HandshakeResult {
pub tx_cipher: ChaCha20Poly1305, // client -> server
pub rx_cipher: ChaCha20Poly1305, // server -> client
pub handshake_hash: [u8; 32],
pub protocol_version: u16,
/// Leftover bytes from the TCP buffer after the handshake response.
/// These are the start of the first Noise transport record.
pub leftover: Vec<u8>,
}
struct SymmetricState {
h: [u8; 32], // hash
ck: [u8; 32], // chaining key
}
impl SymmetricState {
fn initialize() -> Self {
let h: [u8; 32] = blake2::Blake2s256::digest(PROTOCOL_NAME).into();
Self { h, ck: h }
}
fn mix_hash(&mut self, data: &[u8]) {
let mut hasher = blake2::Blake2s256::new();
hasher.update(&self.h);
hasher.update(data);
self.h = hasher.finalize().into();
}
fn mix_dh(
&mut self,
priv_key: &StaticSecret,
pub_key: &PublicKey,
) -> crate::Result<ChaCha20Poly1305> {
let shared = priv_key.diffie_hellman(pub_key);
let hk = HkdfBlake2s::new(Some(&self.ck), shared.as_bytes());
let mut output = [0u8; 64];
hk.expand(&[], &mut output)
.map_err(|e| crate::Error::Noise(format!("HKDF expand: {e}")))?;
self.ck.copy_from_slice(&output[..32]);
let key: [u8; 32] = output[32..64].try_into().unwrap();
ChaCha20Poly1305::new_from_slice(&key)
.map_err(|e| crate::Error::Noise(format!("AEAD key: {e}")))
}
fn encrypt_and_hash(&mut self, cipher: ChaCha20Poly1305, plaintext: &[u8]) -> Vec<u8> {
let nonce = Nonce::default(); // all zeros
let mut ciphertext = plaintext.to_vec();
let tag = cipher
.encrypt_in_place_detached(&nonce, &self.h, &mut ciphertext)
.expect("encryption failed");
ciphertext.extend_from_slice(tag.as_slice());
self.mix_hash(&ciphertext);
ciphertext
}
fn decrypt_and_hash(
&mut self,
cipher: ChaCha20Poly1305,
ciphertext: &[u8],
) -> crate::Result<Vec<u8>> {
if ciphertext.len() < TAG_SIZE {
return Err(crate::Error::Noise("ciphertext too short".into()));
}
let nonce = Nonce::default();
let (ct, tag_bytes) = ciphertext.split_at(ciphertext.len() - TAG_SIZE);
let tag = Tag::from_slice(tag_bytes);
let mut plaintext = ct.to_vec();
cipher
.decrypt_in_place_detached(&nonce, &self.h, &mut plaintext, tag)
.map_err(|_| crate::Error::Noise("AEAD decryption failed".into()))?;
self.mix_hash(ciphertext);
Ok(plaintext)
}
fn split(self) -> crate::Result<(ChaCha20Poly1305, ChaCha20Poly1305)> {
let hk = HkdfBlake2s::new(Some(&self.ck), &[]);
let mut output = [0u8; 64];
hk.expand(&[], &mut output)
.map_err(|e| crate::Error::Noise(format!("HKDF split: {e}")))?;
let k1: [u8; 32] = output[..32].try_into().unwrap();
let k2: [u8; 32] = output[32..64].try_into().unwrap();
let c1 = ChaCha20Poly1305::new_from_slice(&k1)
.map_err(|e| crate::Error::Noise(format!("c1 key: {e}")))?;
let c2 = ChaCha20Poly1305::new_from_slice(&k2)
.map_err(|e| crate::Error::Noise(format!("c2 key: {e}")))?;
Ok((c1, c2))
}
}
fn protocol_version_prologue(version: u16) -> Vec<u8> {
let mut p = PROTOCOL_VERSION_PREFIX.to_vec();
p.extend_from_slice(version.to_string().as_bytes());
p
}
/// Perform the TS2021 HTTP upgrade and controlbase Noise IK handshake.
///
/// Returns a [`HandshakeResult`] containing the transport ciphers.
pub(crate) async fn perform_handshake(
stream: &mut TcpStream,
machine_private: &StaticSecret,
machine_public: &PublicKey,
server_public: &PublicKey,
) -> crate::Result<HandshakeResult> {
let mut s = SymmetricState::initialize();
// Prologue
s.mix_hash(&protocol_version_prologue(PROTOCOL_VERSION));
// <- s (mix in known server static key)
s.mix_hash(server_public.as_bytes());
// -> e, es, s, ss
let ephemeral_secret = StaticSecret::random_from_rng(rand::rngs::OsRng);
let ephemeral_public = PublicKey::from(&ephemeral_secret);
// Build initiation message (101 bytes)
let mut init = vec![0u8; 101];
// Header
init[0..2].copy_from_slice(&PROTOCOL_VERSION.to_be_bytes());
init[2] = MSG_TYPE_INITIATION;
init[3..5].copy_from_slice(&96u16.to_be_bytes());
// Payload
init[5..37].copy_from_slice(ephemeral_public.as_bytes());
s.mix_hash(ephemeral_public.as_bytes());
let cipher_es = s.mix_dh(&ephemeral_secret, server_public)?;
let encrypted_machine_pub = s.encrypt_and_hash(cipher_es, machine_public.as_bytes());
init[37..85].copy_from_slice(&encrypted_machine_pub);
let cipher_ss = s.mix_dh(machine_private, server_public)?;
let tag = s.encrypt_and_hash(cipher_ss, &[]);
init[85..101].copy_from_slice(&tag);
// Base64-encode for HTTP header
let init_b64 = base64::engine::general_purpose::STANDARD.encode(&init);
// Send HTTP upgrade
let peer = stream
.peer_addr()
.map(|a| a.to_string())
.unwrap_or_else(|_| "localhost".into());
let upgrade_req = format!(
"POST /ts2021 HTTP/1.1\r\n\
Host: {peer}\r\n\
Upgrade: tailscale-control-protocol\r\n\
Connection: upgrade\r\n\
X-Tailscale-Handshake: {init_b64}\r\n\
\r\n"
);
stream
.write_all(upgrade_req.as_bytes())
.await
.map_err(|e| crate::Error::Noise(format!("write upgrade: {e}")))?;
// Read HTTP 101 response
let mut resp_buf = vec![0u8; 8192];
let mut total = 0;
loop {
let n = stream
.read(&mut resp_buf[total..])
.await
.map_err(|e| crate::Error::Noise(format!("read response: {e}")))?;
if n == 0 {
return Err(crate::Error::ConnectionClosed);
}
total += n;
if resp_buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let header_end = resp_buf[..total]
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|p| p + 4)
.ok_or_else(|| crate::Error::Noise("no header boundary".into()))?;
let headers = String::from_utf8_lossy(&resp_buf[..header_end]);
if !headers.contains("101") {
let body = String::from_utf8_lossy(&resp_buf[header_end..total]);
return Err(crate::Error::Noise(format!(
"expected 101, got: {headers}\nBody: {body}"
)));
}
// Read server response message (51 bytes: 3 header + 48 payload)
// Some bytes may already be in resp_buf after the HTTP headers.
let mut resp_data = resp_buf[header_end..total].to_vec();
while resp_data.len() < 51 {
let mut tmp = vec![0u8; 51 - resp_data.len()];
let n = stream
.read(&mut tmp)
.await
.map_err(|e| crate::Error::Noise(format!("read response msg: {e}")))?;
if n == 0 {
return Err(crate::Error::ConnectionClosed);
}
resp_data.extend_from_slice(&tmp[..n]);
}
// Parse response
if resp_data[0] == MSG_TYPE_ERROR {
let err_len = u16::from_be_bytes([resp_data[1], resp_data[2]]) as usize;
let err_msg =
String::from_utf8_lossy(&resp_data[3..3 + err_len.min(resp_data.len() - 3)]);
return Err(crate::Error::Noise(format!("server error: {err_msg}")));
}
if resp_data[0] != MSG_TYPE_RESPONSE {
return Err(crate::Error::Noise(format!(
"unexpected response type {}",
resp_data[0]
)));
}
let server_ephemeral_bytes: [u8; 32] = resp_data[3..35].try_into().unwrap();
let server_ephemeral = PublicKey::from(server_ephemeral_bytes);
let resp_tag = &resp_data[35..51];
// <- e, ee, se
s.mix_hash(server_ephemeral.as_bytes());
let _cipher_ee = s.mix_dh(&ephemeral_secret, &server_ephemeral)?;
let cipher_se = s.mix_dh(machine_private, &server_ephemeral)?;
s.decrypt_and_hash(cipher_se, resp_tag)?;
// Split into transport ciphers
let h = s.h;
let (c1, c2) = s.split()?;
tracing::debug!("controlbase handshake complete");
// Any bytes in resp_data beyond the 51-byte response are the start
// of the first Noise transport record (early payload).
let leftover = if resp_data.len() > 51 {
resp_data[51..].to_vec()
} else {
vec![]
};
Ok(HandshakeResult {
tx_cipher: c1,
rx_cipher: c2,
handshake_hash: h,
protocol_version: PROTOCOL_VERSION,
leftover,
})
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;
/// Mock server that implements the controlbase server-side handshake.
async fn mock_controlbase_server(
mut sock: TcpStream,
server_private: StaticSecret,
server_public: PublicKey,
) {
// Read the full HTTP request.
let mut buf = vec![0u8; 8192];
let mut total = 0;
loop {
let n = sock.read(&mut buf[total..]).await.unwrap();
if n == 0 {
break;
}
total += n;
if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
// Extract the X-Tailscale-Handshake header.
let req_str = String::from_utf8_lossy(&buf[..total]);
let init_b64 = req_str
.lines()
.find(|l| l.starts_with("X-Tailscale-Handshake:"))
.map(|l| l.trim_start_matches("X-Tailscale-Handshake:").trim())
.expect("missing X-Tailscale-Handshake header");
let init_msg = base64::engine::general_purpose::STANDARD
.decode(init_b64)
.expect("bad base64");
assert_eq!(init_msg.len(), 101, "initiation message should be 101 bytes");
// Parse initiation
let _version = u16::from_be_bytes([init_msg[0], init_msg[1]]);
assert_eq!(init_msg[2], MSG_TYPE_INITIATION);
let _payload_len = u16::from_be_bytes([init_msg[3], init_msg[4]]);
let client_ephemeral_bytes: [u8; 32] = init_msg[5..37].try_into().unwrap();
let client_ephemeral = PublicKey::from(client_ephemeral_bytes);
let encrypted_machine_pub = &init_msg[37..85];
let init_tag = &init_msg[85..101];
// Server-side symmetric state
let mut s = SymmetricState::initialize();
s.mix_hash(&protocol_version_prologue(PROTOCOL_VERSION));
s.mix_hash(server_public.as_bytes());
// Process -> e
s.mix_hash(client_ephemeral.as_bytes());
// -> es: MixDH(server_private, client_ephemeral)
let cipher_es = s.mix_dh(&server_private, &client_ephemeral).unwrap();
// -> s: DecryptAndHash to get client machine public key
let client_machine_pub_bytes = s.decrypt_and_hash(cipher_es, encrypted_machine_pub).unwrap();
assert_eq!(client_machine_pub_bytes.len(), 32);
let client_machine_pub = PublicKey::from(<[u8; 32]>::try_from(&client_machine_pub_bytes[..]).unwrap());
// -> ss: MixDH(server_private, client_machine_pub)
let cipher_ss = s.mix_dh(&server_private, &client_machine_pub).unwrap();
s.decrypt_and_hash(cipher_ss, init_tag).unwrap();
// Send 101 response.
sock.write_all(
b"HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: tailscale-control-protocol\r\n\
\r\n",
)
.await
.unwrap();
// <- e (server ephemeral)
let server_ephemeral_secret = StaticSecret::random_from_rng(rand::rngs::OsRng);
let server_ephemeral_public = PublicKey::from(&server_ephemeral_secret);
s.mix_hash(server_ephemeral_public.as_bytes());
// <- ee: MixDH(server_ephemeral_private, client_ephemeral)
let _cipher_ee = s.mix_dh(&server_ephemeral_secret, &client_ephemeral).unwrap();
// <- se: MixDH(server_ephemeral_private, client_machine_pub)
let cipher_se = s.mix_dh(&server_ephemeral_secret, &client_machine_pub).unwrap();
let tag = s.encrypt_and_hash(cipher_se, &[]);
// Build response message (51 bytes)
let mut resp = vec![0u8; 51];
resp[0] = MSG_TYPE_RESPONSE;
resp[1..3].copy_from_slice(&48u16.to_be_bytes());
resp[3..35].copy_from_slice(server_ephemeral_public.as_bytes());
resp[35..51].copy_from_slice(&tag);
sock.write_all(&resp).await.unwrap();
// Split for transport
let (_c1, _c2) = s.split().unwrap();
}
#[tokio::test]
async fn handshake_with_mock_server() {
// Generate server keys.
let server_private = StaticSecret::random_from_rng(rand::rngs::OsRng);
let server_public = PublicKey::from(&server_private);
// Generate client keys.
let machine_private = StaticSecret::random_from_rng(rand::rngs::OsRng);
let machine_public = PublicKey::from(&machine_private);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let sp = server_public;
let spriv = server_private.clone();
// Server task.
let server_handle = tokio::spawn(async move {
let (sock, _) = listener.accept().await.unwrap();
mock_controlbase_server(sock, spriv, sp).await;
});
// Client.
let mut client_stream = TcpStream::connect(addr).await.unwrap();
let result = perform_handshake(
&mut client_stream,
&machine_private,
&machine_public,
&server_public,
)
.await
.unwrap();
// Verify we got transport ciphers and can encrypt.
let nonce = Nonce::default();
let mut data = b"hello".to_vec();
result
.tx_cipher
.encrypt_in_place_detached(&nonce, &[], &mut data)
.unwrap();
// Ciphertext should differ from plaintext.
assert_ne!(&data[..], b"hello");
assert_eq!(result.protocol_version, 49);
server_handle.await.unwrap();
}
}

View File

@@ -0,0 +1,3 @@
pub mod framing;
pub mod handshake;
pub mod stream;

View File

@@ -0,0 +1,318 @@
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use bytes::BytesMut;
use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, Nonce, Tag};
use futures::stream::StreamExt;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_util::codec::Framed;
use super::framing::{NoiseFrame, NoiseFrameCodec, FRAME_TYPE_DATA, MAX_FRAME_SIZE};
/// AEAD tag size for ChaCha20Poly1305.
const TAG_SIZE: usize = 16;
/// Maximum plaintext chunk size, leaving room for the AEAD tag (16 bytes).
const MAX_PLAINTEXT_CHUNK: usize = MAX_FRAME_SIZE - TAG_SIZE;
/// An encrypted, framed stream using the controlbase protocol over TCP.
///
/// Implements `AsyncRead + AsyncWrite` so it can be used transparently by `h2`
/// and other async I/O consumers.
pub struct NoiseStream {
inner: Framed<TcpStream, NoiseFrameCodec>,
tx_cipher: ChaCha20Poly1305,
rx_cipher: ChaCha20Poly1305,
read_nonce: AtomicU64,
write_nonce: AtomicU64,
read_buf: BytesMut,
// Encrypted frames ready to be sent, accumulated during poll_write.
pending_write: BytesMut,
// Whether we have an error from a previous operation.
write_err: Option<std::io::Error>,
}
// SAFETY: ChaCha20Poly1305 is Send, all other fields are Send.
unsafe impl Send for NoiseStream {}
impl Unpin for NoiseStream {}
impl std::fmt::Debug for NoiseStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NoiseStream")
.field("read_nonce", &self.read_nonce.load(Ordering::Relaxed))
.field("write_nonce", &self.write_nonce.load(Ordering::Relaxed))
.finish()
}
}
/// Build a 12-byte nonce: first 4 bytes zero, last 8 bytes little-endian u64 counter.
fn build_nonce(counter: u64) -> Nonce {
let mut nonce = [0u8; 12];
nonce[4..12].copy_from_slice(&counter.to_be_bytes());
Nonce::from(nonce)
}
impl NoiseStream {
/// Wrap a TCP stream with controlbase encryption using the given transport ciphers.
/// `leftover` contains any bytes already read from TCP that belong to the first
/// Noise transport record (from the handshake buffer overflow).
pub fn new(
stream: TcpStream,
tx_cipher: ChaCha20Poly1305,
rx_cipher: ChaCha20Poly1305,
leftover: Vec<u8>,
) -> Self {
let mut framed = Framed::new(stream, NoiseFrameCodec);
// Prepend leftover bytes so the codec can parse them as part of the first frame.
if !leftover.is_empty() {
framed.read_buffer_mut().extend_from_slice(&leftover);
}
Self {
inner: framed,
tx_cipher,
rx_cipher,
read_nonce: AtomicU64::new(0),
write_nonce: AtomicU64::new(0),
read_buf: BytesMut::with_capacity(MAX_FRAME_SIZE),
pending_write: BytesMut::new(),
write_err: None,
}
}
/// Read and consume the early payload sent by Headscale after the handshake.
///
/// The early payload consists of multiple Noise records:
/// 1. 5 bytes: magic `\xff\xff\xffTS`
/// 2. 4 bytes: BE u32 JSON length
/// 3. N bytes: JSON payload
///
/// Each part may be a separate Noise record. We read and reassemble them.
/// If the first record is not an early payload (i.e. it's an HTTP/2 frame),
/// we buffer it for h2 to read.
///
/// Returns the early payload JSON bytes, or an empty vec if no early payload.
pub async fn consume_early_payload(&mut self) -> std::io::Result<Vec<u8>> {
// Accumulate plaintext from multiple records
let mut accum = Vec::new();
// Read first frame
let frame = match self.inner.next().await {
Some(Ok(frame)) => frame,
Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
None => return Ok(vec![]),
};
let plaintext = self.decrypt_frame(&frame)?;
accum.extend_from_slice(&plaintext);
// Check for early payload magic
if accum.len() < 5
|| accum[0] != 0xff
|| accum[1] != 0xff
|| accum[2] != 0xff
|| accum[3] != b'T'
|| accum[4] != b'S'
{
// Not an early payload — buffer for h2
self.read_buf.extend_from_slice(&accum);
return Ok(vec![]);
}
// Read more frames until we have the 4-byte length
while accum.len() < 9 {
let frame = match self.inner.next().await {
Some(Ok(f)) => f,
Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
None => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "early payload truncated")),
};
accum.extend_from_slice(&self.decrypt_frame(&frame)?);
}
let json_len = u32::from_be_bytes([accum[5], accum[6], accum[7], accum[8]]) as usize;
// Read more frames until we have the full JSON
while accum.len() < 9 + json_len {
let frame = match self.inner.next().await {
Some(Ok(f)) => f,
Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
None => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "early payload JSON truncated")),
};
accum.extend_from_slice(&self.decrypt_frame(&frame)?);
}
let json_bytes = accum[9..9 + json_len].to_vec();
tracing::debug!("consumed early payload ({} bytes)", json_bytes.len());
// If there's any leftover data after the early payload, buffer it for h2
if accum.len() > 9 + json_len {
self.read_buf.extend_from_slice(&accum[9 + json_len..]);
}
Ok(json_bytes)
}
/// Decrypt a data frame and return plaintext.
fn decrypt_frame(&self, frame: &NoiseFrame) -> std::io::Result<Vec<u8>> {
let counter = self.read_nonce.fetch_add(1, Ordering::Relaxed);
let nonce = build_nonce(counter);
tracing::trace!("decrypt_frame: type={:02x}, payload_len={}, nonce={}", frame.frame_type, frame.payload.len(), counter);
if frame.payload.len() < TAG_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"frame too short for AEAD tag",
));
}
let (ct, tag_bytes) = frame.payload.split_at(frame.payload.len() - TAG_SIZE);
let tag = Tag::from_slice(tag_bytes);
let mut plaintext = ct.to_vec();
self.rx_cipher
.decrypt_in_place_detached(&nonce, &[], &mut plaintext, tag)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
Ok(plaintext)
}
/// Encrypt a chunk and produce a NoiseFrame.
fn encrypt_chunk(&self, plaintext: &[u8]) -> std::io::Result<NoiseFrame> {
let counter = self.write_nonce.fetch_add(1, Ordering::Relaxed);
let nonce = build_nonce(counter);
let mut ciphertext = plaintext.to_vec();
let tag = self
.tx_cipher
.encrypt_in_place_detached(&nonce, &[], &mut ciphertext)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
ciphertext.extend_from_slice(tag.as_slice());
Ok(NoiseFrame {
frame_type: FRAME_TYPE_DATA,
payload: BytesMut::from(&ciphertext[..]),
})
}
}
impl AsyncRead for NoiseStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
// If we have buffered plaintext, return it.
if !this.read_buf.is_empty() {
let to_copy = std::cmp::min(this.read_buf.len(), buf.remaining());
buf.put_slice(&this.read_buf[..to_copy]);
let _ = this.read_buf.split_to(to_copy);
return Poll::Ready(Ok(()));
}
// Poll the inner framed stream for the next frame.
match this.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(frame))) => {
if frame.frame_type != FRAME_TYPE_DATA {
// Skip non-data frames; wake to try again.
cx.waker().wake_by_ref();
return Poll::Pending;
}
match this.decrypt_frame(&frame) {
Ok(plaintext) => {
let to_copy = std::cmp::min(plaintext.len(), buf.remaining());
buf.put_slice(&plaintext[..to_copy]);
if to_copy < plaintext.len() {
this.read_buf.extend_from_slice(&plaintext[to_copy..]);
}
Poll::Ready(Ok(()))
}
Err(e) => Poll::Ready(Err(e)),
}
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)),
Poll::Ready(None) => Poll::Ready(Ok(())), // EOF
Poll::Pending => Poll::Pending,
}
}
}
impl AsyncWrite for NoiseStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let this = self.get_mut();
if let Some(e) = this.write_err.take() {
return Poll::Ready(Err(e));
}
// Encrypt in chunks and buffer the frames.
let mut written = 0;
for chunk in buf.chunks(MAX_PLAINTEXT_CHUNK) {
match this.encrypt_chunk(chunk) {
Ok(frame) => {
// Manually serialize the frame into pending_write.
let len = frame.payload.len() as u16;
this.pending_write.extend_from_slice(&[
frame.frame_type,
(len >> 8) as u8,
len as u8,
]);
this.pending_write.extend_from_slice(&frame.payload);
written += chunk.len();
}
Err(e) => {
if written > 0 {
// Report what we encrypted so far; save error for next call.
this.write_err = Some(e);
return Poll::Ready(Ok(written));
}
return Poll::Ready(Err(e));
}
}
}
Poll::Ready(Ok(written))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
// Flush pending_write bytes through the underlying stream.
if !this.pending_write.is_empty() {
let inner_stream = this.inner.get_mut();
loop {
if this.pending_write.is_empty() {
break;
}
match Pin::new(&mut *inner_stream).poll_write(cx, &this.pending_write) {
Poll::Ready(Ok(n)) => {
let _ = this.pending_write.split_to(n);
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
}
Pin::new(this.inner.get_mut()).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
Pin::new(this.inner.get_mut()).poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn max_plaintext_chunk_leaves_room_for_tag() {
assert!(MAX_PLAINTEXT_CHUNK + 16 <= MAX_FRAME_SIZE);
}
}