use std::path::Path; use rand::rngs::OsRng; use serde::{Deserialize, Serialize}; use x25519_dalek::{PublicKey, StaticSecret}; use crate::error::ResultExt; const KEYS_FILE: &str = "keys.json"; /// The three x25519 key pairs used by a node. #[derive(Clone)] pub struct NodeKeys { pub node_private: StaticSecret, pub node_public: PublicKey, pub disco_private: StaticSecret, pub disco_public: PublicKey, pub wg_private: StaticSecret, pub wg_public: PublicKey, } impl std::fmt::Debug for NodeKeys { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("NodeKeys") .field("node_public", &self.node_key_str()) .field("disco_public", &self.disco_key_str()) .field("wg_public", &hex::encode(self.wg_public.as_bytes())) .finish() } } /// On-disk representation of persisted keys. #[derive(Serialize, Deserialize)] struct PersistedKeys { node_private: String, disco_private: String, wg_private: String, } // Hex helpers — we avoid pulling in a hex crate by implementing locally. mod hex { pub fn encode(bytes: &[u8]) -> String { bytes.iter().map(|b| format!("{b:02x}")).collect() } pub fn decode(s: &str) -> Result, String> { if s.len() % 2 != 0 { return Err("odd length hex string".into()); } (0..s.len()) .step_by(2) .map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| e.to_string())) .collect() } } impl NodeKeys { /// Generate fresh random key pairs. pub fn generate() -> Self { let node_private = StaticSecret::random_from_rng(OsRng); let node_public = PublicKey::from(&node_private); let disco_private = StaticSecret::random_from_rng(OsRng); let disco_public = PublicKey::from(&disco_private); let wg_private = StaticSecret::random_from_rng(OsRng); let wg_public = PublicKey::from(&wg_private); Self { node_private, node_public, disco_private, disco_public, wg_private, wg_public, } } /// Load keys from `state_dir/keys.json`, or generate and persist new ones. pub fn load_or_generate(state_dir: &Path) -> crate::Result { let path = state_dir.join(KEYS_FILE); if path.exists() { let data = std::fs::read_to_string(&path).ctx("reading keys file")?; let persisted: PersistedKeys = serde_json::from_str(&data)?; return Self::from_persisted(&persisted); } let keys = Self::generate(); std::fs::create_dir_all(state_dir).ctx("creating state directory")?; let persisted = keys.to_persisted(); let data = serde_json::to_string_pretty(&persisted)?; std::fs::write(&path, data).ctx("writing keys file")?; tracing::debug!("generated new node keys at {}", path.display()); Ok(keys) } /// Tailscale-style node key string: `nodekey:`. pub fn node_key_str(&self) -> String { format!("nodekey:{}", hex::encode(self.node_public.as_bytes())) } /// Tailscale-style disco key string: `discokey:`. pub fn disco_key_str(&self) -> String { format!("discokey:{}", hex::encode(self.disco_public.as_bytes())) } fn to_persisted(&self) -> PersistedKeys { PersistedKeys { node_private: hex::encode(self.node_private.as_bytes()), disco_private: hex::encode(self.disco_private.as_bytes()), wg_private: hex::encode(self.wg_private.as_bytes()), } } fn from_persisted(p: &PersistedKeys) -> crate::Result { let node_bytes = parse_key_hex(&p.node_private, "node")?; let disco_bytes = parse_key_hex(&p.disco_private, "disco")?; let wg_bytes = parse_key_hex(&p.wg_private, "wg")?; let node_private = StaticSecret::from(node_bytes); let node_public = PublicKey::from(&node_private); let disco_private = StaticSecret::from(disco_bytes); let disco_public = PublicKey::from(&disco_private); let wg_private = StaticSecret::from(wg_bytes); let wg_public = PublicKey::from(&wg_private); Ok(Self { node_private, node_public, disco_private, disco_public, wg_private, wg_public, }) } } fn parse_key_hex(s: &str, name: &str) -> crate::Result<[u8; 32]> { let bytes = hex::decode(s).map_err(|e| crate::Error::Other(format!("bad {name} key hex: {e}")))?; let arr: [u8; 32] = bytes .try_into() .map_err(|_| crate::Error::Other(format!("{name} key must be 32 bytes")))?; Ok(arr) } #[cfg(test)] mod tests { use super::*; use tempfile::TempDir; #[test] fn generate_produces_valid_keys() { let keys = NodeKeys::generate(); assert!(keys.node_key_str().starts_with("nodekey:")); assert!(keys.disco_key_str().starts_with("discokey:")); assert_eq!(keys.node_key_str().len(), "nodekey:".len() + 64); } #[test] fn load_or_generate_round_trips() { let dir = TempDir::new().unwrap(); let keys1 = NodeKeys::load_or_generate(dir.path()).unwrap(); let keys2 = NodeKeys::load_or_generate(dir.path()).unwrap(); assert_eq!(keys1.node_key_str(), keys2.node_key_str()); assert_eq!(keys1.disco_key_str(), keys2.disco_key_str()); } #[test] fn hex_round_trip() { let input = [0xde, 0xad, 0xbe, 0xef]; let encoded = super::hex::encode(&input); assert_eq!(encoded, "deadbeef"); let decoded = super::hex::decode(&encoded).unwrap(); assert_eq!(decoded, input); } }