diff --git a/src/config.rs b/src/config.rs index ee3f374..8e139c6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -60,7 +60,8 @@ fn default_config_configmap() -> String { "pingora-config".to_string() } #[derive(Debug, Deserialize, Clone)] pub struct DDoSConfig { - pub model_path: String, + #[serde(default)] + pub model_path: Option, #[serde(default = "default_k")] pub k: usize, #[serde(default = "default_threshold")] @@ -73,6 +74,8 @@ pub struct DDoSConfig { pub min_events: usize, #[serde(default = "default_enabled")] pub enabled: bool, + #[serde(default = "default_use_ensemble")] + pub use_ensemble: bool, } #[derive(Debug, Deserialize, Clone)] @@ -97,7 +100,8 @@ pub struct BucketConfig { #[derive(Debug, Deserialize, Clone)] pub struct ScannerConfig { - pub model_path: String, + #[serde(default)] + pub model_path: Option, #[serde(default = "default_scanner_threshold")] pub threshold: f64, #[serde(default = "default_scanner_enabled")] @@ -111,6 +115,8 @@ pub struct ScannerConfig { /// TTL (seconds) for verified bot IP cache entries. #[serde(default = "default_bot_cache_ttl")] pub bot_cache_ttl_secs: u64, + #[serde(default = "default_use_ensemble")] + pub use_ensemble: bool, } #[derive(Debug, Deserialize, Clone)] @@ -130,6 +136,7 @@ pub struct BotAllowlistRule { } fn default_bot_cache_ttl() -> u64 { 86400 } // 24h +fn default_use_ensemble() -> bool { true } fn default_scanner_threshold() -> f64 { 0.5 } fn default_scanner_enabled() -> bool { true } diff --git a/src/ddos/detector.rs b/src/ddos/detector.rs index 711893e..ffc5183 100644 --- a/src/ddos/detector.rs +++ b/src/ddos/detector.rs @@ -15,6 +15,7 @@ pub struct DDoSDetector { window_secs: u64, window_capacity: usize, min_events: usize, + use_ensemble: bool, } fn shard_index(ip: &IpAddr) -> usize { @@ -34,6 +35,24 @@ impl DDoSDetector { window_secs: config.window_secs, window_capacity: config.window_capacity, min_events: config.min_events, + use_ensemble: false, + } + } + + /// Create a detector that uses the ensemble (decision tree + MLP) path. + /// A dummy model is still needed for fallback, but ensemble inference + /// takes priority when `use_ensemble` is true. + pub fn new_ensemble(model: TrainedModel, config: &DDoSConfig) -> Self { + let shards = (0..NUM_SHARDS) + .map(|_| RwLock::new(FxHashMap::default())) + .collect(); + Self { + model, + shards, + window_secs: config.window_secs, + window_capacity: config.window_capacity, + min_events: config.min_events, + use_ensemble: true, } } @@ -79,6 +98,24 @@ impl DDoSDetector { } let features = state.extract_features(self.window_secs); + + if self.use_ensemble { + // Cast f64 features to f32 array for ensemble inference. + let mut f32_features = [0.0f32; 14]; + for (i, &v) in features.iter().enumerate().take(14) { + f32_features[i] = v as f32; + } + let ev = crate::ensemble::ddos::ddos_ensemble_predict(&f32_features); + crate::metrics::DDOS_ENSEMBLE_PATH + .with_label_values(&[match ev.path { + crate::ensemble::ddos::DDoSEnsemblePath::TreeBlock => "tree_block", + crate::ensemble::ddos::DDoSEnsemblePath::TreeAllow => "tree_allow", + crate::ensemble::ddos::DDoSEnsemblePath::Mlp => "mlp", + }]) + .inc(); + return ev.action; + } + self.model.classify(&features) } diff --git a/src/ddos/model.rs b/src/ddos/model.rs index a5c5496..e655915 100644 --- a/src/ddos/model.rs +++ b/src/ddos/model.rs @@ -48,6 +48,21 @@ impl TrainedModel { }) } + /// Create an empty model (no training points). Used when the ensemble + /// path is active and the KNN model is not needed. + pub fn empty(k: usize, threshold: f64) -> Self { + Self { + points: vec![], + labels: vec![], + norm_params: NormParams { + mins: [0.0; NUM_FEATURES], + maxs: [1.0; NUM_FEATURES], + }, + k, + threshold, + } + } + pub fn from_serialized(model: SerializedModel) -> Self { Self { points: model.points, diff --git a/src/lib.rs b/src/lib.rs index daa5a61..a799ce9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,14 +2,19 @@ // integration tests in tests/ can construct and drive a SunbeamProxy // without going through the binary entry point. pub mod acme; +pub mod autotune; pub mod cache; pub mod cluster; pub mod config; +pub mod dataset; pub mod ddos; pub mod dual_stack; +pub mod ensemble; pub mod metrics; pub mod proxy; pub mod rate_limit; pub mod scanner; pub mod ssh; pub mod static_files; +#[cfg(feature = "training")] +pub mod training; diff --git a/src/metrics.rs b/src/metrics.rs index 5c12004..d129807 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -194,6 +194,24 @@ pub static CLUSTER_MODEL_UPDATES: LazyLock = LazyLock::new(|| { c }); +pub static SCANNER_ENSEMBLE_PATH: LazyLock = LazyLock::new(|| { + let c = IntCounterVec::new( + Opts::new("sunbeam_scanner_ensemble_path_total", "Scanner ensemble decision path"), + &["path"], + ).unwrap(); + REGISTRY.register(Box::new(c.clone())).unwrap(); + c +}); + +pub static DDOS_ENSEMBLE_PATH: LazyLock = LazyLock::new(|| { + let c = IntCounterVec::new( + Opts::new("sunbeam_ddos_ensemble_path_total", "DDoS ensemble decision path"), + &["path"], + ).unwrap(); + REGISTRY.register(Box::new(c.clone())).unwrap(); + c +}); + /// Spawn a lightweight HTTP server on `port` serving `/metrics` and `/health`. /// Returns immediately; the server runs in the background on the tokio runtime. /// Port 0 = disabled. diff --git a/src/scanner/detector.rs b/src/scanner/detector.rs index 63251f2..aeb3baa 100644 --- a/src/scanner/detector.rs +++ b/src/scanner/detector.rs @@ -15,6 +15,7 @@ pub struct ScannerDetector { weights: [f64; NUM_SCANNER_WEIGHTS], threshold: f64, norm_params: ScannerNormParams, + use_ensemble: bool, } impl ScannerDetector { @@ -42,6 +43,39 @@ impl ScannerDetector { weights: model.weights, threshold: model.threshold, norm_params: model.norm_params.clone(), + use_ensemble: false, + } + } + + /// Create a detector that uses the ensemble (decision tree + MLP) path + /// instead of the linear model. No model file needed — weights are compiled in. + pub fn new_ensemble(routes: &[RouteConfig]) -> Self { + let fragment_hashes: FxHashSet = crate::scanner::train::DEFAULT_FRAGMENTS + .iter() + .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) + .collect(); + + let extension_hashes: FxHashSet = SUSPICIOUS_EXTENSIONS_LIST + .iter() + .map(|e| fx_hash_bytes(e.as_bytes())) + .collect(); + + let configured_hosts: FxHashSet = routes + .iter() + .map(|r| fx_hash_bytes(r.host_prefix.as_bytes())) + .collect(); + + Self { + fragment_hashes, + extension_hashes, + configured_hosts, + weights: [0.0; NUM_SCANNER_WEIGHTS], + threshold: 0.5, + norm_params: ScannerNormParams { + mins: [0.0; NUM_SCANNER_FEATURES], + maxs: [1.0; NUM_SCANNER_FEATURES], + }, + use_ensemble: true, } } @@ -87,6 +121,25 @@ impl ScannerDetector { }; } + if self.use_ensemble { + // Ensemble path: extract f32 features → decision tree + MLP. + let raw_f32 = features::extract_features_f32( + method, path, host_prefix, + has_cookies, has_referer, has_accept_language, + accept, user_agent, content_length, + &self.fragment_hashes, &self.extension_hashes, &self.configured_hosts, + ); + let ev = crate::ensemble::scanner::scanner_ensemble_predict(&raw_f32); + crate::metrics::SCANNER_ENSEMBLE_PATH + .with_label_values(&[match ev.path { + crate::ensemble::scanner::EnsemblePath::TreeBlock => "tree_block", + crate::ensemble::scanner::EnsemblePath::TreeAllow => "tree_allow", + crate::ensemble::scanner::EnsemblePath::Mlp => "mlp", + }]) + .inc(); + return ev.into(); + } + // 1. Extract 12 features let raw = features::extract_features( method, diff --git a/src/scanner/features.rs b/src/scanner/features.rs index 3a7b528..9f2c9d0 100644 --- a/src/scanner/features.rs +++ b/src/scanner/features.rs @@ -167,6 +167,30 @@ fn path_has_traversal(path: &str) -> f64 { 0.0 } +pub const NUM_SCANNER_FEATURES_F32: usize = NUM_SCANNER_FEATURES; +pub type ScannerFeatureVectorF32 = [f32; NUM_SCANNER_FEATURES]; + +/// Same as `extract_features` but returns f32 for ensemble inference. +#[allow(clippy::too_many_arguments)] +pub fn extract_features_f32( + method: &str, path: &str, host_prefix: &str, + has_cookies: bool, has_referer: bool, has_accept_language: bool, + accept: &str, user_agent: &str, content_length: u64, + fragment_hashes: &FxHashSet, + extension_hashes: &FxHashSet, + configured_hosts: &FxHashSet, +) -> ScannerFeatureVectorF32 { + let f64_features = extract_features( + method, path, host_prefix, has_cookies, has_referer, has_accept_language, + accept, user_agent, content_length, fragment_hashes, extension_hashes, configured_hosts, + ); + let mut out = [0.0f32; NUM_SCANNER_FEATURES]; + for i in 0..NUM_SCANNER_FEATURES { + out[i] = f64_features[i] as f32; + } + out +} + pub fn fx_hash_bytes(bytes: &[u8]) -> u64 { use std::hash::{Hash, Hasher}; let mut h = rustc_hash::FxHasher::default();