feat(net): add Noise IK + HTTP/2 stream layer
Tailscale's TS2021 protocol layers HTTP/2 over an encrypted Noise IK channel reached via HTTP CONNECT-style upgrade. Add the lower half: - noise/handshake: hand-rolled Noise_IK_25519_ChaChaPoly_BLAKE2s initiator with HKDF + ChaCha20-Poly1305 (no snow dependency) - noise/framing: 3-byte frame codec (1-byte type + 2-byte BE length) - noise/stream: NoiseStream implementing AsyncRead + AsyncWrite over the framed channel so the h2 crate can sit on top
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod keys;
|
pub mod keys;
|
||||||
|
pub mod noise;
|
||||||
pub(crate) mod proto;
|
pub(crate) mod proto;
|
||||||
|
|
||||||
pub use config::VpnConfig;
|
pub use config::VpnConfig;
|
||||||
|
|||||||
181
sunbeam-net/src/noise/framing.rs
Normal file
181
sunbeam-net/src/noise/framing.rs
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
use bytes::{Buf, BufMut, BytesMut};
|
||||||
|
use tokio_util::codec::{Decoder, Encoder};
|
||||||
|
|
||||||
|
/// Record frame type (controlbase post-handshake data).
|
||||||
|
pub const FRAME_TYPE_DATA: u8 = 0x04;
|
||||||
|
/// Control frame type.
|
||||||
|
pub const FRAME_TYPE_CONTROL: u8 = 0x01;
|
||||||
|
/// Maximum payload size per frame.
|
||||||
|
pub const MAX_FRAME_SIZE: usize = 4096;
|
||||||
|
/// Size of the frame header (1 byte type + 2 bytes length).
|
||||||
|
pub const HEADER_SIZE: usize = 3;
|
||||||
|
|
||||||
|
/// A single Noise frame (type + payload).
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct NoiseFrame {
|
||||||
|
pub frame_type: u8,
|
||||||
|
pub payload: BytesMut,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Codec for encoding/decoding Noise frames on a byte stream.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct NoiseFrameCodec;
|
||||||
|
|
||||||
|
impl Decoder for NoiseFrameCodec {
|
||||||
|
type Item = NoiseFrame;
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
fn decode(&mut self, src: &mut BytesMut) -> std::result::Result<Option<Self::Item>, Self::Error> {
|
||||||
|
if src.len() < HEADER_SIZE {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let frame_type = src[0];
|
||||||
|
let payload_len = u16::from_be_bytes([src[1], src[2]]) as usize;
|
||||||
|
|
||||||
|
if payload_len > MAX_FRAME_SIZE {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("frame payload {payload_len} exceeds maximum {MAX_FRAME_SIZE}"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let total = HEADER_SIZE + payload_len;
|
||||||
|
if src.len() < total {
|
||||||
|
// Reserve space so the next read can fill it.
|
||||||
|
src.reserve(total - src.len());
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
src.advance(HEADER_SIZE);
|
||||||
|
let payload = src.split_to(payload_len);
|
||||||
|
|
||||||
|
Ok(Some(NoiseFrame {
|
||||||
|
frame_type,
|
||||||
|
payload,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder<NoiseFrame> for NoiseFrameCodec {
|
||||||
|
type Error = std::io::Error;
|
||||||
|
|
||||||
|
fn encode(&mut self, item: NoiseFrame, dst: &mut BytesMut) -> std::result::Result<(), Self::Error> {
|
||||||
|
if item.payload.len() > MAX_FRAME_SIZE {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidInput,
|
||||||
|
format!(
|
||||||
|
"payload length {} exceeds maximum {MAX_FRAME_SIZE}",
|
||||||
|
item.payload.len()
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = item.payload.len() as u16;
|
||||||
|
dst.reserve(HEADER_SIZE + item.payload.len());
|
||||||
|
dst.put_u8(item.frame_type);
|
||||||
|
dst.put_u16(len);
|
||||||
|
dst.put(item.payload);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn round_trip(frame_type: u8, payload: &[u8]) {
|
||||||
|
let mut codec = NoiseFrameCodec;
|
||||||
|
let frame = NoiseFrame {
|
||||||
|
frame_type,
|
||||||
|
payload: BytesMut::from(payload),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
codec.encode(frame.clone(), &mut buf).unwrap();
|
||||||
|
|
||||||
|
let decoded = codec.decode(&mut buf).unwrap().expect("should decode a frame");
|
||||||
|
assert_eq!(decoded, frame);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn round_trip_data_frame() {
|
||||||
|
round_trip(FRAME_TYPE_DATA, b"hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn round_trip_control_frame() {
|
||||||
|
round_trip(FRAME_TYPE_CONTROL, b"\x01\x02\x03");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn round_trip_empty_payload() {
|
||||||
|
round_trip(FRAME_TYPE_DATA, b"");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn round_trip_max_size() {
|
||||||
|
let payload = vec![0xAB; MAX_FRAME_SIZE];
|
||||||
|
round_trip(FRAME_TYPE_DATA, &payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn partial_header_returns_none() {
|
||||||
|
let mut codec = NoiseFrameCodec;
|
||||||
|
let mut buf = BytesMut::from(&[0x00, 0x00][..]);
|
||||||
|
assert!(codec.decode(&mut buf).unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn partial_payload_returns_none() {
|
||||||
|
let mut codec = NoiseFrameCodec;
|
||||||
|
// Header says 10 bytes of payload, but only 3 provided.
|
||||||
|
let mut buf = BytesMut::from(&[0x00, 0x00, 0x0A, 0x01, 0x02, 0x03][..]);
|
||||||
|
assert!(codec.decode(&mut buf).unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oversized_payload_is_rejected_on_decode() {
|
||||||
|
let mut codec = NoiseFrameCodec;
|
||||||
|
let bad_len = (MAX_FRAME_SIZE as u16) + 1;
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
buf.put_u8(0x00);
|
||||||
|
buf.put_u16(bad_len);
|
||||||
|
buf.extend_from_slice(&vec![0; bad_len as usize]);
|
||||||
|
assert!(codec.decode(&mut buf).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oversized_payload_is_rejected_on_encode() {
|
||||||
|
let mut codec = NoiseFrameCodec;
|
||||||
|
let frame = NoiseFrame {
|
||||||
|
frame_type: FRAME_TYPE_DATA,
|
||||||
|
payload: BytesMut::from(&vec![0; MAX_FRAME_SIZE + 1][..]),
|
||||||
|
};
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
assert!(codec.encode(frame, &mut buf).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multiple_frames_in_buffer() {
|
||||||
|
let mut codec = NoiseFrameCodec;
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
|
||||||
|
let f1 = NoiseFrame {
|
||||||
|
frame_type: FRAME_TYPE_DATA,
|
||||||
|
payload: BytesMut::from(&b"aaa"[..]),
|
||||||
|
};
|
||||||
|
let f2 = NoiseFrame {
|
||||||
|
frame_type: FRAME_TYPE_CONTROL,
|
||||||
|
payload: BytesMut::from(&b"bbb"[..]),
|
||||||
|
};
|
||||||
|
codec.encode(f1.clone(), &mut buf).unwrap();
|
||||||
|
codec.encode(f2.clone(), &mut buf).unwrap();
|
||||||
|
|
||||||
|
let d1 = codec.decode(&mut buf).unwrap().unwrap();
|
||||||
|
let d2 = codec.decode(&mut buf).unwrap().unwrap();
|
||||||
|
assert_eq!(d1, f1);
|
||||||
|
assert_eq!(d2, f2);
|
||||||
|
assert!(codec.decode(&mut buf).unwrap().is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
428
sunbeam-net/src/noise/handshake.rs
Normal file
428
sunbeam-net/src/noise/handshake.rs
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
//! Tailscale controlbase handshake implementation.
|
||||||
|
//!
|
||||||
|
//! This implements the custom Noise IK variant used by the Tailscale/Headscale
|
||||||
|
//! control protocol. It is NOT compatible with standard Noise libraries like `snow`.
|
||||||
|
|
||||||
|
use base64::Engine;
|
||||||
|
use blake2::Digest;
|
||||||
|
use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit, Nonce, Tag};
|
||||||
|
use hkdf::Hkdf;
|
||||||
|
use hmac::SimpleHmac;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use x25519_dalek::{PublicKey, StaticSecret};
|
||||||
|
|
||||||
|
type HkdfBlake2s = Hkdf<blake2::Blake2s256, SimpleHmac<blake2::Blake2s256>>;
|
||||||
|
|
||||||
|
const PROTOCOL_NAME: &[u8] = b"Noise_IK_25519_ChaChaPoly_BLAKE2s";
|
||||||
|
const PROTOCOL_VERSION_PREFIX: &[u8] = b"Tailscale Control Protocol v";
|
||||||
|
const PROTOCOL_VERSION: u16 = 49;
|
||||||
|
|
||||||
|
// Message types
|
||||||
|
const MSG_TYPE_INITIATION: u8 = 1;
|
||||||
|
const MSG_TYPE_RESPONSE: u8 = 2;
|
||||||
|
const MSG_TYPE_ERROR: u8 = 3;
|
||||||
|
pub(crate) const MSG_TYPE_RECORD: u8 = 4;
|
||||||
|
|
||||||
|
// Sizes
|
||||||
|
const TAG_SIZE: usize = 16;
|
||||||
|
|
||||||
|
pub(crate) struct HandshakeResult {
|
||||||
|
pub tx_cipher: ChaCha20Poly1305, // client -> server
|
||||||
|
pub rx_cipher: ChaCha20Poly1305, // server -> client
|
||||||
|
pub handshake_hash: [u8; 32],
|
||||||
|
pub protocol_version: u16,
|
||||||
|
/// Leftover bytes from the TCP buffer after the handshake response.
|
||||||
|
/// These are the start of the first Noise transport record.
|
||||||
|
pub leftover: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SymmetricState {
|
||||||
|
h: [u8; 32], // hash
|
||||||
|
ck: [u8; 32], // chaining key
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SymmetricState {
|
||||||
|
fn initialize() -> Self {
|
||||||
|
let h: [u8; 32] = blake2::Blake2s256::digest(PROTOCOL_NAME).into();
|
||||||
|
Self { h, ck: h }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mix_hash(&mut self, data: &[u8]) {
|
||||||
|
let mut hasher = blake2::Blake2s256::new();
|
||||||
|
hasher.update(&self.h);
|
||||||
|
hasher.update(data);
|
||||||
|
self.h = hasher.finalize().into();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mix_dh(
|
||||||
|
&mut self,
|
||||||
|
priv_key: &StaticSecret,
|
||||||
|
pub_key: &PublicKey,
|
||||||
|
) -> crate::Result<ChaCha20Poly1305> {
|
||||||
|
let shared = priv_key.diffie_hellman(pub_key);
|
||||||
|
let hk = HkdfBlake2s::new(Some(&self.ck), shared.as_bytes());
|
||||||
|
let mut output = [0u8; 64];
|
||||||
|
hk.expand(&[], &mut output)
|
||||||
|
.map_err(|e| crate::Error::Noise(format!("HKDF expand: {e}")))?;
|
||||||
|
self.ck.copy_from_slice(&output[..32]);
|
||||||
|
let key: [u8; 32] = output[32..64].try_into().unwrap();
|
||||||
|
ChaCha20Poly1305::new_from_slice(&key)
|
||||||
|
.map_err(|e| crate::Error::Noise(format!("AEAD key: {e}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encrypt_and_hash(&mut self, cipher: ChaCha20Poly1305, plaintext: &[u8]) -> Vec<u8> {
|
||||||
|
let nonce = Nonce::default(); // all zeros
|
||||||
|
let mut ciphertext = plaintext.to_vec();
|
||||||
|
let tag = cipher
|
||||||
|
.encrypt_in_place_detached(&nonce, &self.h, &mut ciphertext)
|
||||||
|
.expect("encryption failed");
|
||||||
|
ciphertext.extend_from_slice(tag.as_slice());
|
||||||
|
self.mix_hash(&ciphertext);
|
||||||
|
ciphertext
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decrypt_and_hash(
|
||||||
|
&mut self,
|
||||||
|
cipher: ChaCha20Poly1305,
|
||||||
|
ciphertext: &[u8],
|
||||||
|
) -> crate::Result<Vec<u8>> {
|
||||||
|
if ciphertext.len() < TAG_SIZE {
|
||||||
|
return Err(crate::Error::Noise("ciphertext too short".into()));
|
||||||
|
}
|
||||||
|
let nonce = Nonce::default();
|
||||||
|
let (ct, tag_bytes) = ciphertext.split_at(ciphertext.len() - TAG_SIZE);
|
||||||
|
let tag = Tag::from_slice(tag_bytes);
|
||||||
|
let mut plaintext = ct.to_vec();
|
||||||
|
cipher
|
||||||
|
.decrypt_in_place_detached(&nonce, &self.h, &mut plaintext, tag)
|
||||||
|
.map_err(|_| crate::Error::Noise("AEAD decryption failed".into()))?;
|
||||||
|
self.mix_hash(ciphertext);
|
||||||
|
Ok(plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split(self) -> crate::Result<(ChaCha20Poly1305, ChaCha20Poly1305)> {
|
||||||
|
let hk = HkdfBlake2s::new(Some(&self.ck), &[]);
|
||||||
|
let mut output = [0u8; 64];
|
||||||
|
hk.expand(&[], &mut output)
|
||||||
|
.map_err(|e| crate::Error::Noise(format!("HKDF split: {e}")))?;
|
||||||
|
let k1: [u8; 32] = output[..32].try_into().unwrap();
|
||||||
|
let k2: [u8; 32] = output[32..64].try_into().unwrap();
|
||||||
|
let c1 = ChaCha20Poly1305::new_from_slice(&k1)
|
||||||
|
.map_err(|e| crate::Error::Noise(format!("c1 key: {e}")))?;
|
||||||
|
let c2 = ChaCha20Poly1305::new_from_slice(&k2)
|
||||||
|
.map_err(|e| crate::Error::Noise(format!("c2 key: {e}")))?;
|
||||||
|
Ok((c1, c2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn protocol_version_prologue(version: u16) -> Vec<u8> {
|
||||||
|
let mut p = PROTOCOL_VERSION_PREFIX.to_vec();
|
||||||
|
p.extend_from_slice(version.to_string().as_bytes());
|
||||||
|
p
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform the TS2021 HTTP upgrade and controlbase Noise IK handshake.
|
||||||
|
///
|
||||||
|
/// Returns a [`HandshakeResult`] containing the transport ciphers.
|
||||||
|
pub(crate) async fn perform_handshake(
|
||||||
|
stream: &mut TcpStream,
|
||||||
|
machine_private: &StaticSecret,
|
||||||
|
machine_public: &PublicKey,
|
||||||
|
server_public: &PublicKey,
|
||||||
|
) -> crate::Result<HandshakeResult> {
|
||||||
|
let mut s = SymmetricState::initialize();
|
||||||
|
|
||||||
|
// Prologue
|
||||||
|
s.mix_hash(&protocol_version_prologue(PROTOCOL_VERSION));
|
||||||
|
|
||||||
|
// <- s (mix in known server static key)
|
||||||
|
s.mix_hash(server_public.as_bytes());
|
||||||
|
|
||||||
|
// -> e, es, s, ss
|
||||||
|
let ephemeral_secret = StaticSecret::random_from_rng(rand::rngs::OsRng);
|
||||||
|
let ephemeral_public = PublicKey::from(&ephemeral_secret);
|
||||||
|
|
||||||
|
// Build initiation message (101 bytes)
|
||||||
|
let mut init = vec![0u8; 101];
|
||||||
|
// Header
|
||||||
|
init[0..2].copy_from_slice(&PROTOCOL_VERSION.to_be_bytes());
|
||||||
|
init[2] = MSG_TYPE_INITIATION;
|
||||||
|
init[3..5].copy_from_slice(&96u16.to_be_bytes());
|
||||||
|
// Payload
|
||||||
|
init[5..37].copy_from_slice(ephemeral_public.as_bytes());
|
||||||
|
s.mix_hash(ephemeral_public.as_bytes());
|
||||||
|
|
||||||
|
let cipher_es = s.mix_dh(&ephemeral_secret, server_public)?;
|
||||||
|
let encrypted_machine_pub = s.encrypt_and_hash(cipher_es, machine_public.as_bytes());
|
||||||
|
init[37..85].copy_from_slice(&encrypted_machine_pub);
|
||||||
|
|
||||||
|
let cipher_ss = s.mix_dh(machine_private, server_public)?;
|
||||||
|
let tag = s.encrypt_and_hash(cipher_ss, &[]);
|
||||||
|
init[85..101].copy_from_slice(&tag);
|
||||||
|
|
||||||
|
// Base64-encode for HTTP header
|
||||||
|
let init_b64 = base64::engine::general_purpose::STANDARD.encode(&init);
|
||||||
|
|
||||||
|
// Send HTTP upgrade
|
||||||
|
let peer = stream
|
||||||
|
.peer_addr()
|
||||||
|
.map(|a| a.to_string())
|
||||||
|
.unwrap_or_else(|_| "localhost".into());
|
||||||
|
let upgrade_req = format!(
|
||||||
|
"POST /ts2021 HTTP/1.1\r\n\
|
||||||
|
Host: {peer}\r\n\
|
||||||
|
Upgrade: tailscale-control-protocol\r\n\
|
||||||
|
Connection: upgrade\r\n\
|
||||||
|
X-Tailscale-Handshake: {init_b64}\r\n\
|
||||||
|
\r\n"
|
||||||
|
);
|
||||||
|
stream
|
||||||
|
.write_all(upgrade_req.as_bytes())
|
||||||
|
.await
|
||||||
|
.map_err(|e| crate::Error::Noise(format!("write upgrade: {e}")))?;
|
||||||
|
|
||||||
|
// Read HTTP 101 response
|
||||||
|
let mut resp_buf = vec![0u8; 8192];
|
||||||
|
let mut total = 0;
|
||||||
|
loop {
|
||||||
|
let n = stream
|
||||||
|
.read(&mut resp_buf[total..])
|
||||||
|
.await
|
||||||
|
.map_err(|e| crate::Error::Noise(format!("read response: {e}")))?;
|
||||||
|
if n == 0 {
|
||||||
|
return Err(crate::Error::ConnectionClosed);
|
||||||
|
}
|
||||||
|
total += n;
|
||||||
|
if resp_buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let header_end = resp_buf[..total]
|
||||||
|
.windows(4)
|
||||||
|
.position(|w| w == b"\r\n\r\n")
|
||||||
|
.map(|p| p + 4)
|
||||||
|
.ok_or_else(|| crate::Error::Noise("no header boundary".into()))?;
|
||||||
|
|
||||||
|
let headers = String::from_utf8_lossy(&resp_buf[..header_end]);
|
||||||
|
if !headers.contains("101") {
|
||||||
|
let body = String::from_utf8_lossy(&resp_buf[header_end..total]);
|
||||||
|
return Err(crate::Error::Noise(format!(
|
||||||
|
"expected 101, got: {headers}\nBody: {body}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read server response message (51 bytes: 3 header + 48 payload)
|
||||||
|
// Some bytes may already be in resp_buf after the HTTP headers.
|
||||||
|
let mut resp_data = resp_buf[header_end..total].to_vec();
|
||||||
|
while resp_data.len() < 51 {
|
||||||
|
let mut tmp = vec![0u8; 51 - resp_data.len()];
|
||||||
|
let n = stream
|
||||||
|
.read(&mut tmp)
|
||||||
|
.await
|
||||||
|
.map_err(|e| crate::Error::Noise(format!("read response msg: {e}")))?;
|
||||||
|
if n == 0 {
|
||||||
|
return Err(crate::Error::ConnectionClosed);
|
||||||
|
}
|
||||||
|
resp_data.extend_from_slice(&tmp[..n]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response
|
||||||
|
if resp_data[0] == MSG_TYPE_ERROR {
|
||||||
|
let err_len = u16::from_be_bytes([resp_data[1], resp_data[2]]) as usize;
|
||||||
|
let err_msg =
|
||||||
|
String::from_utf8_lossy(&resp_data[3..3 + err_len.min(resp_data.len() - 3)]);
|
||||||
|
return Err(crate::Error::Noise(format!("server error: {err_msg}")));
|
||||||
|
}
|
||||||
|
if resp_data[0] != MSG_TYPE_RESPONSE {
|
||||||
|
return Err(crate::Error::Noise(format!(
|
||||||
|
"unexpected response type {}",
|
||||||
|
resp_data[0]
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let server_ephemeral_bytes: [u8; 32] = resp_data[3..35].try_into().unwrap();
|
||||||
|
let server_ephemeral = PublicKey::from(server_ephemeral_bytes);
|
||||||
|
let resp_tag = &resp_data[35..51];
|
||||||
|
|
||||||
|
// <- e, ee, se
|
||||||
|
s.mix_hash(server_ephemeral.as_bytes());
|
||||||
|
let _cipher_ee = s.mix_dh(&ephemeral_secret, &server_ephemeral)?;
|
||||||
|
let cipher_se = s.mix_dh(machine_private, &server_ephemeral)?;
|
||||||
|
s.decrypt_and_hash(cipher_se, resp_tag)?;
|
||||||
|
|
||||||
|
// Split into transport ciphers
|
||||||
|
let h = s.h;
|
||||||
|
let (c1, c2) = s.split()?;
|
||||||
|
|
||||||
|
tracing::debug!("controlbase handshake complete");
|
||||||
|
|
||||||
|
// Any bytes in resp_data beyond the 51-byte response are the start
|
||||||
|
// of the first Noise transport record (early payload).
|
||||||
|
let leftover = if resp_data.len() > 51 {
|
||||||
|
resp_data[51..].to_vec()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(HandshakeResult {
|
||||||
|
tx_cipher: c1,
|
||||||
|
rx_cipher: c2,
|
||||||
|
handshake_hash: h,
|
||||||
|
protocol_version: PROTOCOL_VERSION,
|
||||||
|
leftover,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
/// Mock server that implements the controlbase server-side handshake.
|
||||||
|
async fn mock_controlbase_server(
|
||||||
|
mut sock: TcpStream,
|
||||||
|
server_private: StaticSecret,
|
||||||
|
server_public: PublicKey,
|
||||||
|
) {
|
||||||
|
// Read the full HTTP request.
|
||||||
|
let mut buf = vec![0u8; 8192];
|
||||||
|
let mut total = 0;
|
||||||
|
loop {
|
||||||
|
let n = sock.read(&mut buf[total..]).await.unwrap();
|
||||||
|
if n == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
total += n;
|
||||||
|
if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the X-Tailscale-Handshake header.
|
||||||
|
let req_str = String::from_utf8_lossy(&buf[..total]);
|
||||||
|
let init_b64 = req_str
|
||||||
|
.lines()
|
||||||
|
.find(|l| l.starts_with("X-Tailscale-Handshake:"))
|
||||||
|
.map(|l| l.trim_start_matches("X-Tailscale-Handshake:").trim())
|
||||||
|
.expect("missing X-Tailscale-Handshake header");
|
||||||
|
let init_msg = base64::engine::general_purpose::STANDARD
|
||||||
|
.decode(init_b64)
|
||||||
|
.expect("bad base64");
|
||||||
|
|
||||||
|
assert_eq!(init_msg.len(), 101, "initiation message should be 101 bytes");
|
||||||
|
|
||||||
|
// Parse initiation
|
||||||
|
let _version = u16::from_be_bytes([init_msg[0], init_msg[1]]);
|
||||||
|
assert_eq!(init_msg[2], MSG_TYPE_INITIATION);
|
||||||
|
let _payload_len = u16::from_be_bytes([init_msg[3], init_msg[4]]);
|
||||||
|
|
||||||
|
let client_ephemeral_bytes: [u8; 32] = init_msg[5..37].try_into().unwrap();
|
||||||
|
let client_ephemeral = PublicKey::from(client_ephemeral_bytes);
|
||||||
|
let encrypted_machine_pub = &init_msg[37..85];
|
||||||
|
let init_tag = &init_msg[85..101];
|
||||||
|
|
||||||
|
// Server-side symmetric state
|
||||||
|
let mut s = SymmetricState::initialize();
|
||||||
|
s.mix_hash(&protocol_version_prologue(PROTOCOL_VERSION));
|
||||||
|
s.mix_hash(server_public.as_bytes());
|
||||||
|
|
||||||
|
// Process -> e
|
||||||
|
s.mix_hash(client_ephemeral.as_bytes());
|
||||||
|
|
||||||
|
// -> es: MixDH(server_private, client_ephemeral)
|
||||||
|
let cipher_es = s.mix_dh(&server_private, &client_ephemeral).unwrap();
|
||||||
|
|
||||||
|
// -> s: DecryptAndHash to get client machine public key
|
||||||
|
let client_machine_pub_bytes = s.decrypt_and_hash(cipher_es, encrypted_machine_pub).unwrap();
|
||||||
|
assert_eq!(client_machine_pub_bytes.len(), 32);
|
||||||
|
let client_machine_pub = PublicKey::from(<[u8; 32]>::try_from(&client_machine_pub_bytes[..]).unwrap());
|
||||||
|
|
||||||
|
// -> ss: MixDH(server_private, client_machine_pub)
|
||||||
|
let cipher_ss = s.mix_dh(&server_private, &client_machine_pub).unwrap();
|
||||||
|
s.decrypt_and_hash(cipher_ss, init_tag).unwrap();
|
||||||
|
|
||||||
|
// Send 101 response.
|
||||||
|
sock.write_all(
|
||||||
|
b"HTTP/1.1 101 Switching Protocols\r\n\
|
||||||
|
Connection: Upgrade\r\n\
|
||||||
|
Upgrade: tailscale-control-protocol\r\n\
|
||||||
|
\r\n",
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// <- e (server ephemeral)
|
||||||
|
let server_ephemeral_secret = StaticSecret::random_from_rng(rand::rngs::OsRng);
|
||||||
|
let server_ephemeral_public = PublicKey::from(&server_ephemeral_secret);
|
||||||
|
|
||||||
|
s.mix_hash(server_ephemeral_public.as_bytes());
|
||||||
|
|
||||||
|
// <- ee: MixDH(server_ephemeral_private, client_ephemeral)
|
||||||
|
let _cipher_ee = s.mix_dh(&server_ephemeral_secret, &client_ephemeral).unwrap();
|
||||||
|
|
||||||
|
// <- se: MixDH(server_ephemeral_private, client_machine_pub)
|
||||||
|
let cipher_se = s.mix_dh(&server_ephemeral_secret, &client_machine_pub).unwrap();
|
||||||
|
let tag = s.encrypt_and_hash(cipher_se, &[]);
|
||||||
|
|
||||||
|
// Build response message (51 bytes)
|
||||||
|
let mut resp = vec![0u8; 51];
|
||||||
|
resp[0] = MSG_TYPE_RESPONSE;
|
||||||
|
resp[1..3].copy_from_slice(&48u16.to_be_bytes());
|
||||||
|
resp[3..35].copy_from_slice(server_ephemeral_public.as_bytes());
|
||||||
|
resp[35..51].copy_from_slice(&tag);
|
||||||
|
|
||||||
|
sock.write_all(&resp).await.unwrap();
|
||||||
|
|
||||||
|
// Split for transport
|
||||||
|
let (_c1, _c2) = s.split().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn handshake_with_mock_server() {
|
||||||
|
// Generate server keys.
|
||||||
|
let server_private = StaticSecret::random_from_rng(rand::rngs::OsRng);
|
||||||
|
let server_public = PublicKey::from(&server_private);
|
||||||
|
|
||||||
|
// Generate client keys.
|
||||||
|
let machine_private = StaticSecret::random_from_rng(rand::rngs::OsRng);
|
||||||
|
let machine_public = PublicKey::from(&machine_private);
|
||||||
|
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let sp = server_public;
|
||||||
|
let spriv = server_private.clone();
|
||||||
|
|
||||||
|
// Server task.
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
let (sock, _) = listener.accept().await.unwrap();
|
||||||
|
mock_controlbase_server(sock, spriv, sp).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Client.
|
||||||
|
let mut client_stream = TcpStream::connect(addr).await.unwrap();
|
||||||
|
let result = perform_handshake(
|
||||||
|
&mut client_stream,
|
||||||
|
&machine_private,
|
||||||
|
&machine_public,
|
||||||
|
&server_public,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Verify we got transport ciphers and can encrypt.
|
||||||
|
let nonce = Nonce::default();
|
||||||
|
let mut data = b"hello".to_vec();
|
||||||
|
result
|
||||||
|
.tx_cipher
|
||||||
|
.encrypt_in_place_detached(&nonce, &[], &mut data)
|
||||||
|
.unwrap();
|
||||||
|
// Ciphertext should differ from plaintext.
|
||||||
|
assert_ne!(&data[..], b"hello");
|
||||||
|
|
||||||
|
assert_eq!(result.protocol_version, 49);
|
||||||
|
|
||||||
|
server_handle.await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
3
sunbeam-net/src/noise/mod.rs
Normal file
3
sunbeam-net/src/noise/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod framing;
|
||||||
|
pub mod handshake;
|
||||||
|
pub mod stream;
|
||||||
318
sunbeam-net/src/noise/stream.rs
Normal file
318
sunbeam-net/src/noise/stream.rs
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
|
use bytes::BytesMut;
|
||||||
|
use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, Nonce, Tag};
|
||||||
|
use futures::stream::StreamExt;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio_util::codec::Framed;
|
||||||
|
|
||||||
|
use super::framing::{NoiseFrame, NoiseFrameCodec, FRAME_TYPE_DATA, MAX_FRAME_SIZE};
|
||||||
|
|
||||||
|
/// AEAD tag size for ChaCha20Poly1305.
|
||||||
|
const TAG_SIZE: usize = 16;
|
||||||
|
|
||||||
|
/// Maximum plaintext chunk size, leaving room for the AEAD tag (16 bytes).
|
||||||
|
const MAX_PLAINTEXT_CHUNK: usize = MAX_FRAME_SIZE - TAG_SIZE;
|
||||||
|
|
||||||
|
/// An encrypted, framed stream using the controlbase protocol over TCP.
|
||||||
|
///
|
||||||
|
/// Implements `AsyncRead + AsyncWrite` so it can be used transparently by `h2`
|
||||||
|
/// and other async I/O consumers.
|
||||||
|
pub struct NoiseStream {
|
||||||
|
inner: Framed<TcpStream, NoiseFrameCodec>,
|
||||||
|
tx_cipher: ChaCha20Poly1305,
|
||||||
|
rx_cipher: ChaCha20Poly1305,
|
||||||
|
read_nonce: AtomicU64,
|
||||||
|
write_nonce: AtomicU64,
|
||||||
|
read_buf: BytesMut,
|
||||||
|
// Encrypted frames ready to be sent, accumulated during poll_write.
|
||||||
|
pending_write: BytesMut,
|
||||||
|
// Whether we have an error from a previous operation.
|
||||||
|
write_err: Option<std::io::Error>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: ChaCha20Poly1305 is Send, all other fields are Send.
|
||||||
|
unsafe impl Send for NoiseStream {}
|
||||||
|
|
||||||
|
impl Unpin for NoiseStream {}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for NoiseStream {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("NoiseStream")
|
||||||
|
.field("read_nonce", &self.read_nonce.load(Ordering::Relaxed))
|
||||||
|
.field("write_nonce", &self.write_nonce.load(Ordering::Relaxed))
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a 12-byte nonce: first 4 bytes zero, last 8 bytes little-endian u64 counter.
|
||||||
|
fn build_nonce(counter: u64) -> Nonce {
|
||||||
|
let mut nonce = [0u8; 12];
|
||||||
|
nonce[4..12].copy_from_slice(&counter.to_be_bytes());
|
||||||
|
Nonce::from(nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NoiseStream {
|
||||||
|
/// Wrap a TCP stream with controlbase encryption using the given transport ciphers.
|
||||||
|
/// `leftover` contains any bytes already read from TCP that belong to the first
|
||||||
|
/// Noise transport record (from the handshake buffer overflow).
|
||||||
|
pub fn new(
|
||||||
|
stream: TcpStream,
|
||||||
|
tx_cipher: ChaCha20Poly1305,
|
||||||
|
rx_cipher: ChaCha20Poly1305,
|
||||||
|
leftover: Vec<u8>,
|
||||||
|
) -> Self {
|
||||||
|
let mut framed = Framed::new(stream, NoiseFrameCodec);
|
||||||
|
// Prepend leftover bytes so the codec can parse them as part of the first frame.
|
||||||
|
if !leftover.is_empty() {
|
||||||
|
framed.read_buffer_mut().extend_from_slice(&leftover);
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
inner: framed,
|
||||||
|
tx_cipher,
|
||||||
|
rx_cipher,
|
||||||
|
read_nonce: AtomicU64::new(0),
|
||||||
|
write_nonce: AtomicU64::new(0),
|
||||||
|
read_buf: BytesMut::with_capacity(MAX_FRAME_SIZE),
|
||||||
|
pending_write: BytesMut::new(),
|
||||||
|
write_err: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read and consume the early payload sent by Headscale after the handshake.
|
||||||
|
///
|
||||||
|
/// The early payload consists of multiple Noise records:
|
||||||
|
/// 1. 5 bytes: magic `\xff\xff\xffTS`
|
||||||
|
/// 2. 4 bytes: BE u32 JSON length
|
||||||
|
/// 3. N bytes: JSON payload
|
||||||
|
///
|
||||||
|
/// Each part may be a separate Noise record. We read and reassemble them.
|
||||||
|
/// If the first record is not an early payload (i.e. it's an HTTP/2 frame),
|
||||||
|
/// we buffer it for h2 to read.
|
||||||
|
///
|
||||||
|
/// Returns the early payload JSON bytes, or an empty vec if no early payload.
|
||||||
|
pub async fn consume_early_payload(&mut self) -> std::io::Result<Vec<u8>> {
|
||||||
|
// Accumulate plaintext from multiple records
|
||||||
|
let mut accum = Vec::new();
|
||||||
|
|
||||||
|
// Read first frame
|
||||||
|
let frame = match self.inner.next().await {
|
||||||
|
Some(Ok(frame)) => frame,
|
||||||
|
Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
|
||||||
|
None => return Ok(vec![]),
|
||||||
|
};
|
||||||
|
let plaintext = self.decrypt_frame(&frame)?;
|
||||||
|
accum.extend_from_slice(&plaintext);
|
||||||
|
|
||||||
|
// Check for early payload magic
|
||||||
|
if accum.len() < 5
|
||||||
|
|| accum[0] != 0xff
|
||||||
|
|| accum[1] != 0xff
|
||||||
|
|| accum[2] != 0xff
|
||||||
|
|| accum[3] != b'T'
|
||||||
|
|| accum[4] != b'S'
|
||||||
|
{
|
||||||
|
// Not an early payload — buffer for h2
|
||||||
|
self.read_buf.extend_from_slice(&accum);
|
||||||
|
return Ok(vec![]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read more frames until we have the 4-byte length
|
||||||
|
while accum.len() < 9 {
|
||||||
|
let frame = match self.inner.next().await {
|
||||||
|
Some(Ok(f)) => f,
|
||||||
|
Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
|
||||||
|
None => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "early payload truncated")),
|
||||||
|
};
|
||||||
|
accum.extend_from_slice(&self.decrypt_frame(&frame)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let json_len = u32::from_be_bytes([accum[5], accum[6], accum[7], accum[8]]) as usize;
|
||||||
|
|
||||||
|
// Read more frames until we have the full JSON
|
||||||
|
while accum.len() < 9 + json_len {
|
||||||
|
let frame = match self.inner.next().await {
|
||||||
|
Some(Ok(f)) => f,
|
||||||
|
Some(Err(e)) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
|
||||||
|
None => return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "early payload JSON truncated")),
|
||||||
|
};
|
||||||
|
accum.extend_from_slice(&self.decrypt_frame(&frame)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let json_bytes = accum[9..9 + json_len].to_vec();
|
||||||
|
tracing::debug!("consumed early payload ({} bytes)", json_bytes.len());
|
||||||
|
|
||||||
|
// If there's any leftover data after the early payload, buffer it for h2
|
||||||
|
if accum.len() > 9 + json_len {
|
||||||
|
self.read_buf.extend_from_slice(&accum[9 + json_len..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(json_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypt a data frame and return plaintext.
|
||||||
|
fn decrypt_frame(&self, frame: &NoiseFrame) -> std::io::Result<Vec<u8>> {
|
||||||
|
let counter = self.read_nonce.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let nonce = build_nonce(counter);
|
||||||
|
tracing::trace!("decrypt_frame: type={:02x}, payload_len={}, nonce={}", frame.frame_type, frame.payload.len(), counter);
|
||||||
|
|
||||||
|
if frame.payload.len() < TAG_SIZE {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
"frame too short for AEAD tag",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (ct, tag_bytes) = frame.payload.split_at(frame.payload.len() - TAG_SIZE);
|
||||||
|
let tag = Tag::from_slice(tag_bytes);
|
||||||
|
let mut plaintext = ct.to_vec();
|
||||||
|
self.rx_cipher
|
||||||
|
.decrypt_in_place_detached(&nonce, &[], &mut plaintext, tag)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
|
||||||
|
Ok(plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encrypt a chunk and produce a NoiseFrame.
|
||||||
|
fn encrypt_chunk(&self, plaintext: &[u8]) -> std::io::Result<NoiseFrame> {
|
||||||
|
let counter = self.write_nonce.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let nonce = build_nonce(counter);
|
||||||
|
|
||||||
|
let mut ciphertext = plaintext.to_vec();
|
||||||
|
let tag = self
|
||||||
|
.tx_cipher
|
||||||
|
.encrypt_in_place_detached(&nonce, &[], &mut ciphertext)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
|
||||||
|
ciphertext.extend_from_slice(tag.as_slice());
|
||||||
|
|
||||||
|
Ok(NoiseFrame {
|
||||||
|
frame_type: FRAME_TYPE_DATA,
|
||||||
|
payload: BytesMut::from(&ciphertext[..]),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for NoiseStream {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
|
||||||
|
// If we have buffered plaintext, return it.
|
||||||
|
if !this.read_buf.is_empty() {
|
||||||
|
let to_copy = std::cmp::min(this.read_buf.len(), buf.remaining());
|
||||||
|
buf.put_slice(&this.read_buf[..to_copy]);
|
||||||
|
let _ = this.read_buf.split_to(to_copy);
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Poll the inner framed stream for the next frame.
|
||||||
|
match this.inner.poll_next_unpin(cx) {
|
||||||
|
Poll::Ready(Some(Ok(frame))) => {
|
||||||
|
if frame.frame_type != FRAME_TYPE_DATA {
|
||||||
|
// Skip non-data frames; wake to try again.
|
||||||
|
cx.waker().wake_by_ref();
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
match this.decrypt_frame(&frame) {
|
||||||
|
Ok(plaintext) => {
|
||||||
|
let to_copy = std::cmp::min(plaintext.len(), buf.remaining());
|
||||||
|
buf.put_slice(&plaintext[..to_copy]);
|
||||||
|
if to_copy < plaintext.len() {
|
||||||
|
this.read_buf.extend_from_slice(&plaintext[to_copy..]);
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
Err(e) => Poll::Ready(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)),
|
||||||
|
Poll::Ready(None) => Poll::Ready(Ok(())), // EOF
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for NoiseStream {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
|
||||||
|
if let Some(e) = this.write_err.take() {
|
||||||
|
return Poll::Ready(Err(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt in chunks and buffer the frames.
|
||||||
|
let mut written = 0;
|
||||||
|
for chunk in buf.chunks(MAX_PLAINTEXT_CHUNK) {
|
||||||
|
match this.encrypt_chunk(chunk) {
|
||||||
|
Ok(frame) => {
|
||||||
|
// Manually serialize the frame into pending_write.
|
||||||
|
let len = frame.payload.len() as u16;
|
||||||
|
this.pending_write.extend_from_slice(&[
|
||||||
|
frame.frame_type,
|
||||||
|
(len >> 8) as u8,
|
||||||
|
len as u8,
|
||||||
|
]);
|
||||||
|
this.pending_write.extend_from_slice(&frame.payload);
|
||||||
|
written += chunk.len();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if written > 0 {
|
||||||
|
// Report what we encrypted so far; save error for next call.
|
||||||
|
this.write_err = Some(e);
|
||||||
|
return Poll::Ready(Ok(written));
|
||||||
|
}
|
||||||
|
return Poll::Ready(Err(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Ready(Ok(written))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
|
||||||
|
// Flush pending_write bytes through the underlying stream.
|
||||||
|
if !this.pending_write.is_empty() {
|
||||||
|
let inner_stream = this.inner.get_mut();
|
||||||
|
loop {
|
||||||
|
if this.pending_write.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
match Pin::new(&mut *inner_stream).poll_write(cx, &this.pending_write) {
|
||||||
|
Poll::Ready(Ok(n)) => {
|
||||||
|
let _ = this.pending_write.split_to(n);
|
||||||
|
}
|
||||||
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Pin::new(this.inner.get_mut()).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
Pin::new(this.inner.get_mut()).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn max_plaintext_chunk_leaves_room_for_tag() {
|
||||||
|
assert!(MAX_PLAINTEXT_CHUNK + 16 <= MAX_FRAME_SIZE);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user