feat(ensemble): wire ensemble into scanner and DDoS detectors
Add use_ensemble config flag (default true) to both DDoSConfig and ScannerConfig. When enabled, detectors call compiled-in ensemble weights instead of loading model files. Also adds ensemble decision metrics and makes model_path optional in config. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
@@ -60,7 +60,8 @@ fn default_config_configmap() -> String { "pingora-config".to_string() }
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct DDoSConfig {
|
pub struct DDoSConfig {
|
||||||
pub model_path: String,
|
#[serde(default)]
|
||||||
|
pub model_path: Option<String>,
|
||||||
#[serde(default = "default_k")]
|
#[serde(default = "default_k")]
|
||||||
pub k: usize,
|
pub k: usize,
|
||||||
#[serde(default = "default_threshold")]
|
#[serde(default = "default_threshold")]
|
||||||
@@ -73,6 +74,8 @@ pub struct DDoSConfig {
|
|||||||
pub min_events: usize,
|
pub min_events: usize,
|
||||||
#[serde(default = "default_enabled")]
|
#[serde(default = "default_enabled")]
|
||||||
pub enabled: bool,
|
pub enabled: bool,
|
||||||
|
#[serde(default = "default_use_ensemble")]
|
||||||
|
pub use_ensemble: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
@@ -97,7 +100,8 @@ pub struct BucketConfig {
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct ScannerConfig {
|
pub struct ScannerConfig {
|
||||||
pub model_path: String,
|
#[serde(default)]
|
||||||
|
pub model_path: Option<String>,
|
||||||
#[serde(default = "default_scanner_threshold")]
|
#[serde(default = "default_scanner_threshold")]
|
||||||
pub threshold: f64,
|
pub threshold: f64,
|
||||||
#[serde(default = "default_scanner_enabled")]
|
#[serde(default = "default_scanner_enabled")]
|
||||||
@@ -111,6 +115,8 @@ pub struct ScannerConfig {
|
|||||||
/// TTL (seconds) for verified bot IP cache entries.
|
/// TTL (seconds) for verified bot IP cache entries.
|
||||||
#[serde(default = "default_bot_cache_ttl")]
|
#[serde(default = "default_bot_cache_ttl")]
|
||||||
pub bot_cache_ttl_secs: u64,
|
pub bot_cache_ttl_secs: u64,
|
||||||
|
#[serde(default = "default_use_ensemble")]
|
||||||
|
pub use_ensemble: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
@@ -130,6 +136,7 @@ pub struct BotAllowlistRule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn default_bot_cache_ttl() -> u64 { 86400 } // 24h
|
fn default_bot_cache_ttl() -> u64 { 86400 } // 24h
|
||||||
|
fn default_use_ensemble() -> bool { true }
|
||||||
|
|
||||||
fn default_scanner_threshold() -> f64 { 0.5 }
|
fn default_scanner_threshold() -> f64 { 0.5 }
|
||||||
fn default_scanner_enabled() -> bool { true }
|
fn default_scanner_enabled() -> bool { true }
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ pub struct DDoSDetector {
|
|||||||
window_secs: u64,
|
window_secs: u64,
|
||||||
window_capacity: usize,
|
window_capacity: usize,
|
||||||
min_events: usize,
|
min_events: usize,
|
||||||
|
use_ensemble: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn shard_index(ip: &IpAddr) -> usize {
|
fn shard_index(ip: &IpAddr) -> usize {
|
||||||
@@ -34,6 +35,24 @@ impl DDoSDetector {
|
|||||||
window_secs: config.window_secs,
|
window_secs: config.window_secs,
|
||||||
window_capacity: config.window_capacity,
|
window_capacity: config.window_capacity,
|
||||||
min_events: config.min_events,
|
min_events: config.min_events,
|
||||||
|
use_ensemble: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a detector that uses the ensemble (decision tree + MLP) path.
|
||||||
|
/// A dummy model is still needed for fallback, but ensemble inference
|
||||||
|
/// takes priority when `use_ensemble` is true.
|
||||||
|
pub fn new_ensemble(model: TrainedModel, config: &DDoSConfig) -> Self {
|
||||||
|
let shards = (0..NUM_SHARDS)
|
||||||
|
.map(|_| RwLock::new(FxHashMap::default()))
|
||||||
|
.collect();
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
shards,
|
||||||
|
window_secs: config.window_secs,
|
||||||
|
window_capacity: config.window_capacity,
|
||||||
|
min_events: config.min_events,
|
||||||
|
use_ensemble: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,6 +98,24 @@ impl DDoSDetector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let features = state.extract_features(self.window_secs);
|
let features = state.extract_features(self.window_secs);
|
||||||
|
|
||||||
|
if self.use_ensemble {
|
||||||
|
// Cast f64 features to f32 array for ensemble inference.
|
||||||
|
let mut f32_features = [0.0f32; 14];
|
||||||
|
for (i, &v) in features.iter().enumerate().take(14) {
|
||||||
|
f32_features[i] = v as f32;
|
||||||
|
}
|
||||||
|
let ev = crate::ensemble::ddos::ddos_ensemble_predict(&f32_features);
|
||||||
|
crate::metrics::DDOS_ENSEMBLE_PATH
|
||||||
|
.with_label_values(&[match ev.path {
|
||||||
|
crate::ensemble::ddos::DDoSEnsemblePath::TreeBlock => "tree_block",
|
||||||
|
crate::ensemble::ddos::DDoSEnsemblePath::TreeAllow => "tree_allow",
|
||||||
|
crate::ensemble::ddos::DDoSEnsemblePath::Mlp => "mlp",
|
||||||
|
}])
|
||||||
|
.inc();
|
||||||
|
return ev.action;
|
||||||
|
}
|
||||||
|
|
||||||
self.model.classify(&features)
|
self.model.classify(&features)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,21 @@ impl TrainedModel {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an empty model (no training points). Used when the ensemble
|
||||||
|
/// path is active and the KNN model is not needed.
|
||||||
|
pub fn empty(k: usize, threshold: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
points: vec![],
|
||||||
|
labels: vec![],
|
||||||
|
norm_params: NormParams {
|
||||||
|
mins: [0.0; NUM_FEATURES],
|
||||||
|
maxs: [1.0; NUM_FEATURES],
|
||||||
|
},
|
||||||
|
k,
|
||||||
|
threshold,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn from_serialized(model: SerializedModel) -> Self {
|
pub fn from_serialized(model: SerializedModel) -> Self {
|
||||||
Self {
|
Self {
|
||||||
points: model.points,
|
points: model.points,
|
||||||
|
|||||||
@@ -2,14 +2,19 @@
|
|||||||
// integration tests in tests/ can construct and drive a SunbeamProxy
|
// integration tests in tests/ can construct and drive a SunbeamProxy
|
||||||
// without going through the binary entry point.
|
// without going through the binary entry point.
|
||||||
pub mod acme;
|
pub mod acme;
|
||||||
|
pub mod autotune;
|
||||||
pub mod cache;
|
pub mod cache;
|
||||||
pub mod cluster;
|
pub mod cluster;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod dataset;
|
||||||
pub mod ddos;
|
pub mod ddos;
|
||||||
pub mod dual_stack;
|
pub mod dual_stack;
|
||||||
|
pub mod ensemble;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
pub mod proxy;
|
pub mod proxy;
|
||||||
pub mod rate_limit;
|
pub mod rate_limit;
|
||||||
pub mod scanner;
|
pub mod scanner;
|
||||||
pub mod ssh;
|
pub mod ssh;
|
||||||
pub mod static_files;
|
pub mod static_files;
|
||||||
|
#[cfg(feature = "training")]
|
||||||
|
pub mod training;
|
||||||
|
|||||||
@@ -194,6 +194,24 @@ pub static CLUSTER_MODEL_UPDATES: LazyLock<IntCounterVec> = LazyLock::new(|| {
|
|||||||
c
|
c
|
||||||
});
|
});
|
||||||
|
|
||||||
|
pub static SCANNER_ENSEMBLE_PATH: LazyLock<IntCounterVec> = LazyLock::new(|| {
|
||||||
|
let c = IntCounterVec::new(
|
||||||
|
Opts::new("sunbeam_scanner_ensemble_path_total", "Scanner ensemble decision path"),
|
||||||
|
&["path"],
|
||||||
|
).unwrap();
|
||||||
|
REGISTRY.register(Box::new(c.clone())).unwrap();
|
||||||
|
c
|
||||||
|
});
|
||||||
|
|
||||||
|
pub static DDOS_ENSEMBLE_PATH: LazyLock<IntCounterVec> = LazyLock::new(|| {
|
||||||
|
let c = IntCounterVec::new(
|
||||||
|
Opts::new("sunbeam_ddos_ensemble_path_total", "DDoS ensemble decision path"),
|
||||||
|
&["path"],
|
||||||
|
).unwrap();
|
||||||
|
REGISTRY.register(Box::new(c.clone())).unwrap();
|
||||||
|
c
|
||||||
|
});
|
||||||
|
|
||||||
/// Spawn a lightweight HTTP server on `port` serving `/metrics` and `/health`.
|
/// Spawn a lightweight HTTP server on `port` serving `/metrics` and `/health`.
|
||||||
/// Returns immediately; the server runs in the background on the tokio runtime.
|
/// Returns immediately; the server runs in the background on the tokio runtime.
|
||||||
/// Port 0 = disabled.
|
/// Port 0 = disabled.
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ pub struct ScannerDetector {
|
|||||||
weights: [f64; NUM_SCANNER_WEIGHTS],
|
weights: [f64; NUM_SCANNER_WEIGHTS],
|
||||||
threshold: f64,
|
threshold: f64,
|
||||||
norm_params: ScannerNormParams,
|
norm_params: ScannerNormParams,
|
||||||
|
use_ensemble: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ScannerDetector {
|
impl ScannerDetector {
|
||||||
@@ -42,6 +43,39 @@ impl ScannerDetector {
|
|||||||
weights: model.weights,
|
weights: model.weights,
|
||||||
threshold: model.threshold,
|
threshold: model.threshold,
|
||||||
norm_params: model.norm_params.clone(),
|
norm_params: model.norm_params.clone(),
|
||||||
|
use_ensemble: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a detector that uses the ensemble (decision tree + MLP) path
|
||||||
|
/// instead of the linear model. No model file needed — weights are compiled in.
|
||||||
|
pub fn new_ensemble(routes: &[RouteConfig]) -> Self {
|
||||||
|
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
||||||
|
.iter()
|
||||||
|
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let extension_hashes: FxHashSet<u64> = SUSPICIOUS_EXTENSIONS_LIST
|
||||||
|
.iter()
|
||||||
|
.map(|e| fx_hash_bytes(e.as_bytes()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let configured_hosts: FxHashSet<u64> = routes
|
||||||
|
.iter()
|
||||||
|
.map(|r| fx_hash_bytes(r.host_prefix.as_bytes()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
fragment_hashes,
|
||||||
|
extension_hashes,
|
||||||
|
configured_hosts,
|
||||||
|
weights: [0.0; NUM_SCANNER_WEIGHTS],
|
||||||
|
threshold: 0.5,
|
||||||
|
norm_params: ScannerNormParams {
|
||||||
|
mins: [0.0; NUM_SCANNER_FEATURES],
|
||||||
|
maxs: [1.0; NUM_SCANNER_FEATURES],
|
||||||
|
},
|
||||||
|
use_ensemble: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,6 +121,25 @@ impl ScannerDetector {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.use_ensemble {
|
||||||
|
// Ensemble path: extract f32 features → decision tree + MLP.
|
||||||
|
let raw_f32 = features::extract_features_f32(
|
||||||
|
method, path, host_prefix,
|
||||||
|
has_cookies, has_referer, has_accept_language,
|
||||||
|
accept, user_agent, content_length,
|
||||||
|
&self.fragment_hashes, &self.extension_hashes, &self.configured_hosts,
|
||||||
|
);
|
||||||
|
let ev = crate::ensemble::scanner::scanner_ensemble_predict(&raw_f32);
|
||||||
|
crate::metrics::SCANNER_ENSEMBLE_PATH
|
||||||
|
.with_label_values(&[match ev.path {
|
||||||
|
crate::ensemble::scanner::EnsemblePath::TreeBlock => "tree_block",
|
||||||
|
crate::ensemble::scanner::EnsemblePath::TreeAllow => "tree_allow",
|
||||||
|
crate::ensemble::scanner::EnsemblePath::Mlp => "mlp",
|
||||||
|
}])
|
||||||
|
.inc();
|
||||||
|
return ev.into();
|
||||||
|
}
|
||||||
|
|
||||||
// 1. Extract 12 features
|
// 1. Extract 12 features
|
||||||
let raw = features::extract_features(
|
let raw = features::extract_features(
|
||||||
method,
|
method,
|
||||||
|
|||||||
@@ -167,6 +167,30 @@ fn path_has_traversal(path: &str) -> f64 {
|
|||||||
0.0
|
0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const NUM_SCANNER_FEATURES_F32: usize = NUM_SCANNER_FEATURES;
|
||||||
|
pub type ScannerFeatureVectorF32 = [f32; NUM_SCANNER_FEATURES];
|
||||||
|
|
||||||
|
/// Same as `extract_features` but returns f32 for ensemble inference.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn extract_features_f32(
|
||||||
|
method: &str, path: &str, host_prefix: &str,
|
||||||
|
has_cookies: bool, has_referer: bool, has_accept_language: bool,
|
||||||
|
accept: &str, user_agent: &str, content_length: u64,
|
||||||
|
fragment_hashes: &FxHashSet<u64>,
|
||||||
|
extension_hashes: &FxHashSet<u64>,
|
||||||
|
configured_hosts: &FxHashSet<u64>,
|
||||||
|
) -> ScannerFeatureVectorF32 {
|
||||||
|
let f64_features = extract_features(
|
||||||
|
method, path, host_prefix, has_cookies, has_referer, has_accept_language,
|
||||||
|
accept, user_agent, content_length, fragment_hashes, extension_hashes, configured_hosts,
|
||||||
|
);
|
||||||
|
let mut out = [0.0f32; NUM_SCANNER_FEATURES];
|
||||||
|
for i in 0..NUM_SCANNER_FEATURES {
|
||||||
|
out[i] = f64_features[i] as f32;
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
pub fn fx_hash_bytes(bytes: &[u8]) -> u64 {
|
pub fn fx_hash_bytes(bytes: &[u8]) -> u64 {
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
let mut h = rustc_hash::FxHasher::default();
|
let mut h = rustc_hash::FxHasher::default();
|
||||||
|
|||||||
Reference in New Issue
Block a user