feat(cluster): implement gossip-based cluster subsystem with iroh

Core cluster module with four gossip channels (bandwidth, models,
leader, license) over iroh-gossip HyParView/PlumTree. Includes:
- BandwidthTracker: atomic per-node counters with zero hot-path contention
- ClusterBandwidthState: peer aggregation with stale eviction
- BandwidthMeter: sliding-window aggregate rate (power-of-2 MiB units)
- BandwidthLimiter: runtime-mutable bandwidth cap (default 1 Gbps)
- ClusterHandle/spawn_cluster: dedicated OS thread + tokio runtime
- Bincode-serialized message envelope with versioned payloads
- Bootstrap and k8s peer discovery modes
- Persistent ed25519 identity for stable EndpointId across restarts

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:20 +00:00
parent ad5c7f0afb
commit 5d279f992b
4 changed files with 1120 additions and 0 deletions

437
src/cluster/bandwidth.rs Normal file
View File

@@ -0,0 +1,437 @@
use rustc_hash::FxHashMap;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use std::time::{Duration, Instant};
/// Per-node atomic bandwidth counters. Zero contention on the hot path
/// (single `fetch_add` per counter per request).
pub struct BandwidthTracker {
/// Bytes received since last broadcast (reset each cycle).
bytes_in: AtomicU64,
/// Bytes sent since last broadcast (reset each cycle).
bytes_out: AtomicU64,
/// Requests since last broadcast (reset each cycle).
request_count: AtomicU64,
/// Monotonic total bytes received (never reset).
cumulative_in: AtomicU64,
/// Monotonic total bytes sent (never reset).
cumulative_out: AtomicU64,
}
impl BandwidthTracker {
pub fn new() -> Self {
Self {
bytes_in: AtomicU64::new(0),
bytes_out: AtomicU64::new(0),
request_count: AtomicU64::new(0),
cumulative_in: AtomicU64::new(0),
cumulative_out: AtomicU64::new(0),
}
}
/// Record a completed request's byte counts.
#[inline]
pub fn record(&self, bytes_in: u64, bytes_out: u64) {
self.bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
self.request_count.fetch_add(1, Ordering::Relaxed);
self.cumulative_in.fetch_add(bytes_in, Ordering::Relaxed);
self.cumulative_out.fetch_add(bytes_out, Ordering::Relaxed);
}
/// Take a snapshot and reset the per-interval counters.
pub fn snapshot_and_reset(&self) -> BandwidthSnapshot {
BandwidthSnapshot {
bytes_in: self.bytes_in.swap(0, Ordering::Relaxed),
bytes_out: self.bytes_out.swap(0, Ordering::Relaxed),
request_count: self.request_count.swap(0, Ordering::Relaxed),
cumulative_in: self.cumulative_in.load(Ordering::Relaxed),
cumulative_out: self.cumulative_out.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct BandwidthSnapshot {
pub bytes_in: u64,
pub bytes_out: u64,
pub request_count: u64,
pub cumulative_in: u64,
pub cumulative_out: u64,
}
/// Aggregated bandwidth state from all cluster peers.
pub struct ClusterBandwidthState {
peers: RwLock<FxHashMap<[u8; 32], PeerEntry>>,
/// Sum of all peers' cumulative bytes in (updated on each report).
pub total_bytes_in: AtomicU64,
/// Sum of all peers' cumulative bytes out.
pub total_bytes_out: AtomicU64,
/// Number of active (non-stale) peers.
pub peer_count: AtomicU64,
/// Stale peer timeout.
stale_timeout_secs: u64,
}
struct PeerEntry {
cumulative_in: u64,
cumulative_out: u64,
last_seen: Instant,
}
impl ClusterBandwidthState {
pub fn new(stale_timeout_secs: u64) -> Self {
Self {
peers: RwLock::new(FxHashMap::default()),
total_bytes_in: AtomicU64::new(0),
total_bytes_out: AtomicU64::new(0),
peer_count: AtomicU64::new(0),
stale_timeout_secs,
}
}
/// Update a peer's bandwidth state from a received report.
pub fn update_peer(&self, peer_id: [u8; 32], cumulative_in: u64, cumulative_out: u64) {
let mut peers = self.peers.write().unwrap();
peers.insert(
peer_id,
PeerEntry {
cumulative_in,
cumulative_out,
last_seen: Instant::now(),
},
);
self.recalculate(&peers);
}
/// Remove peers that haven't reported within the stale timeout.
pub fn evict_stale(&self) {
let mut peers = self.peers.write().unwrap();
let cutoff = Instant::now() - std::time::Duration::from_secs(self.stale_timeout_secs);
peers.retain(|_, entry| entry.last_seen > cutoff);
self.recalculate(&peers);
}
fn recalculate(&self, peers: &FxHashMap<[u8; 32], PeerEntry>) {
let mut total_in = 0u64;
let mut total_out = 0u64;
for entry in peers.values() {
total_in = total_in.saturating_add(entry.cumulative_in);
total_out = total_out.saturating_add(entry.cumulative_out);
}
self.total_bytes_in.store(total_in, Ordering::Relaxed);
self.total_bytes_out.store(total_out, Ordering::Relaxed);
self.peer_count.store(peers.len() as u64, Ordering::Relaxed);
}
}
/// Aggregate bandwidth rate across the entire cluster, computed from a
/// sliding window of samples from all nodes (local + remote).
///
/// Each broadcast cycle produces one sample per node. With a 5s broadcast
/// interval and 30s window, the deque holds ~6 × node_count entries — tiny.
pub struct BandwidthMeter {
samples: RwLock<VecDeque<Sample>>,
window: Duration,
}
struct Sample {
time: Instant,
bytes_in: u64,
bytes_out: u64,
}
/// Snapshot of the aggregate cluster-wide bandwidth rate.
/// All rates are in bytes/sec. Use the `*_mib_per_sec` methods for MiB/s (power-of-2).
#[derive(Debug, Clone, Copy)]
pub struct AggregateRate {
/// Inbound bytes/sec across all nodes.
pub bytes_in_per_sec: f64,
/// Outbound bytes/sec across all nodes.
pub bytes_out_per_sec: f64,
/// Total (in + out) bytes/sec.
pub total_per_sec: f64,
/// Number of samples in the window.
pub sample_count: usize,
}
const BYTES_PER_MIB: f64 = 1_048_576.0; // 1024 * 1024
impl AggregateRate {
/// Inbound rate in MiB/s (power-of-2).
pub fn in_mib_per_sec(&self) -> f64 {
self.bytes_in_per_sec / BYTES_PER_MIB
}
/// Outbound rate in MiB/s (power-of-2).
pub fn out_mib_per_sec(&self) -> f64 {
self.bytes_out_per_sec / BYTES_PER_MIB
}
/// Total rate in MiB/s (power-of-2).
pub fn total_mib_per_sec(&self) -> f64 {
self.total_per_sec / BYTES_PER_MIB
}
}
impl BandwidthMeter {
pub fn new(window_secs: u64) -> Self {
Self {
samples: RwLock::new(VecDeque::new()),
window: Duration::from_secs(window_secs),
}
}
/// Record a bandwidth sample (from local broadcast or remote peer report).
pub fn record_sample(&self, bytes_in: u64, bytes_out: u64) {
let now = Instant::now();
let mut samples = self.samples.write().unwrap();
samples.push_back(Sample {
time: now,
bytes_in,
bytes_out,
});
// Evict samples outside the window.
let cutoff = now - self.window;
while samples.front().is_some_and(|s| s.time < cutoff) {
samples.pop_front();
}
}
/// Compute the aggregate bandwidth rate over the sliding window.
pub fn aggregate_rate(&self) -> AggregateRate {
let now = Instant::now();
let samples = self.samples.read().unwrap();
let cutoff = now - self.window;
let mut total_in = 0u64;
let mut total_out = 0u64;
let mut count = 0usize;
for s in samples.iter() {
if s.time >= cutoff {
total_in = total_in.saturating_add(s.bytes_in);
total_out = total_out.saturating_add(s.bytes_out);
count += 1;
}
}
let window_secs = self.window.as_secs_f64();
let bytes_in_per_sec = total_in as f64 / window_secs;
let bytes_out_per_sec = total_out as f64 / window_secs;
AggregateRate {
bytes_in_per_sec,
bytes_out_per_sec,
total_per_sec: bytes_in_per_sec + bytes_out_per_sec,
sample_count: count,
}
}
}
/// Cluster-wide bandwidth limiter. Compares the aggregate rate from the
/// `BandwidthMeter` against a configurable cap (bytes/sec). The limit is
/// stored as an `AtomicU64` so it can be updated at runtime (e.g. when a
/// license quota changes via gossip).
pub struct BandwidthLimiter {
/// Max total (in + out) bytes/sec across the cluster. 0 = unlimited.
limit_bytes_per_sec: AtomicU64,
meter: std::sync::Arc<BandwidthMeter>,
}
/// Result of a bandwidth limit check.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BandwidthLimitResult {
Allow,
Reject,
}
impl BandwidthLimiter {
pub fn new(meter: std::sync::Arc<BandwidthMeter>, limit_bytes_per_sec: u64) -> Self {
Self {
limit_bytes_per_sec: AtomicU64::new(limit_bytes_per_sec),
meter,
}
}
/// Check whether the cluster is currently over its bandwidth cap.
#[inline]
pub fn check(&self) -> BandwidthLimitResult {
let limit = self.limit_bytes_per_sec.load(Ordering::Relaxed);
if limit == 0 {
return BandwidthLimitResult::Allow;
}
let rate = self.meter.aggregate_rate();
if rate.total_per_sec > limit as f64 {
BandwidthLimitResult::Reject
} else {
BandwidthLimitResult::Allow
}
}
/// Update the bandwidth cap at runtime (e.g. from a license update).
pub fn set_limit(&self, bytes_per_sec: u64) {
self.limit_bytes_per_sec.store(bytes_per_sec, Ordering::Relaxed);
}
/// Current limit in bytes/sec (0 = unlimited).
pub fn limit(&self) -> u64 {
self.limit_bytes_per_sec.load(Ordering::Relaxed)
}
/// Current aggregate rate snapshot.
pub fn current_rate(&self) -> AggregateRate {
self.meter.aggregate_rate()
}
}
/// Convert Gbps (base-10, as used in networking/billing) to bytes/sec.
/// 1 Gbps = 1_000_000_000 bits/sec = 125_000_000 bytes/sec.
pub fn gbps_to_bytes_per_sec(gbps: f64) -> u64 {
(gbps * 125_000_000.0) as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tracker_record_and_snapshot() {
let tracker = BandwidthTracker::new();
tracker.record(100, 200);
tracker.record(50, 75);
let snap = tracker.snapshot_and_reset();
assert_eq!(snap.bytes_in, 150);
assert_eq!(snap.bytes_out, 275);
assert_eq!(snap.request_count, 2);
assert_eq!(snap.cumulative_in, 150);
assert_eq!(snap.cumulative_out, 275);
// After reset, interval counters are zero but cumulative persists.
tracker.record(10, 20);
let snap2 = tracker.snapshot_and_reset();
assert_eq!(snap2.bytes_in, 10);
assert_eq!(snap2.bytes_out, 20);
assert_eq!(snap2.request_count, 1);
assert_eq!(snap2.cumulative_in, 160);
assert_eq!(snap2.cumulative_out, 295);
}
#[test]
fn meter_aggregate_rate() {
let meter = BandwidthMeter::new(30);
// Simulate 6 samples over the window (one every 5s).
// In reality they come from multiple nodes; we don't care about source.
meter.record_sample(500_000_000, 100_000_000); // 500MB in, 100MB out
meter.record_sample(50_000_000, 10_000_000); // 50MB in, 10MB out
let rate = meter.aggregate_rate();
assert_eq!(rate.sample_count, 2);
// total_in = 550MB over 30s window = ~18.3 MB/s
let expected_in = 550_000_000.0 / 30.0;
assert!(
(rate.bytes_in_per_sec - expected_in).abs() < 1.0,
"expected ~{expected_in}, got {}",
rate.bytes_in_per_sec
);
let expected_out = 110_000_000.0 / 30.0;
assert!(
(rate.bytes_out_per_sec - expected_out).abs() < 1.0,
"expected ~{expected_out}, got {}",
rate.bytes_out_per_sec
);
assert!(
(rate.total_per_sec - (expected_in + expected_out)).abs() < 1.0,
);
}
#[test]
fn meter_evicts_old_samples() {
// Use a 1-second window so we can test eviction quickly.
let meter = BandwidthMeter::new(1);
meter.record_sample(1000, 2000);
std::thread::sleep(std::time::Duration::from_millis(1100));
// Sample should be evicted.
meter.record_sample(500, 600);
let rate = meter.aggregate_rate();
assert_eq!(rate.sample_count, 1, "old sample should be evicted");
// Only the second sample should be counted.
assert!((rate.bytes_in_per_sec - 500.0).abs() < 1.0);
}
#[test]
fn meter_empty_returns_zero() {
let meter = BandwidthMeter::new(30);
let rate = meter.aggregate_rate();
assert_eq!(rate.sample_count, 0);
assert_eq!(rate.bytes_in_per_sec, 0.0);
assert_eq!(rate.bytes_out_per_sec, 0.0);
assert_eq!(rate.total_per_sec, 0.0);
}
#[test]
fn cluster_state_aggregation() {
let state = ClusterBandwidthState::new(30);
state.update_peer([1u8; 32], 1000, 2000);
state.update_peer([2u8; 32], 3000, 4000);
assert_eq!(state.total_bytes_in.load(Ordering::Relaxed), 4000);
assert_eq!(state.total_bytes_out.load(Ordering::Relaxed), 6000);
assert_eq!(state.peer_count.load(Ordering::Relaxed), 2);
// Update existing peer.
state.update_peer([1u8; 32], 1500, 2500);
assert_eq!(state.total_bytes_in.load(Ordering::Relaxed), 4500);
assert_eq!(state.total_bytes_out.load(Ordering::Relaxed), 6500);
assert_eq!(state.peer_count.load(Ordering::Relaxed), 2);
}
#[test]
fn limiter_allows_when_unlimited() {
let meter = std::sync::Arc::new(BandwidthMeter::new(30));
meter.record_sample(999_999_999, 999_999_999);
let limiter = BandwidthLimiter::new(meter, 0); // 0 = unlimited
assert_eq!(limiter.check(), BandwidthLimitResult::Allow);
}
#[test]
fn limiter_allows_under_cap() {
let meter = std::sync::Arc::new(BandwidthMeter::new(30));
// 1 GiB total over 30s = ~33 MiB/s ≈ ~35 MB/s — well under 1 Gbps
meter.record_sample(500_000_000, 500_000_000);
let limiter = BandwidthLimiter::new(meter, gbps_to_bytes_per_sec(1.0));
assert_eq!(limiter.check(), BandwidthLimitResult::Allow);
}
#[test]
fn limiter_rejects_over_cap() {
let meter = std::sync::Arc::new(BandwidthMeter::new(1)); // 1s window
// 200 MB total in 1s window = 200 MB/s > 125 MB/s (1 Gbps)
meter.record_sample(100_000_000, 100_000_000);
let limiter = BandwidthLimiter::new(meter, gbps_to_bytes_per_sec(1.0));
assert_eq!(limiter.check(), BandwidthLimitResult::Reject);
}
#[test]
fn limiter_set_limit_runtime() {
let meter = std::sync::Arc::new(BandwidthMeter::new(1));
meter.record_sample(100_000_000, 100_000_000); // 200 MB/s
let limiter = BandwidthLimiter::new(meter, gbps_to_bytes_per_sec(1.0));
assert_eq!(limiter.check(), BandwidthLimitResult::Reject);
// Raise the limit to 10 Gbps → should now allow.
limiter.set_limit(gbps_to_bytes_per_sec(10.0));
assert_eq!(limiter.check(), BandwidthLimitResult::Allow);
assert_eq!(limiter.limit(), gbps_to_bytes_per_sec(10.0));
}
#[test]
fn gbps_conversion() {
assert_eq!(gbps_to_bytes_per_sec(1.0), 125_000_000);
assert_eq!(gbps_to_bytes_per_sec(10.0), 1_250_000_000);
assert_eq!(gbps_to_bytes_per_sec(0.0), 0);
}
}

144
src/cluster/messages.rs Normal file
View File

@@ -0,0 +1,144 @@
use serde::{Deserialize, Serialize};
/// Envelope for all cluster gossip messages.
/// Serialized with bincode before broadcast.
#[derive(Debug, Serialize, Deserialize)]
pub struct ClusterMessage {
pub version: u8,
pub sender: [u8; 32],
pub payload: Payload,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Payload {
BandwidthReport {
timestamp: u64,
bytes_in: u64,
bytes_out: u64,
request_count: u64,
cumulative_in: u64,
cumulative_out: u64,
},
ModelAnnounce {
model_type: String,
hash: [u8; 32],
total_size: u64,
chunk_count: u32,
},
ModelChunk {
hash: [u8; 32],
chunk_index: u32,
data: Vec<u8>,
},
LeaderHeartbeat {
term: u64,
leader_id: [u8; 32],
},
LicenseQuota {
max_bytes: u64,
current_bytes: u64,
},
}
impl ClusterMessage {
pub fn encode(&self) -> Result<Vec<u8>, bincode::Error> {
bincode::serialize(self)
}
pub fn decode(data: &[u8]) -> Result<Self, bincode::Error> {
bincode::deserialize(data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_bandwidth_report() {
let msg = ClusterMessage {
version: 1,
sender: [42u8; 32],
payload: Payload::BandwidthReport {
timestamp: 1234567890,
bytes_in: 1000,
bytes_out: 2000,
request_count: 50,
cumulative_in: 100_000,
cumulative_out: 200_000,
},
};
let encoded = msg.encode().unwrap();
let decoded = ClusterMessage::decode(&encoded).unwrap();
assert_eq!(decoded.version, 1);
assert_eq!(decoded.sender, [42u8; 32]);
match decoded.payload {
Payload::BandwidthReport {
timestamp,
bytes_in,
bytes_out,
request_count,
cumulative_in,
cumulative_out,
} => {
assert_eq!(timestamp, 1234567890);
assert_eq!(bytes_in, 1000);
assert_eq!(bytes_out, 2000);
assert_eq!(request_count, 50);
assert_eq!(cumulative_in, 100_000);
assert_eq!(cumulative_out, 200_000);
}
_ => panic!("wrong payload variant"),
}
}
#[test]
fn roundtrip_model_announce() {
let msg = ClusterMessage {
version: 1,
sender: [1u8; 32],
payload: Payload::ModelAnnounce {
model_type: "scanner".to_string(),
hash: [0xAA; 32],
total_size: 1_000_000,
chunk_count: 16,
},
};
let encoded = msg.encode().unwrap();
let decoded = ClusterMessage::decode(&encoded).unwrap();
match decoded.payload {
Payload::ModelAnnounce {
model_type,
total_size,
chunk_count,
..
} => {
assert_eq!(model_type, "scanner");
assert_eq!(total_size, 1_000_000);
assert_eq!(chunk_count, 16);
}
_ => panic!("wrong payload variant"),
}
}
#[test]
fn roundtrip_leader_heartbeat() {
let msg = ClusterMessage {
version: 1,
sender: [7u8; 32],
payload: Payload::LeaderHeartbeat {
term: 3,
leader_id: [7u8; 32],
},
};
let encoded = msg.encode().unwrap();
let decoded = ClusterMessage::decode(&encoded).unwrap();
match decoded.payload {
Payload::LeaderHeartbeat { term, leader_id } => {
assert_eq!(term, 3);
assert_eq!(leader_id, [7u8; 32]);
}
_ => panic!("wrong payload variant"),
}
}
}

104
src/cluster/mod.rs Normal file
View File

@@ -0,0 +1,104 @@
pub mod bandwidth;
pub mod messages;
pub mod node;
use std::sync::Arc;
use anyhow::Result;
use tokio::sync::watch;
use crate::config::ClusterConfig;
use bandwidth::{
gbps_to_bytes_per_sec, BandwidthLimiter, BandwidthMeter, BandwidthTracker,
ClusterBandwidthState,
};
pub struct ClusterHandle {
pub bandwidth: Arc<BandwidthTracker>,
pub cluster_bandwidth: Arc<ClusterBandwidthState>,
/// Sliding-window aggregate bandwidth rate across the cluster.
pub meter: Arc<BandwidthMeter>,
/// Cluster-wide bandwidth limiter (0 = unlimited).
pub limiter: Arc<BandwidthLimiter>,
pub endpoint_id: iroh::PublicKey,
shutdown_tx: watch::Sender<bool>,
}
impl ClusterHandle {
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(true);
}
}
impl Drop for ClusterHandle {
fn drop(&mut self) {
self.shutdown();
}
}
/// Spawn the cluster subsystem on a dedicated OS thread with its own tokio runtime.
/// Returns a handle for bandwidth recording and cluster state queries.
pub fn spawn_cluster(cfg: &ClusterConfig) -> Result<ClusterHandle> {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let stale_timeout = cfg
.bandwidth
.as_ref()
.map(|b| b.stale_peer_timeout_secs)
.unwrap_or(30);
let meter_window = cfg
.bandwidth
.as_ref()
.map(|b| b.meter_window_secs)
.unwrap_or(30);
// Default: 1 Gbps cap. Updated at runtime via license gossip.
let limit_bytes_per_sec = gbps_to_bytes_per_sec(1.0);
let bandwidth = Arc::new(BandwidthTracker::new());
let cluster_bandwidth = Arc::new(ClusterBandwidthState::new(stale_timeout));
let meter = Arc::new(BandwidthMeter::new(meter_window));
let limiter = Arc::new(BandwidthLimiter::new(meter.clone(), limit_bytes_per_sec));
let bw = bandwidth.clone();
let cbw = cluster_bandwidth.clone();
let m = meter.clone();
let cluster_cfg = cfg.clone();
let (ready_tx, ready_rx) = tokio::sync::oneshot::channel();
std::thread::Builder::new()
.name("cluster".into())
.spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.thread_name("cluster-worker")
.build()
.expect("cluster runtime");
rt.block_on(node::run_cluster(
&cluster_cfg,
bw,
cbw,
m,
shutdown_rx,
ready_tx,
));
})?;
// Wait for the cluster to initialize (or fail).
let endpoint_id = ready_rx
.blocking_recv()
.map_err(|_| anyhow::anyhow!("cluster thread exited before initialization"))??;
Ok(ClusterHandle {
bandwidth,
cluster_bandwidth,
meter,
limiter,
endpoint_id,
shutdown_tx,
})
}

435
src/cluster/node.rs Normal file
View File

@@ -0,0 +1,435 @@
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use anyhow::{Context, Result};
use futures::stream::StreamExt;
use iroh::protocol::Router;
use iroh::{Endpoint, RelayMode, SecretKey};
use iroh_gossip::net::Gossip;
use iroh_gossip::{api::Event, proto::TopicId, ALPN};
use tokio::sync::watch;
use crate::cluster::bandwidth::{BandwidthMeter, BandwidthTracker, ClusterBandwidthState};
use crate::cluster::messages::{ClusterMessage, Payload};
use crate::config::ClusterConfig;
use crate::metrics;
/// Derive a deterministic TopicId from tenant UUID and channel name.
pub fn derive_topic(tenant: &str, channel: &str) -> TopicId {
let input = format!("/sunbeam-proxy/1.0/{tenant}/{channel}");
let hash = blake3::hash(input.as_bytes());
TopicId::from_bytes(*hash.as_bytes())
}
/// Load or generate a persistent ed25519 identity key.
fn load_or_generate_key(path: &Path) -> Result<SecretKey> {
if path.exists() {
let data = std::fs::read(path).context("reading node key")?;
let bytes: [u8; 32] = data
.try_into()
.map_err(|_| anyhow::anyhow!("invalid key file length"))?;
Ok(SecretKey::from_bytes(&bytes))
} else {
let key = SecretKey::generate(&mut rand::rng());
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).ok();
}
std::fs::write(path, key.to_bytes()).context("writing node key")?;
tracing::info!(path = %path.display(), "generated new node identity key");
Ok(key)
}
}
/// Parse and pre-connect to bootstrap peers.
/// K8s mode starts with no bootstrap peers — relies on incoming connections.
/// Bootstrap mode parses "endpointid@host:port" and initiates connections.
async fn resolve_bootstrap_peers(
cfg: &ClusterConfig,
endpoint: &Endpoint,
) -> Vec<iroh::PublicKey> {
match cfg.discovery.method.as_str() {
"k8s" => {
tracing::info!("k8s discovery mode: waiting for peers to connect");
vec![]
}
"bootstrap" => {
let mut peers = Vec::new();
for entry in cfg
.discovery
.bootstrap_peers
.as_deref()
.unwrap_or_default()
{
if let Some((id_str, addr_str)) = entry.split_once('@') {
match id_str.parse::<iroh::PublicKey>() {
Ok(id) => {
if let Ok(addr) = addr_str.parse::<SocketAddr>() {
let node_addr = iroh::EndpointAddr::from_parts(
id,
[iroh::TransportAddr::Ip(addr)],
);
// Pre-connect so the gossip layer can reach this peer.
match endpoint.connect(node_addr, ALPN).await {
Ok(conn) => {
tracing::info!(peer = %id, addr = %addr, "connected to bootstrap peer");
// Drop the connection — gossip will reuse the underlying QUIC path.
drop(conn);
}
Err(e) => {
tracing::warn!(peer = %id, addr = %addr, error = %e, "failed to connect to bootstrap peer");
}
}
}
peers.push(id);
}
Err(e) => {
tracing::warn!(entry, error = %e, "invalid bootstrap peer id");
}
}
} else {
tracing::warn!(entry, "invalid bootstrap peer format (expected id@host:port)");
}
}
peers
}
other => {
tracing::warn!(method = other, "unknown discovery method");
vec![]
}
}
}
/// Run the cluster node. Called from a dedicated OS thread with its own tokio runtime.
/// Sends the endpoint ID through `ready_tx` once initialization is complete,
/// then runs event loops until shutdown.
pub async fn run_cluster(
cfg: &ClusterConfig,
bandwidth: Arc<BandwidthTracker>,
cluster_bandwidth: Arc<ClusterBandwidthState>,
meter: Arc<BandwidthMeter>,
mut shutdown_rx: watch::Receiver<bool>,
ready_tx: tokio::sync::oneshot::Sender<Result<iroh::PublicKey>>,
) {
// Helper macro to send error through ready channel and return early.
macro_rules! try_init {
($expr:expr) => {
match $expr {
Ok(v) => v,
Err(e) => {
let _ = ready_tx.send(Err(e.into()));
return;
}
}
};
}
// 1. Load or generate identity.
let key_path = cfg
.key_path
.as_deref()
.unwrap_or("/var/lib/sunbeam/node.key");
let secret_key = try_init!(load_or_generate_key(Path::new(key_path)));
// 2. Create iroh endpoint.
let builder = try_init!(Endpoint::builder()
.secret_key(secret_key)
.relay_mode(RelayMode::Disabled)
.alpns(vec![ALPN.to_vec()])
.bind_addr(SocketAddr::from(([0, 0, 0, 0], cfg.gossip_port)))
.map_err(|e| anyhow::anyhow!("invalid bind address: {e}")));
let endpoint = try_init!(builder.bind().await.context("binding iroh endpoint"));
let my_id = endpoint.id();
let my_id_bytes: [u8; 32] = *my_id.as_bytes();
tracing::info!(endpoint_id = %my_id, port = cfg.gossip_port, "cluster node started");
// 3. Create gossip instance and router.
let gossip = Gossip::builder().spawn(endpoint.clone());
let router = Router::builder(endpoint.clone())
.accept(ALPN, gossip.clone())
.spawn();
// 4. Resolve bootstrap peers.
let peers = resolve_bootstrap_peers(cfg, &endpoint).await;
// 5. Derive topics.
let bandwidth_topic = derive_topic(&cfg.tenant, "bandwidth");
let models_topic = derive_topic(&cfg.tenant, "models");
let leader_topic = derive_topic(&cfg.tenant, "leader");
let license_topic = derive_topic(&cfg.tenant, "license");
tracing::info!(
tenant = %cfg.tenant,
bandwidth_topic = ?bandwidth_topic,
"subscribing to gossip topics"
);
// 6. Subscribe to topics.
let bw_gossip_topic = try_init!(gossip
.subscribe(bandwidth_topic, peers.clone())
.await
.context("subscribing to bandwidth topic"));
let (bw_sender, bw_receiver) = bw_gossip_topic.split();
let models_gossip_topic = try_init!(gossip
.subscribe(models_topic, peers.clone())
.await
.context("subscribing to models topic"));
let (_models_sender, models_receiver) = models_gossip_topic.split();
let leader_gossip_topic = try_init!(gossip
.subscribe(leader_topic, peers.clone())
.await
.context("subscribing to leader topic"));
let (_leader_sender, leader_receiver) = leader_gossip_topic.split();
let license_gossip_topic = try_init!(gossip
.subscribe(license_topic, peers)
.await
.context("subscribing to license topic"));
let (_license_sender, license_receiver) = license_gossip_topic.split();
// Initialization complete — signal the caller with our endpoint ID.
let _ = ready_tx.send(Ok(my_id));
let broadcast_interval = Duration::from_secs(
cfg.bandwidth
.as_ref()
.map(|b| b.broadcast_interval_secs)
.unwrap_or(5),
);
let stale_timeout = Duration::from_secs(
cfg.bandwidth
.as_ref()
.map(|b| b.stale_peer_timeout_secs)
.unwrap_or(30),
);
// 7. Bandwidth broadcast loop.
let bw_tracker = bandwidth.clone();
let bw_sender_clone = bw_sender.clone();
let meter_broadcast = meter.clone();
let my_id_bw = my_id_bytes;
let broadcast_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(broadcast_interval);
loop {
interval.tick().await;
let snap = bw_tracker.snapshot_and_reset();
// Feed local node's delta into the sliding window meter.
meter_broadcast.record_sample(snap.bytes_in, snap.bytes_out);
let msg = ClusterMessage {
version: 1,
sender: my_id_bw,
payload: Payload::BandwidthReport {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
bytes_in: snap.bytes_in,
bytes_out: snap.bytes_out,
request_count: snap.request_count,
cumulative_in: snap.cumulative_in,
cumulative_out: snap.cumulative_out,
},
};
match msg.encode() {
Ok(data) => {
if let Err(e) = bw_sender_clone.broadcast(data.into()).await {
tracing::debug!(error = %e, "bandwidth broadcast failed");
}
metrics::CLUSTER_GOSSIP_MESSAGES
.with_label_values(&["bandwidth"])
.inc();
}
Err(e) => tracing::warn!(error = %e, "failed to encode bandwidth report"),
}
}
});
// 8. Bandwidth receive loop.
let cluster_bw = cluster_bandwidth.clone();
let meter_recv = meter.clone();
let bw_recv_task = tokio::spawn(handle_bandwidth_events(bw_receiver, cluster_bw, meter_recv));
// 9. Model receive loop (stub).
let models_recv_task = tokio::spawn(handle_model_events(models_receiver));
// 10. Leader topic (stub).
let leader_recv_task = tokio::spawn(handle_stub_events(leader_receiver, "leader"));
// 11. License topic (stub).
let license_recv_task = tokio::spawn(handle_stub_events(license_receiver, "license"));
// 12. Stale peer eviction loop.
let cluster_bw_evict = cluster_bandwidth.clone();
let eviction_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(stale_timeout);
loop {
interval.tick().await;
cluster_bw_evict.evict_stale();
}
});
// 13. Aggregate rate metrics updater (every broadcast interval).
let meter_metrics = meter;
let metrics_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(broadcast_interval);
loop {
interval.tick().await;
let rate = meter_metrics.aggregate_rate();
metrics::CLUSTER_AGGREGATE_IN_RATE.set(rate.bytes_in_per_sec);
metrics::CLUSTER_AGGREGATE_OUT_RATE.set(rate.bytes_out_per_sec);
metrics::CLUSTER_AGGREGATE_TOTAL_RATE.set(rate.total_per_sec);
}
});
// Wait for shutdown or task failure.
tokio::select! {
_ = shutdown_rx.changed() => {
tracing::info!("cluster shutdown signal received");
}
r = broadcast_task => {
tracing::error!(result = ?r, "bandwidth broadcast task exited");
}
r = bw_recv_task => {
tracing::error!(result = ?r, "bandwidth receive task exited");
}
r = models_recv_task => {
tracing::error!(result = ?r, "model receive task exited");
}
r = leader_recv_task => {
tracing::error!(result = ?r, "leader receive task exited");
}
r = license_recv_task => {
tracing::error!(result = ?r, "license receive task exited");
}
r = eviction_task => {
tracing::error!(result = ?r, "eviction task exited");
}
r = metrics_task => {
tracing::error!(result = ?r, "metrics task exited");
}
}
if let Err(e) = router.shutdown().await {
tracing::error!(error = %e, "router shutdown failed");
}
}
async fn handle_bandwidth_events(
mut receiver: iroh_gossip::api::GossipReceiver,
cluster_bw: Arc<ClusterBandwidthState>,
meter: Arc<BandwidthMeter>,
) {
while let Some(Ok(event)) = receiver.next().await {
if let Event::Received(message) = event {
match ClusterMessage::decode(&message.content) {
Ok(ClusterMessage {
sender,
payload:
Payload::BandwidthReport {
cumulative_in,
cumulative_out,
bytes_in,
bytes_out,
request_count,
..
},
..
}) => {
cluster_bw.update_peer(sender, cumulative_in, cumulative_out);
// Feed remote peer's delta into the sliding window meter.
meter.record_sample(bytes_in, bytes_out);
metrics::CLUSTER_GOSSIP_MESSAGES
.with_label_values(&["bandwidth"])
.inc();
tracing::debug!(
sender = hex::encode(sender),
bytes_in,
bytes_out,
request_count,
"received bandwidth report"
);
}
Ok(_) => tracing::debug!("unexpected payload on bandwidth topic"),
Err(e) => tracing::debug!(error = %e, "failed to decode bandwidth message"),
}
}
}
}
async fn handle_model_events(mut receiver: iroh_gossip::api::GossipReceiver) {
while let Some(Ok(event)) = receiver.next().await {
if let Event::Received(message) = event {
match ClusterMessage::decode(&message.content) {
Ok(ClusterMessage {
payload: Payload::ModelAnnounce { model_type, hash, total_size, .. },
..
}) => {
tracing::info!(
model_type,
hash = hex::encode(hash),
total_size,
"received model announce (stub — ignoring)"
);
metrics::CLUSTER_MODEL_UPDATES
.with_label_values(&[&model_type, "ignored"])
.inc();
}
Ok(ClusterMessage {
payload: Payload::ModelChunk { hash, chunk_index, .. },
..
}) => {
tracing::debug!(
hash = hex::encode(hash),
chunk_index,
"received model chunk (stub — ignoring)"
);
}
Ok(_) => {}
Err(e) => tracing::debug!(error = %e, "failed to decode model message"),
}
}
}
}
async fn handle_stub_events(mut receiver: iroh_gossip::api::GossipReceiver, channel: &str) {
while let Some(Ok(event)) = receiver.next().await {
if let Event::Received(message) = event {
if let Ok(msg) = ClusterMessage::decode(&message.content) {
tracing::debug!(?msg, channel, "received stub message");
metrics::CLUSTER_GOSSIP_MESSAGES
.with_label_values(&[channel])
.inc();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn topic_derivation_deterministic() {
let t1 = derive_topic("550e8400-e29b-41d4-a716-446655440000", "bandwidth");
let t2 = derive_topic("550e8400-e29b-41d4-a716-446655440000", "bandwidth");
assert_eq!(t1, t2);
}
#[test]
fn topic_derivation_different_channels() {
let bw = derive_topic("550e8400-e29b-41d4-a716-446655440000", "bandwidth");
let models = derive_topic("550e8400-e29b-41d4-a716-446655440000", "models");
assert_ne!(bw, models);
}
#[test]
fn topic_derivation_different_tenants() {
let t1 = derive_topic("550e8400-e29b-41d4-a716-446655440000", "bandwidth");
let t2 = derive_topic("660e8400-e29b-41d4-a716-446655440001", "bandwidth");
assert_ne!(t1, t2);
}
}