feat(ensemble): wire ensemble into scanner and DDoS detectors

Add use_ensemble config flag (default true) to both DDoSConfig and
ScannerConfig. When enabled, detectors call compiled-in ensemble weights
instead of loading model files. Also adds ensemble decision metrics and
makes model_path optional in config.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:21 +00:00
parent 597362faa2
commit a9f1fd83bd
7 changed files with 161 additions and 2 deletions

View File

@@ -60,7 +60,8 @@ fn default_config_configmap() -> String { "pingora-config".to_string() }
#[derive(Debug, Deserialize, Clone)]
pub struct DDoSConfig {
pub model_path: String,
#[serde(default)]
pub model_path: Option<String>,
#[serde(default = "default_k")]
pub k: usize,
#[serde(default = "default_threshold")]
@@ -73,6 +74,8 @@ pub struct DDoSConfig {
pub min_events: usize,
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(default = "default_use_ensemble")]
pub use_ensemble: bool,
}
#[derive(Debug, Deserialize, Clone)]
@@ -97,7 +100,8 @@ pub struct BucketConfig {
#[derive(Debug, Deserialize, Clone)]
pub struct ScannerConfig {
pub model_path: String,
#[serde(default)]
pub model_path: Option<String>,
#[serde(default = "default_scanner_threshold")]
pub threshold: f64,
#[serde(default = "default_scanner_enabled")]
@@ -111,6 +115,8 @@ pub struct ScannerConfig {
/// TTL (seconds) for verified bot IP cache entries.
#[serde(default = "default_bot_cache_ttl")]
pub bot_cache_ttl_secs: u64,
#[serde(default = "default_use_ensemble")]
pub use_ensemble: bool,
}
#[derive(Debug, Deserialize, Clone)]
@@ -130,6 +136,7 @@ pub struct BotAllowlistRule {
}
fn default_bot_cache_ttl() -> u64 { 86400 } // 24h
fn default_use_ensemble() -> bool { true }
fn default_scanner_threshold() -> f64 { 0.5 }
fn default_scanner_enabled() -> bool { true }

View File

@@ -15,6 +15,7 @@ pub struct DDoSDetector {
window_secs: u64,
window_capacity: usize,
min_events: usize,
use_ensemble: bool,
}
fn shard_index(ip: &IpAddr) -> usize {
@@ -34,6 +35,24 @@ impl DDoSDetector {
window_secs: config.window_secs,
window_capacity: config.window_capacity,
min_events: config.min_events,
use_ensemble: false,
}
}
/// Create a detector that uses the ensemble (decision tree + MLP) path.
/// A dummy model is still needed for fallback, but ensemble inference
/// takes priority when `use_ensemble` is true.
pub fn new_ensemble(model: TrainedModel, config: &DDoSConfig) -> Self {
let shards = (0..NUM_SHARDS)
.map(|_| RwLock::new(FxHashMap::default()))
.collect();
Self {
model,
shards,
window_secs: config.window_secs,
window_capacity: config.window_capacity,
min_events: config.min_events,
use_ensemble: true,
}
}
@@ -79,6 +98,24 @@ impl DDoSDetector {
}
let features = state.extract_features(self.window_secs);
if self.use_ensemble {
// Cast f64 features to f32 array for ensemble inference.
let mut f32_features = [0.0f32; 14];
for (i, &v) in features.iter().enumerate().take(14) {
f32_features[i] = v as f32;
}
let ev = crate::ensemble::ddos::ddos_ensemble_predict(&f32_features);
crate::metrics::DDOS_ENSEMBLE_PATH
.with_label_values(&[match ev.path {
crate::ensemble::ddos::DDoSEnsemblePath::TreeBlock => "tree_block",
crate::ensemble::ddos::DDoSEnsemblePath::TreeAllow => "tree_allow",
crate::ensemble::ddos::DDoSEnsemblePath::Mlp => "mlp",
}])
.inc();
return ev.action;
}
self.model.classify(&features)
}

View File

@@ -48,6 +48,21 @@ impl TrainedModel {
})
}
/// Create an empty model (no training points). Used when the ensemble
/// path is active and the KNN model is not needed.
pub fn empty(k: usize, threshold: f64) -> Self {
Self {
points: vec![],
labels: vec![],
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k,
threshold,
}
}
pub fn from_serialized(model: SerializedModel) -> Self {
Self {
points: model.points,

View File

@@ -2,14 +2,19 @@
// integration tests in tests/ can construct and drive a SunbeamProxy
// without going through the binary entry point.
pub mod acme;
pub mod autotune;
pub mod cache;
pub mod cluster;
pub mod config;
pub mod dataset;
pub mod ddos;
pub mod dual_stack;
pub mod ensemble;
pub mod metrics;
pub mod proxy;
pub mod rate_limit;
pub mod scanner;
pub mod ssh;
pub mod static_files;
#[cfg(feature = "training")]
pub mod training;

View File

@@ -194,6 +194,24 @@ pub static CLUSTER_MODEL_UPDATES: LazyLock<IntCounterVec> = LazyLock::new(|| {
c
});
pub static SCANNER_ENSEMBLE_PATH: LazyLock<IntCounterVec> = LazyLock::new(|| {
let c = IntCounterVec::new(
Opts::new("sunbeam_scanner_ensemble_path_total", "Scanner ensemble decision path"),
&["path"],
).unwrap();
REGISTRY.register(Box::new(c.clone())).unwrap();
c
});
pub static DDOS_ENSEMBLE_PATH: LazyLock<IntCounterVec> = LazyLock::new(|| {
let c = IntCounterVec::new(
Opts::new("sunbeam_ddos_ensemble_path_total", "DDoS ensemble decision path"),
&["path"],
).unwrap();
REGISTRY.register(Box::new(c.clone())).unwrap();
c
});
/// Spawn a lightweight HTTP server on `port` serving `/metrics` and `/health`.
/// Returns immediately; the server runs in the background on the tokio runtime.
/// Port 0 = disabled.

View File

@@ -15,6 +15,7 @@ pub struct ScannerDetector {
weights: [f64; NUM_SCANNER_WEIGHTS],
threshold: f64,
norm_params: ScannerNormParams,
use_ensemble: bool,
}
impl ScannerDetector {
@@ -42,6 +43,39 @@ impl ScannerDetector {
weights: model.weights,
threshold: model.threshold,
norm_params: model.norm_params.clone(),
use_ensemble: false,
}
}
/// Create a detector that uses the ensemble (decision tree + MLP) path
/// instead of the linear model. No model file needed — weights are compiled in.
pub fn new_ensemble(routes: &[RouteConfig]) -> Self {
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let extension_hashes: FxHashSet<u64> = SUSPICIOUS_EXTENSIONS_LIST
.iter()
.map(|e| fx_hash_bytes(e.as_bytes()))
.collect();
let configured_hosts: FxHashSet<u64> = routes
.iter()
.map(|r| fx_hash_bytes(r.host_prefix.as_bytes()))
.collect();
Self {
fragment_hashes,
extension_hashes,
configured_hosts,
weights: [0.0; NUM_SCANNER_WEIGHTS],
threshold: 0.5,
norm_params: ScannerNormParams {
mins: [0.0; NUM_SCANNER_FEATURES],
maxs: [1.0; NUM_SCANNER_FEATURES],
},
use_ensemble: true,
}
}
@@ -87,6 +121,25 @@ impl ScannerDetector {
};
}
if self.use_ensemble {
// Ensemble path: extract f32 features → decision tree + MLP.
let raw_f32 = features::extract_features_f32(
method, path, host_prefix,
has_cookies, has_referer, has_accept_language,
accept, user_agent, content_length,
&self.fragment_hashes, &self.extension_hashes, &self.configured_hosts,
);
let ev = crate::ensemble::scanner::scanner_ensemble_predict(&raw_f32);
crate::metrics::SCANNER_ENSEMBLE_PATH
.with_label_values(&[match ev.path {
crate::ensemble::scanner::EnsemblePath::TreeBlock => "tree_block",
crate::ensemble::scanner::EnsemblePath::TreeAllow => "tree_allow",
crate::ensemble::scanner::EnsemblePath::Mlp => "mlp",
}])
.inc();
return ev.into();
}
// 1. Extract 12 features
let raw = features::extract_features(
method,

View File

@@ -167,6 +167,30 @@ fn path_has_traversal(path: &str) -> f64 {
0.0
}
pub const NUM_SCANNER_FEATURES_F32: usize = NUM_SCANNER_FEATURES;
pub type ScannerFeatureVectorF32 = [f32; NUM_SCANNER_FEATURES];
/// Same as `extract_features` but returns f32 for ensemble inference.
#[allow(clippy::too_many_arguments)]
pub fn extract_features_f32(
method: &str, path: &str, host_prefix: &str,
has_cookies: bool, has_referer: bool, has_accept_language: bool,
accept: &str, user_agent: &str, content_length: u64,
fragment_hashes: &FxHashSet<u64>,
extension_hashes: &FxHashSet<u64>,
configured_hosts: &FxHashSet<u64>,
) -> ScannerFeatureVectorF32 {
let f64_features = extract_features(
method, path, host_prefix, has_cookies, has_referer, has_accept_language,
accept, user_agent, content_length, fragment_hashes, extension_hashes, configured_hosts,
);
let mut out = [0.0f32; NUM_SCANNER_FEATURES];
for i in 0..NUM_SCANNER_FEATURES {
out[i] = f64_features[i] as f32;
}
out
}
pub fn fx_hash_bytes(bytes: &[u8]) -> u64 {
use std::hash::{Hash, Hasher};
let mut h = rustc_hash::FxHasher::default();