From 5d279f992bd124a4ec0eebcd256c3463ac693636 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Tue, 10 Mar 2026 23:38:20 +0000 Subject: [PATCH] feat(cluster): implement gossip-based cluster subsystem with iroh Core cluster module with four gossip channels (bandwidth, models, leader, license) over iroh-gossip HyParView/PlumTree. Includes: - BandwidthTracker: atomic per-node counters with zero hot-path contention - ClusterBandwidthState: peer aggregation with stale eviction - BandwidthMeter: sliding-window aggregate rate (power-of-2 MiB units) - BandwidthLimiter: runtime-mutable bandwidth cap (default 1 Gbps) - ClusterHandle/spawn_cluster: dedicated OS thread + tokio runtime - Bincode-serialized message envelope with versioned payloads - Bootstrap and k8s peer discovery modes - Persistent ed25519 identity for stable EndpointId across restarts Signed-off-by: Sienna Meridian Satterwhite --- src/cluster/bandwidth.rs | 437 +++++++++++++++++++++++++++++++++++++++ src/cluster/messages.rs | 144 +++++++++++++ src/cluster/mod.rs | 104 ++++++++++ src/cluster/node.rs | 435 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 1120 insertions(+) create mode 100644 src/cluster/bandwidth.rs create mode 100644 src/cluster/messages.rs create mode 100644 src/cluster/mod.rs create mode 100644 src/cluster/node.rs diff --git a/src/cluster/bandwidth.rs b/src/cluster/bandwidth.rs new file mode 100644 index 0000000..06d32e5 --- /dev/null +++ b/src/cluster/bandwidth.rs @@ -0,0 +1,437 @@ +use rustc_hash::FxHashMap; +use std::collections::VecDeque; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::RwLock; +use std::time::{Duration, Instant}; + +/// Per-node atomic bandwidth counters. Zero contention on the hot path +/// (single `fetch_add` per counter per request). +pub struct BandwidthTracker { + /// Bytes received since last broadcast (reset each cycle). + bytes_in: AtomicU64, + /// Bytes sent since last broadcast (reset each cycle). + bytes_out: AtomicU64, + /// Requests since last broadcast (reset each cycle). + request_count: AtomicU64, + /// Monotonic total bytes received (never reset). + cumulative_in: AtomicU64, + /// Monotonic total bytes sent (never reset). + cumulative_out: AtomicU64, +} + +impl BandwidthTracker { + pub fn new() -> Self { + Self { + bytes_in: AtomicU64::new(0), + bytes_out: AtomicU64::new(0), + request_count: AtomicU64::new(0), + cumulative_in: AtomicU64::new(0), + cumulative_out: AtomicU64::new(0), + } + } + + /// Record a completed request's byte counts. + #[inline] + pub fn record(&self, bytes_in: u64, bytes_out: u64) { + self.bytes_in.fetch_add(bytes_in, Ordering::Relaxed); + self.bytes_out.fetch_add(bytes_out, Ordering::Relaxed); + self.request_count.fetch_add(1, Ordering::Relaxed); + self.cumulative_in.fetch_add(bytes_in, Ordering::Relaxed); + self.cumulative_out.fetch_add(bytes_out, Ordering::Relaxed); + } + + /// Take a snapshot and reset the per-interval counters. + pub fn snapshot_and_reset(&self) -> BandwidthSnapshot { + BandwidthSnapshot { + bytes_in: self.bytes_in.swap(0, Ordering::Relaxed), + bytes_out: self.bytes_out.swap(0, Ordering::Relaxed), + request_count: self.request_count.swap(0, Ordering::Relaxed), + cumulative_in: self.cumulative_in.load(Ordering::Relaxed), + cumulative_out: self.cumulative_out.load(Ordering::Relaxed), + } + } +} + +#[derive(Debug, Clone)] +pub struct BandwidthSnapshot { + pub bytes_in: u64, + pub bytes_out: u64, + pub request_count: u64, + pub cumulative_in: u64, + pub cumulative_out: u64, +} + +/// Aggregated bandwidth state from all cluster peers. +pub struct ClusterBandwidthState { + peers: RwLock>, + /// Sum of all peers' cumulative bytes in (updated on each report). + pub total_bytes_in: AtomicU64, + /// Sum of all peers' cumulative bytes out. + pub total_bytes_out: AtomicU64, + /// Number of active (non-stale) peers. + pub peer_count: AtomicU64, + /// Stale peer timeout. + stale_timeout_secs: u64, +} + +struct PeerEntry { + cumulative_in: u64, + cumulative_out: u64, + last_seen: Instant, +} + +impl ClusterBandwidthState { + pub fn new(stale_timeout_secs: u64) -> Self { + Self { + peers: RwLock::new(FxHashMap::default()), + total_bytes_in: AtomicU64::new(0), + total_bytes_out: AtomicU64::new(0), + peer_count: AtomicU64::new(0), + stale_timeout_secs, + } + } + + /// Update a peer's bandwidth state from a received report. + pub fn update_peer(&self, peer_id: [u8; 32], cumulative_in: u64, cumulative_out: u64) { + let mut peers = self.peers.write().unwrap(); + peers.insert( + peer_id, + PeerEntry { + cumulative_in, + cumulative_out, + last_seen: Instant::now(), + }, + ); + self.recalculate(&peers); + } + + /// Remove peers that haven't reported within the stale timeout. + pub fn evict_stale(&self) { + let mut peers = self.peers.write().unwrap(); + let cutoff = Instant::now() - std::time::Duration::from_secs(self.stale_timeout_secs); + peers.retain(|_, entry| entry.last_seen > cutoff); + self.recalculate(&peers); + } + + fn recalculate(&self, peers: &FxHashMap<[u8; 32], PeerEntry>) { + let mut total_in = 0u64; + let mut total_out = 0u64; + for entry in peers.values() { + total_in = total_in.saturating_add(entry.cumulative_in); + total_out = total_out.saturating_add(entry.cumulative_out); + } + self.total_bytes_in.store(total_in, Ordering::Relaxed); + self.total_bytes_out.store(total_out, Ordering::Relaxed); + self.peer_count.store(peers.len() as u64, Ordering::Relaxed); + } +} + +/// Aggregate bandwidth rate across the entire cluster, computed from a +/// sliding window of samples from all nodes (local + remote). +/// +/// Each broadcast cycle produces one sample per node. With a 5s broadcast +/// interval and 30s window, the deque holds ~6 × node_count entries — tiny. +pub struct BandwidthMeter { + samples: RwLock>, + window: Duration, +} + +struct Sample { + time: Instant, + bytes_in: u64, + bytes_out: u64, +} + +/// Snapshot of the aggregate cluster-wide bandwidth rate. +/// All rates are in bytes/sec. Use the `*_mib_per_sec` methods for MiB/s (power-of-2). +#[derive(Debug, Clone, Copy)] +pub struct AggregateRate { + /// Inbound bytes/sec across all nodes. + pub bytes_in_per_sec: f64, + /// Outbound bytes/sec across all nodes. + pub bytes_out_per_sec: f64, + /// Total (in + out) bytes/sec. + pub total_per_sec: f64, + /// Number of samples in the window. + pub sample_count: usize, +} + +const BYTES_PER_MIB: f64 = 1_048_576.0; // 1024 * 1024 + +impl AggregateRate { + /// Inbound rate in MiB/s (power-of-2). + pub fn in_mib_per_sec(&self) -> f64 { + self.bytes_in_per_sec / BYTES_PER_MIB + } + + /// Outbound rate in MiB/s (power-of-2). + pub fn out_mib_per_sec(&self) -> f64 { + self.bytes_out_per_sec / BYTES_PER_MIB + } + + /// Total rate in MiB/s (power-of-2). + pub fn total_mib_per_sec(&self) -> f64 { + self.total_per_sec / BYTES_PER_MIB + } +} + +impl BandwidthMeter { + pub fn new(window_secs: u64) -> Self { + Self { + samples: RwLock::new(VecDeque::new()), + window: Duration::from_secs(window_secs), + } + } + + /// Record a bandwidth sample (from local broadcast or remote peer report). + pub fn record_sample(&self, bytes_in: u64, bytes_out: u64) { + let now = Instant::now(); + let mut samples = self.samples.write().unwrap(); + samples.push_back(Sample { + time: now, + bytes_in, + bytes_out, + }); + // Evict samples outside the window. + let cutoff = now - self.window; + while samples.front().is_some_and(|s| s.time < cutoff) { + samples.pop_front(); + } + } + + /// Compute the aggregate bandwidth rate over the sliding window. + pub fn aggregate_rate(&self) -> AggregateRate { + let now = Instant::now(); + let samples = self.samples.read().unwrap(); + let cutoff = now - self.window; + + let mut total_in = 0u64; + let mut total_out = 0u64; + let mut count = 0usize; + + for s in samples.iter() { + if s.time >= cutoff { + total_in = total_in.saturating_add(s.bytes_in); + total_out = total_out.saturating_add(s.bytes_out); + count += 1; + } + } + + let window_secs = self.window.as_secs_f64(); + let bytes_in_per_sec = total_in as f64 / window_secs; + let bytes_out_per_sec = total_out as f64 / window_secs; + + AggregateRate { + bytes_in_per_sec, + bytes_out_per_sec, + total_per_sec: bytes_in_per_sec + bytes_out_per_sec, + sample_count: count, + } + } +} + +/// Cluster-wide bandwidth limiter. Compares the aggregate rate from the +/// `BandwidthMeter` against a configurable cap (bytes/sec). The limit is +/// stored as an `AtomicU64` so it can be updated at runtime (e.g. when a +/// license quota changes via gossip). +pub struct BandwidthLimiter { + /// Max total (in + out) bytes/sec across the cluster. 0 = unlimited. + limit_bytes_per_sec: AtomicU64, + meter: std::sync::Arc, +} + +/// Result of a bandwidth limit check. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BandwidthLimitResult { + Allow, + Reject, +} + +impl BandwidthLimiter { + pub fn new(meter: std::sync::Arc, limit_bytes_per_sec: u64) -> Self { + Self { + limit_bytes_per_sec: AtomicU64::new(limit_bytes_per_sec), + meter, + } + } + + /// Check whether the cluster is currently over its bandwidth cap. + #[inline] + pub fn check(&self) -> BandwidthLimitResult { + let limit = self.limit_bytes_per_sec.load(Ordering::Relaxed); + if limit == 0 { + return BandwidthLimitResult::Allow; + } + let rate = self.meter.aggregate_rate(); + if rate.total_per_sec > limit as f64 { + BandwidthLimitResult::Reject + } else { + BandwidthLimitResult::Allow + } + } + + /// Update the bandwidth cap at runtime (e.g. from a license update). + pub fn set_limit(&self, bytes_per_sec: u64) { + self.limit_bytes_per_sec.store(bytes_per_sec, Ordering::Relaxed); + } + + /// Current limit in bytes/sec (0 = unlimited). + pub fn limit(&self) -> u64 { + self.limit_bytes_per_sec.load(Ordering::Relaxed) + } + + /// Current aggregate rate snapshot. + pub fn current_rate(&self) -> AggregateRate { + self.meter.aggregate_rate() + } +} + +/// Convert Gbps (base-10, as used in networking/billing) to bytes/sec. +/// 1 Gbps = 1_000_000_000 bits/sec = 125_000_000 bytes/sec. +pub fn gbps_to_bytes_per_sec(gbps: f64) -> u64 { + (gbps * 125_000_000.0) as u64 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tracker_record_and_snapshot() { + let tracker = BandwidthTracker::new(); + tracker.record(100, 200); + tracker.record(50, 75); + + let snap = tracker.snapshot_and_reset(); + assert_eq!(snap.bytes_in, 150); + assert_eq!(snap.bytes_out, 275); + assert_eq!(snap.request_count, 2); + assert_eq!(snap.cumulative_in, 150); + assert_eq!(snap.cumulative_out, 275); + + // After reset, interval counters are zero but cumulative persists. + tracker.record(10, 20); + let snap2 = tracker.snapshot_and_reset(); + assert_eq!(snap2.bytes_in, 10); + assert_eq!(snap2.bytes_out, 20); + assert_eq!(snap2.request_count, 1); + assert_eq!(snap2.cumulative_in, 160); + assert_eq!(snap2.cumulative_out, 295); + } + + #[test] + fn meter_aggregate_rate() { + let meter = BandwidthMeter::new(30); + // Simulate 6 samples over the window (one every 5s). + // In reality they come from multiple nodes; we don't care about source. + meter.record_sample(500_000_000, 100_000_000); // 500MB in, 100MB out + meter.record_sample(50_000_000, 10_000_000); // 50MB in, 10MB out + + let rate = meter.aggregate_rate(); + assert_eq!(rate.sample_count, 2); + // total_in = 550MB over 30s window = ~18.3 MB/s + let expected_in = 550_000_000.0 / 30.0; + assert!( + (rate.bytes_in_per_sec - expected_in).abs() < 1.0, + "expected ~{expected_in}, got {}", + rate.bytes_in_per_sec + ); + let expected_out = 110_000_000.0 / 30.0; + assert!( + (rate.bytes_out_per_sec - expected_out).abs() < 1.0, + "expected ~{expected_out}, got {}", + rate.bytes_out_per_sec + ); + assert!( + (rate.total_per_sec - (expected_in + expected_out)).abs() < 1.0, + ); + } + + #[test] + fn meter_evicts_old_samples() { + // Use a 1-second window so we can test eviction quickly. + let meter = BandwidthMeter::new(1); + meter.record_sample(1000, 2000); + std::thread::sleep(std::time::Duration::from_millis(1100)); + // Sample should be evicted. + meter.record_sample(500, 600); + + let rate = meter.aggregate_rate(); + assert_eq!(rate.sample_count, 1, "old sample should be evicted"); + // Only the second sample should be counted. + assert!((rate.bytes_in_per_sec - 500.0).abs() < 1.0); + } + + #[test] + fn meter_empty_returns_zero() { + let meter = BandwidthMeter::new(30); + let rate = meter.aggregate_rate(); + assert_eq!(rate.sample_count, 0); + assert_eq!(rate.bytes_in_per_sec, 0.0); + assert_eq!(rate.bytes_out_per_sec, 0.0); + assert_eq!(rate.total_per_sec, 0.0); + } + + #[test] + fn cluster_state_aggregation() { + let state = ClusterBandwidthState::new(30); + state.update_peer([1u8; 32], 1000, 2000); + state.update_peer([2u8; 32], 3000, 4000); + + assert_eq!(state.total_bytes_in.load(Ordering::Relaxed), 4000); + assert_eq!(state.total_bytes_out.load(Ordering::Relaxed), 6000); + assert_eq!(state.peer_count.load(Ordering::Relaxed), 2); + + // Update existing peer. + state.update_peer([1u8; 32], 1500, 2500); + assert_eq!(state.total_bytes_in.load(Ordering::Relaxed), 4500); + assert_eq!(state.total_bytes_out.load(Ordering::Relaxed), 6500); + assert_eq!(state.peer_count.load(Ordering::Relaxed), 2); + } + + #[test] + fn limiter_allows_when_unlimited() { + let meter = std::sync::Arc::new(BandwidthMeter::new(30)); + meter.record_sample(999_999_999, 999_999_999); + let limiter = BandwidthLimiter::new(meter, 0); // 0 = unlimited + assert_eq!(limiter.check(), BandwidthLimitResult::Allow); + } + + #[test] + fn limiter_allows_under_cap() { + let meter = std::sync::Arc::new(BandwidthMeter::new(30)); + // 1 GiB total over 30s = ~33 MiB/s ≈ ~35 MB/s — well under 1 Gbps + meter.record_sample(500_000_000, 500_000_000); + let limiter = BandwidthLimiter::new(meter, gbps_to_bytes_per_sec(1.0)); + assert_eq!(limiter.check(), BandwidthLimitResult::Allow); + } + + #[test] + fn limiter_rejects_over_cap() { + let meter = std::sync::Arc::new(BandwidthMeter::new(1)); // 1s window + // 200 MB total in 1s window = 200 MB/s > 125 MB/s (1 Gbps) + meter.record_sample(100_000_000, 100_000_000); + let limiter = BandwidthLimiter::new(meter, gbps_to_bytes_per_sec(1.0)); + assert_eq!(limiter.check(), BandwidthLimitResult::Reject); + } + + #[test] + fn limiter_set_limit_runtime() { + let meter = std::sync::Arc::new(BandwidthMeter::new(1)); + meter.record_sample(100_000_000, 100_000_000); // 200 MB/s + let limiter = BandwidthLimiter::new(meter, gbps_to_bytes_per_sec(1.0)); + assert_eq!(limiter.check(), BandwidthLimitResult::Reject); + + // Raise the limit to 10 Gbps → should now allow. + limiter.set_limit(gbps_to_bytes_per_sec(10.0)); + assert_eq!(limiter.check(), BandwidthLimitResult::Allow); + assert_eq!(limiter.limit(), gbps_to_bytes_per_sec(10.0)); + } + + #[test] + fn gbps_conversion() { + assert_eq!(gbps_to_bytes_per_sec(1.0), 125_000_000); + assert_eq!(gbps_to_bytes_per_sec(10.0), 1_250_000_000); + assert_eq!(gbps_to_bytes_per_sec(0.0), 0); + } +} diff --git a/src/cluster/messages.rs b/src/cluster/messages.rs new file mode 100644 index 0000000..62a93f2 --- /dev/null +++ b/src/cluster/messages.rs @@ -0,0 +1,144 @@ +use serde::{Deserialize, Serialize}; + +/// Envelope for all cluster gossip messages. +/// Serialized with bincode before broadcast. +#[derive(Debug, Serialize, Deserialize)] +pub struct ClusterMessage { + pub version: u8, + pub sender: [u8; 32], + pub payload: Payload, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum Payload { + BandwidthReport { + timestamp: u64, + bytes_in: u64, + bytes_out: u64, + request_count: u64, + cumulative_in: u64, + cumulative_out: u64, + }, + ModelAnnounce { + model_type: String, + hash: [u8; 32], + total_size: u64, + chunk_count: u32, + }, + ModelChunk { + hash: [u8; 32], + chunk_index: u32, + data: Vec, + }, + LeaderHeartbeat { + term: u64, + leader_id: [u8; 32], + }, + LicenseQuota { + max_bytes: u64, + current_bytes: u64, + }, +} + +impl ClusterMessage { + pub fn encode(&self) -> Result, bincode::Error> { + bincode::serialize(self) + } + + pub fn decode(data: &[u8]) -> Result { + bincode::deserialize(data) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn roundtrip_bandwidth_report() { + let msg = ClusterMessage { + version: 1, + sender: [42u8; 32], + payload: Payload::BandwidthReport { + timestamp: 1234567890, + bytes_in: 1000, + bytes_out: 2000, + request_count: 50, + cumulative_in: 100_000, + cumulative_out: 200_000, + }, + }; + let encoded = msg.encode().unwrap(); + let decoded = ClusterMessage::decode(&encoded).unwrap(); + assert_eq!(decoded.version, 1); + assert_eq!(decoded.sender, [42u8; 32]); + match decoded.payload { + Payload::BandwidthReport { + timestamp, + bytes_in, + bytes_out, + request_count, + cumulative_in, + cumulative_out, + } => { + assert_eq!(timestamp, 1234567890); + assert_eq!(bytes_in, 1000); + assert_eq!(bytes_out, 2000); + assert_eq!(request_count, 50); + assert_eq!(cumulative_in, 100_000); + assert_eq!(cumulative_out, 200_000); + } + _ => panic!("wrong payload variant"), + } + } + + #[test] + fn roundtrip_model_announce() { + let msg = ClusterMessage { + version: 1, + sender: [1u8; 32], + payload: Payload::ModelAnnounce { + model_type: "scanner".to_string(), + hash: [0xAA; 32], + total_size: 1_000_000, + chunk_count: 16, + }, + }; + let encoded = msg.encode().unwrap(); + let decoded = ClusterMessage::decode(&encoded).unwrap(); + match decoded.payload { + Payload::ModelAnnounce { + model_type, + total_size, + chunk_count, + .. + } => { + assert_eq!(model_type, "scanner"); + assert_eq!(total_size, 1_000_000); + assert_eq!(chunk_count, 16); + } + _ => panic!("wrong payload variant"), + } + } + + #[test] + fn roundtrip_leader_heartbeat() { + let msg = ClusterMessage { + version: 1, + sender: [7u8; 32], + payload: Payload::LeaderHeartbeat { + term: 3, + leader_id: [7u8; 32], + }, + }; + let encoded = msg.encode().unwrap(); + let decoded = ClusterMessage::decode(&encoded).unwrap(); + match decoded.payload { + Payload::LeaderHeartbeat { term, leader_id } => { + assert_eq!(term, 3); + assert_eq!(leader_id, [7u8; 32]); + } + _ => panic!("wrong payload variant"), + } + } +} diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs new file mode 100644 index 0000000..e8cebc7 --- /dev/null +++ b/src/cluster/mod.rs @@ -0,0 +1,104 @@ +pub mod bandwidth; +pub mod messages; +pub mod node; + +use std::sync::Arc; + +use anyhow::Result; +use tokio::sync::watch; + +use crate::config::ClusterConfig; +use bandwidth::{ + gbps_to_bytes_per_sec, BandwidthLimiter, BandwidthMeter, BandwidthTracker, + ClusterBandwidthState, +}; + +pub struct ClusterHandle { + pub bandwidth: Arc, + pub cluster_bandwidth: Arc, + /// Sliding-window aggregate bandwidth rate across the cluster. + pub meter: Arc, + /// Cluster-wide bandwidth limiter (0 = unlimited). + pub limiter: Arc, + pub endpoint_id: iroh::PublicKey, + shutdown_tx: watch::Sender, +} + +impl ClusterHandle { + pub fn shutdown(&self) { + let _ = self.shutdown_tx.send(true); + } +} + +impl Drop for ClusterHandle { + fn drop(&mut self) { + self.shutdown(); + } +} + +/// Spawn the cluster subsystem on a dedicated OS thread with its own tokio runtime. +/// Returns a handle for bandwidth recording and cluster state queries. +pub fn spawn_cluster(cfg: &ClusterConfig) -> Result { + let (shutdown_tx, shutdown_rx) = watch::channel(false); + + let stale_timeout = cfg + .bandwidth + .as_ref() + .map(|b| b.stale_peer_timeout_secs) + .unwrap_or(30); + + let meter_window = cfg + .bandwidth + .as_ref() + .map(|b| b.meter_window_secs) + .unwrap_or(30); + + // Default: 1 Gbps cap. Updated at runtime via license gossip. + let limit_bytes_per_sec = gbps_to_bytes_per_sec(1.0); + + let bandwidth = Arc::new(BandwidthTracker::new()); + let cluster_bandwidth = Arc::new(ClusterBandwidthState::new(stale_timeout)); + let meter = Arc::new(BandwidthMeter::new(meter_window)); + let limiter = Arc::new(BandwidthLimiter::new(meter.clone(), limit_bytes_per_sec)); + + let bw = bandwidth.clone(); + let cbw = cluster_bandwidth.clone(); + let m = meter.clone(); + let cluster_cfg = cfg.clone(); + + let (ready_tx, ready_rx) = tokio::sync::oneshot::channel(); + + std::thread::Builder::new() + .name("cluster".into()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .thread_name("cluster-worker") + .build() + .expect("cluster runtime"); + + rt.block_on(node::run_cluster( + &cluster_cfg, + bw, + cbw, + m, + shutdown_rx, + ready_tx, + )); + })?; + + // Wait for the cluster to initialize (or fail). + let endpoint_id = ready_rx + .blocking_recv() + .map_err(|_| anyhow::anyhow!("cluster thread exited before initialization"))??; + + Ok(ClusterHandle { + bandwidth, + cluster_bandwidth, + meter, + limiter, + endpoint_id, + shutdown_tx, + }) +} diff --git a/src/cluster/node.rs b/src/cluster/node.rs new file mode 100644 index 0000000..93b1c20 --- /dev/null +++ b/src/cluster/node.rs @@ -0,0 +1,435 @@ +use std::net::SocketAddr; +use std::path::Path; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use anyhow::{Context, Result}; +use futures::stream::StreamExt; +use iroh::protocol::Router; +use iroh::{Endpoint, RelayMode, SecretKey}; +use iroh_gossip::net::Gossip; +use iroh_gossip::{api::Event, proto::TopicId, ALPN}; +use tokio::sync::watch; + +use crate::cluster::bandwidth::{BandwidthMeter, BandwidthTracker, ClusterBandwidthState}; +use crate::cluster::messages::{ClusterMessage, Payload}; +use crate::config::ClusterConfig; +use crate::metrics; + +/// Derive a deterministic TopicId from tenant UUID and channel name. +pub fn derive_topic(tenant: &str, channel: &str) -> TopicId { + let input = format!("/sunbeam-proxy/1.0/{tenant}/{channel}"); + let hash = blake3::hash(input.as_bytes()); + TopicId::from_bytes(*hash.as_bytes()) +} + +/// Load or generate a persistent ed25519 identity key. +fn load_or_generate_key(path: &Path) -> Result { + if path.exists() { + let data = std::fs::read(path).context("reading node key")?; + let bytes: [u8; 32] = data + .try_into() + .map_err(|_| anyhow::anyhow!("invalid key file length"))?; + Ok(SecretKey::from_bytes(&bytes)) + } else { + let key = SecretKey::generate(&mut rand::rng()); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).ok(); + } + std::fs::write(path, key.to_bytes()).context("writing node key")?; + tracing::info!(path = %path.display(), "generated new node identity key"); + Ok(key) + } +} + +/// Parse and pre-connect to bootstrap peers. +/// K8s mode starts with no bootstrap peers — relies on incoming connections. +/// Bootstrap mode parses "endpointid@host:port" and initiates connections. +async fn resolve_bootstrap_peers( + cfg: &ClusterConfig, + endpoint: &Endpoint, +) -> Vec { + match cfg.discovery.method.as_str() { + "k8s" => { + tracing::info!("k8s discovery mode: waiting for peers to connect"); + vec![] + } + "bootstrap" => { + let mut peers = Vec::new(); + for entry in cfg + .discovery + .bootstrap_peers + .as_deref() + .unwrap_or_default() + { + if let Some((id_str, addr_str)) = entry.split_once('@') { + match id_str.parse::() { + Ok(id) => { + if let Ok(addr) = addr_str.parse::() { + let node_addr = iroh::EndpointAddr::from_parts( + id, + [iroh::TransportAddr::Ip(addr)], + ); + // Pre-connect so the gossip layer can reach this peer. + match endpoint.connect(node_addr, ALPN).await { + Ok(conn) => { + tracing::info!(peer = %id, addr = %addr, "connected to bootstrap peer"); + // Drop the connection — gossip will reuse the underlying QUIC path. + drop(conn); + } + Err(e) => { + tracing::warn!(peer = %id, addr = %addr, error = %e, "failed to connect to bootstrap peer"); + } + } + } + peers.push(id); + } + Err(e) => { + tracing::warn!(entry, error = %e, "invalid bootstrap peer id"); + } + } + } else { + tracing::warn!(entry, "invalid bootstrap peer format (expected id@host:port)"); + } + } + peers + } + other => { + tracing::warn!(method = other, "unknown discovery method"); + vec![] + } + } +} + +/// Run the cluster node. Called from a dedicated OS thread with its own tokio runtime. +/// Sends the endpoint ID through `ready_tx` once initialization is complete, +/// then runs event loops until shutdown. +pub async fn run_cluster( + cfg: &ClusterConfig, + bandwidth: Arc, + cluster_bandwidth: Arc, + meter: Arc, + mut shutdown_rx: watch::Receiver, + ready_tx: tokio::sync::oneshot::Sender>, +) { + // Helper macro to send error through ready channel and return early. + macro_rules! try_init { + ($expr:expr) => { + match $expr { + Ok(v) => v, + Err(e) => { + let _ = ready_tx.send(Err(e.into())); + return; + } + } + }; + } + + // 1. Load or generate identity. + let key_path = cfg + .key_path + .as_deref() + .unwrap_or("/var/lib/sunbeam/node.key"); + let secret_key = try_init!(load_or_generate_key(Path::new(key_path))); + + // 2. Create iroh endpoint. + let builder = try_init!(Endpoint::builder() + .secret_key(secret_key) + .relay_mode(RelayMode::Disabled) + .alpns(vec![ALPN.to_vec()]) + .bind_addr(SocketAddr::from(([0, 0, 0, 0], cfg.gossip_port))) + .map_err(|e| anyhow::anyhow!("invalid bind address: {e}"))); + let endpoint = try_init!(builder.bind().await.context("binding iroh endpoint")); + + let my_id = endpoint.id(); + let my_id_bytes: [u8; 32] = *my_id.as_bytes(); + tracing::info!(endpoint_id = %my_id, port = cfg.gossip_port, "cluster node started"); + + // 3. Create gossip instance and router. + let gossip = Gossip::builder().spawn(endpoint.clone()); + let router = Router::builder(endpoint.clone()) + .accept(ALPN, gossip.clone()) + .spawn(); + + // 4. Resolve bootstrap peers. + let peers = resolve_bootstrap_peers(cfg, &endpoint).await; + + // 5. Derive topics. + let bandwidth_topic = derive_topic(&cfg.tenant, "bandwidth"); + let models_topic = derive_topic(&cfg.tenant, "models"); + let leader_topic = derive_topic(&cfg.tenant, "leader"); + let license_topic = derive_topic(&cfg.tenant, "license"); + + tracing::info!( + tenant = %cfg.tenant, + bandwidth_topic = ?bandwidth_topic, + "subscribing to gossip topics" + ); + + // 6. Subscribe to topics. + let bw_gossip_topic = try_init!(gossip + .subscribe(bandwidth_topic, peers.clone()) + .await + .context("subscribing to bandwidth topic")); + let (bw_sender, bw_receiver) = bw_gossip_topic.split(); + + let models_gossip_topic = try_init!(gossip + .subscribe(models_topic, peers.clone()) + .await + .context("subscribing to models topic")); + let (_models_sender, models_receiver) = models_gossip_topic.split(); + + let leader_gossip_topic = try_init!(gossip + .subscribe(leader_topic, peers.clone()) + .await + .context("subscribing to leader topic")); + let (_leader_sender, leader_receiver) = leader_gossip_topic.split(); + + let license_gossip_topic = try_init!(gossip + .subscribe(license_topic, peers) + .await + .context("subscribing to license topic")); + let (_license_sender, license_receiver) = license_gossip_topic.split(); + + // Initialization complete — signal the caller with our endpoint ID. + let _ = ready_tx.send(Ok(my_id)); + + let broadcast_interval = Duration::from_secs( + cfg.bandwidth + .as_ref() + .map(|b| b.broadcast_interval_secs) + .unwrap_or(5), + ); + let stale_timeout = Duration::from_secs( + cfg.bandwidth + .as_ref() + .map(|b| b.stale_peer_timeout_secs) + .unwrap_or(30), + ); + + // 7. Bandwidth broadcast loop. + let bw_tracker = bandwidth.clone(); + let bw_sender_clone = bw_sender.clone(); + let meter_broadcast = meter.clone(); + let my_id_bw = my_id_bytes; + let broadcast_task = tokio::spawn(async move { + let mut interval = tokio::time::interval(broadcast_interval); + loop { + interval.tick().await; + let snap = bw_tracker.snapshot_and_reset(); + // Feed local node's delta into the sliding window meter. + meter_broadcast.record_sample(snap.bytes_in, snap.bytes_out); + let msg = ClusterMessage { + version: 1, + sender: my_id_bw, + payload: Payload::BandwidthReport { + timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + bytes_in: snap.bytes_in, + bytes_out: snap.bytes_out, + request_count: snap.request_count, + cumulative_in: snap.cumulative_in, + cumulative_out: snap.cumulative_out, + }, + }; + match msg.encode() { + Ok(data) => { + if let Err(e) = bw_sender_clone.broadcast(data.into()).await { + tracing::debug!(error = %e, "bandwidth broadcast failed"); + } + metrics::CLUSTER_GOSSIP_MESSAGES + .with_label_values(&["bandwidth"]) + .inc(); + } + Err(e) => tracing::warn!(error = %e, "failed to encode bandwidth report"), + } + } + }); + + // 8. Bandwidth receive loop. + let cluster_bw = cluster_bandwidth.clone(); + let meter_recv = meter.clone(); + let bw_recv_task = tokio::spawn(handle_bandwidth_events(bw_receiver, cluster_bw, meter_recv)); + + // 9. Model receive loop (stub). + let models_recv_task = tokio::spawn(handle_model_events(models_receiver)); + + // 10. Leader topic (stub). + let leader_recv_task = tokio::spawn(handle_stub_events(leader_receiver, "leader")); + + // 11. License topic (stub). + let license_recv_task = tokio::spawn(handle_stub_events(license_receiver, "license")); + + // 12. Stale peer eviction loop. + let cluster_bw_evict = cluster_bandwidth.clone(); + let eviction_task = tokio::spawn(async move { + let mut interval = tokio::time::interval(stale_timeout); + loop { + interval.tick().await; + cluster_bw_evict.evict_stale(); + } + }); + + // 13. Aggregate rate metrics updater (every broadcast interval). + let meter_metrics = meter; + let metrics_task = tokio::spawn(async move { + let mut interval = tokio::time::interval(broadcast_interval); + loop { + interval.tick().await; + let rate = meter_metrics.aggregate_rate(); + metrics::CLUSTER_AGGREGATE_IN_RATE.set(rate.bytes_in_per_sec); + metrics::CLUSTER_AGGREGATE_OUT_RATE.set(rate.bytes_out_per_sec); + metrics::CLUSTER_AGGREGATE_TOTAL_RATE.set(rate.total_per_sec); + } + }); + + // Wait for shutdown or task failure. + tokio::select! { + _ = shutdown_rx.changed() => { + tracing::info!("cluster shutdown signal received"); + } + r = broadcast_task => { + tracing::error!(result = ?r, "bandwidth broadcast task exited"); + } + r = bw_recv_task => { + tracing::error!(result = ?r, "bandwidth receive task exited"); + } + r = models_recv_task => { + tracing::error!(result = ?r, "model receive task exited"); + } + r = leader_recv_task => { + tracing::error!(result = ?r, "leader receive task exited"); + } + r = license_recv_task => { + tracing::error!(result = ?r, "license receive task exited"); + } + r = eviction_task => { + tracing::error!(result = ?r, "eviction task exited"); + } + r = metrics_task => { + tracing::error!(result = ?r, "metrics task exited"); + } + } + + if let Err(e) = router.shutdown().await { + tracing::error!(error = %e, "router shutdown failed"); + } +} + +async fn handle_bandwidth_events( + mut receiver: iroh_gossip::api::GossipReceiver, + cluster_bw: Arc, + meter: Arc, +) { + while let Some(Ok(event)) = receiver.next().await { + if let Event::Received(message) = event { + match ClusterMessage::decode(&message.content) { + Ok(ClusterMessage { + sender, + payload: + Payload::BandwidthReport { + cumulative_in, + cumulative_out, + bytes_in, + bytes_out, + request_count, + .. + }, + .. + }) => { + cluster_bw.update_peer(sender, cumulative_in, cumulative_out); + // Feed remote peer's delta into the sliding window meter. + meter.record_sample(bytes_in, bytes_out); + metrics::CLUSTER_GOSSIP_MESSAGES + .with_label_values(&["bandwidth"]) + .inc(); + tracing::debug!( + sender = hex::encode(sender), + bytes_in, + bytes_out, + request_count, + "received bandwidth report" + ); + } + Ok(_) => tracing::debug!("unexpected payload on bandwidth topic"), + Err(e) => tracing::debug!(error = %e, "failed to decode bandwidth message"), + } + } + } +} + +async fn handle_model_events(mut receiver: iroh_gossip::api::GossipReceiver) { + while let Some(Ok(event)) = receiver.next().await { + if let Event::Received(message) = event { + match ClusterMessage::decode(&message.content) { + Ok(ClusterMessage { + payload: Payload::ModelAnnounce { model_type, hash, total_size, .. }, + .. + }) => { + tracing::info!( + model_type, + hash = hex::encode(hash), + total_size, + "received model announce (stub — ignoring)" + ); + metrics::CLUSTER_MODEL_UPDATES + .with_label_values(&[&model_type, "ignored"]) + .inc(); + } + Ok(ClusterMessage { + payload: Payload::ModelChunk { hash, chunk_index, .. }, + .. + }) => { + tracing::debug!( + hash = hex::encode(hash), + chunk_index, + "received model chunk (stub — ignoring)" + ); + } + Ok(_) => {} + Err(e) => tracing::debug!(error = %e, "failed to decode model message"), + } + } + } +} + +async fn handle_stub_events(mut receiver: iroh_gossip::api::GossipReceiver, channel: &str) { + while let Some(Ok(event)) = receiver.next().await { + if let Event::Received(message) = event { + if let Ok(msg) = ClusterMessage::decode(&message.content) { + tracing::debug!(?msg, channel, "received stub message"); + metrics::CLUSTER_GOSSIP_MESSAGES + .with_label_values(&[channel]) + .inc(); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn topic_derivation_deterministic() { + let t1 = derive_topic("550e8400-e29b-41d4-a716-446655440000", "bandwidth"); + let t2 = derive_topic("550e8400-e29b-41d4-a716-446655440000", "bandwidth"); + assert_eq!(t1, t2); + } + + #[test] + fn topic_derivation_different_channels() { + let bw = derive_topic("550e8400-e29b-41d4-a716-446655440000", "bandwidth"); + let models = derive_topic("550e8400-e29b-41d4-a716-446655440000", "models"); + assert_ne!(bw, models); + } + + #[test] + fn topic_derivation_different_tenants() { + let t1 = derive_topic("550e8400-e29b-41d4-a716-446655440000", "bandwidth"); + let t2 = derive_topic("660e8400-e29b-41d4-a716-446655440001", "bandwidth"); + assert_ne!(t1, t2); + } +}