feat(ensemble): add decision tree + MLP inference engine

Compile-time ensemble for both scanner (12 features) and DDoS (14 features)
detection. Decision tree runs first; uncertain leaves defer to a single-hidden-
layer MLP (32 units, ReLU, sigmoid output). Weights are embedded as const
arrays for zero-overhead inference with no file I/O at runtime.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:21 +00:00
parent 9db2b1655f
commit 597362faa2
8 changed files with 755 additions and 0 deletions

147
src/ensemble/ddos.rs Normal file
View File

@@ -0,0 +1,147 @@
use crate::ddos::model::DDoSAction;
use super::gen::ddos_weights;
use super::mlp::mlp_predict_32;
use super::tree::{tree_predict, TreeDecision};
/// Which path the DDoS ensemble took to reach its verdict.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DDoSEnsemblePath {
TreeBlock,
TreeAllow,
Mlp,
}
/// Result of the DDoS ensemble inference.
pub struct DDoSEnsembleVerdict {
pub action: DDoSAction,
pub score: f64,
pub reason: &'static str,
pub path: DDoSEnsemblePath,
}
/// Normalize raw features using trained min/max constants.
#[inline]
fn normalize(raw: &[f32; 14]) -> [f32; 14] {
let mut out = [0.0f32; 14];
for i in 0..14 {
let range = ddos_weights::NORM_MAXS[i] - ddos_weights::NORM_MINS[i];
out[i] = if range > 0.0 {
((raw[i] - ddos_weights::NORM_MINS[i]) / range).clamp(0.0, 1.0)
} else {
0.0
};
}
out
}
/// Full DDoS ensemble inference: decision tree first, MLP only on `Defer`.
pub fn ddos_ensemble_predict(raw: &[f32; 14]) -> DDoSEnsembleVerdict {
let input = normalize(raw);
let tree_result = tree_predict(&ddos_weights::TREE_NODES, &input);
match tree_result {
TreeDecision::Block => DDoSEnsembleVerdict {
action: DDoSAction::Block,
score: 1.0,
reason: "ensemble:tree_block",
path: DDoSEnsemblePath::TreeBlock,
},
TreeDecision::Allow => DDoSEnsembleVerdict {
action: DDoSAction::Allow,
score: 0.0,
reason: "ensemble:tree_allow",
path: DDoSEnsemblePath::TreeAllow,
},
TreeDecision::Defer => {
let mlp_score = mlp_predict_32::<14>(
&ddos_weights::W1,
&ddos_weights::B1,
&ddos_weights::W2,
ddos_weights::B2,
&input,
);
let action = if mlp_score > ddos_weights::THRESHOLD {
DDoSAction::Block
} else {
DDoSAction::Allow
};
DDoSEnsembleVerdict {
action,
score: mlp_score as f64,
reason: "ensemble:mlp",
path: DDoSEnsemblePath::Mlp,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tree_allow_path() {
// All zeros → feature 4 (request_rate) = 0.0 <= 0.70 → left (node 1)
// feature 10 (cookie_ratio) = 0.0 <= 0.30 → left (node 3) → Allow
let raw = [0.0f32; 14];
let v = ddos_ensemble_predict(&raw);
assert_eq!(v.action, DDoSAction::Allow);
assert_eq!(v.path, DDoSEnsemblePath::TreeAllow);
assert_eq!(v.reason, "ensemble:tree_allow");
}
#[test]
fn test_tree_block_path() {
// Need: feature 4 (request_rate) > 0.70 normalized → right (node 2)
// feature 12 (accept_language_ratio) > 0.25 normalized → right (node 6) → Block
// feature 4 max = 500, so raw 400 → normalized 0.8 > 0.70 ✓
// feature 12 max = 1.0, so raw 0.5 → normalized 0.5 > 0.25 ✓
let mut raw = [0.0f32; 14];
raw[4] = 400.0;
raw[12] = 0.5;
let v = ddos_ensemble_predict(&raw);
assert_eq!(v.action, DDoSAction::Block);
assert_eq!(v.path, DDoSEnsemblePath::TreeBlock);
}
#[test]
fn test_mlp_path() {
// Need: feature 4 > 0.70 normalized → right (node 2)
// feature 12 <= 0.25 normalized → left (node 5) → Defer
// feature 4 max = 500, raw 400 → 0.8 > 0.70 ✓
// feature 12 max = 1.0, raw 0.1 → 0.1 <= 0.25 ✓
let mut raw = [0.0f32; 14];
raw[4] = 400.0;
raw[12] = 0.1;
let v = ddos_ensemble_predict(&raw);
assert_eq!(v.path, DDoSEnsemblePath::Mlp);
assert_eq!(v.reason, "ensemble:mlp");
assert!(v.score >= 0.0 && v.score <= 1.0);
}
#[test]
fn test_defer_then_mlp_allow() {
// Same Defer path as above — verify the MLP produces a valid action
let mut raw = [0.0f32; 14];
raw[4] = 400.0;
raw[12] = 0.1;
let v = ddos_ensemble_predict(&raw);
assert!(matches!(v.action, DDoSAction::Allow | DDoSAction::Block));
}
#[test]
fn test_normalize_clamps_high() {
let mut raw = [0.0f32; 14];
raw[0] = 999.0; // max is 100
let normed = normalize(&raw);
assert!((normed[0] - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_normalize_clamps_low() {
let mut raw = [0.0f32; 14];
raw[1] = -500.0; // min is 0
let normed = normalize(&raw);
assert!((normed[1] - 0.0).abs() < f32::EPSILON);
}
}

View File

@@ -0,0 +1,71 @@
//! Auto-generated weights for the ddos ensemble.
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-ddos-mlp`.
pub const THRESHOLD: f32 = 0.50000000;
pub const NORM_MINS: [f32; 14] = [
0.08778746, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.05001374, 0.02000000,
0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000
];
pub const NORM_MAXS: [f32; 14] = [
1000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49990082, 500.00000000, 1.00000000,
171240.28125000, 30.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000
];
pub const W1: [[f32; 14]; 32] = [
[0.57458097, -0.10861993, 0.14037465, -0.23486336, -0.43255216, 0.16347405, 0.71766937, 0.83138502, 0.02852129, 0.56590265, 0.54848498, 0.38098580, 0.82907754, 0.61539698],
[0.06583579, 0.02305713, 0.89706898, 0.42619053, -1.20866120, -0.11974730, 1.70674825, -0.17969023, -0.26867196, 0.60768014, -0.08671998, -0.04825107, 0.58131427, -0.02062579],
[-0.40264100, 0.18836430, 0.08431315, 0.33763552, 0.44880620, -0.40894085, -0.22044741, 0.00533387, -0.61574107, 0.07670992, 0.63528854, 0.48244709, 0.20411402, -1.80697525],
[0.66713083, -0.22220801, 2.11234117, 0.41516641, -0.00165093, 0.65624571, 1.87509167, 0.63406783, -2.54182458, -0.53618753, 2.16407824, -0.61959583, -0.04717547, 0.17551991],
[-0.51027024, -0.60132700, 0.46407551, -0.57346475, -0.30902353, -0.24235034, 0.08087540, -2.14762974, -0.29429656, 0.56257033, -0.26935315, -0.16799171, 0.56852734, 1.93494022],
[-0.24938971, -0.12699288, -0.13746630, 0.64942318, 0.09490766, -0.02158179, 0.72449303, -0.28493983, -0.43053114, -0.01443988, 0.89670080, -0.34539866, -1.47019410, 0.79477930],
[0.62935185, 0.74686801, -0.15527052, -0.06635039, 0.73137009, 0.78417069, -0.06417987, 0.72259408, 0.85131824, 0.00477386, -0.14302900, 0.63481224, 0.92724019, -0.50126070],
[-0.12699059, -0.15016419, -0.48704135, 0.00581611, 0.75824696, 0.84114397, -0.08958503, 0.18609463, 0.56247348, 0.22239330, 0.43324804, 0.82077771, 0.55714250, -0.56955606],
[0.83457869, 0.40054807, 0.23281574, -0.58521581, 1.18067443, -0.49485078, 0.08600014, 0.99104887, -0.65019566, -0.44594154, 0.64507920, -0.61692268, -0.29301512, -0.11314666],
[-0.07868081, -0.18392175, 0.15165123, 0.35139060, 0.13855398, 0.16470867, 0.21025884, 1.57204449, 0.07827333, 0.05895505, 0.00810917, 1.05159700, 0.04605416, 0.38080546],
[1.47428405, -0.21614535, -0.35385504, 0.46582970, 1.26638246, 0.00133375, -3.85603786, 0.39766011, 1.92816520, 0.47828305, -0.16951409, -0.13771342, 0.49451983, 0.41184473],
[-0.23364748, 0.68134952, 0.36342716, 0.02657196, 0.07550839, 0.94861823, -0.52908695, 0.83652318, -0.05639480, 0.26536962, 0.44137934, 1.20957208, -0.60747981, -0.50647283],
[-0.16961956, -0.49570882, -0.33771378, -0.28554109, 0.95865113, -0.49269623, -0.44559151, 1.28568971, 0.79537493, -0.53175420, -3.19015551, 0.52214253, 0.86517984, 0.62523192],
[-0.16956513, -0.61727583, 0.63967121, 0.96406335, -0.28760204, 0.56459671, 0.78585202, -0.03668134, -0.14773002, -0.35764447, 0.84649116, -0.34540027, -0.12314465, -0.10070048],
[-0.34183556, -0.07760386, 0.70894319, 0.92814171, -0.19357866, 0.41449037, 0.54653358, 0.27682835, 0.81471086, 0.56383932, 0.57456553, -0.61491662, 0.92498505, 0.74495614],
[-0.38917324, -0.29217750, 1.43508542, -0.19152534, -0.18823336, 0.45097819, -0.38063127, -0.40419811, 0.56686693, -0.33231607, -0.19567636, -0.02500075, -0.04762971, 0.44703853],
[1.14234805, -0.62868208, -0.21298689, 0.00263968, -0.66115338, -1.12038326, 0.93599045, 0.77646011, -0.22770278, 1.43982041, 0.96078646, 1.15076077, -0.45110813, 0.83090556],
[0.89638984, -0.69683450, -0.29400119, 0.94997799, 0.90305328, -0.80215877, -0.09983492, -0.90757453, -0.03181892, 1.00702441, -0.97962254, -0.89580274, 0.69299418, -0.75975400],
[-0.75832003, -0.07210776, 0.07825917, 1.51633596, 0.44593197, 0.00936707, -0.12142835, -0.09877282, 0.06229200, 1.25678349, 0.25317946, 0.54112315, -0.17941843, 0.93283361],
[0.23085761, 0.53307736, 0.38696140, 0.36798462, 0.38192499, 0.23203450, 0.68225187, 0.47096270, -4.24785280, 0.18062039, 0.60047084, 0.16251479, -0.10811257, 0.48166662],
[0.10870802, 0.01576116, 0.00298645, 0.25878090, -0.16634797, 0.15850464, -0.24267951, 0.87678236, -0.27257833, 0.78637868, -0.00851476, 0.01502728, 0.92175138, -0.81292266],
[-0.74364990, -0.63139439, -0.18314177, -0.36881343, -0.53096825, -0.92442876, -0.05536628, -0.71273297, -0.94937468, -0.03863344, -0.09668982, -1.07886386, 0.58555382, 0.23351164],
[-0.09152136, 0.96538877, -0.11560653, -0.53110164, 0.89070886, 0.05664408, -0.71661353, 0.79684436, -0.00206013, 0.23857179, 0.06074178, -0.67188424, -0.15624331, 0.43436247],
[-0.28189376, -0.00535834, 0.60541785, 0.82968009, -0.21901314, -0.29874969, -0.16872653, 0.45570841, -0.25372767, -0.12359514, -1.10104620, 0.00162374, 0.07622175, 0.60413152],
[-1.13819373, -0.41320390, -5.57348347, 0.40931624, -1.59562767, 0.72510892, 0.03248254, 0.00407641, 0.57557869, 0.53510398, -0.35943517, 0.52707136, 0.61220711, -0.11644226],
[-0.02057049, 0.42545527, 0.24192038, 0.29863021, -0.22839858, -0.25318733, 0.17906551, -0.29471490, -0.04746799, 0.15909556, -0.26826856, -0.06874973, -0.03044286, 0.11770450],
[-0.18060833, -0.06301155, 0.01656315, -0.40476608, -0.35056075, 0.06344713, 0.32273614, -0.04382812, -0.18925793, 0.02124963, -0.23447622, 0.29704437, 0.19138981, -0.04584064],
[0.18248987, 0.05461208, -0.25655189, 0.16673982, 0.03251073, 0.05709980, 0.09135589, 0.06712578, -0.02372392, 0.00487196, -0.11774579, 0.34203079, 0.18477952, 0.09847298],
[-0.08292723, -0.03089223, 0.19555064, -0.18158682, -0.32060555, 0.18836822, -0.14625609, -0.83500093, -0.09893667, 0.02719803, 0.06864946, 0.00156752, 0.04342323, 0.30958080],
[-0.21274266, 0.06035644, 0.27282441, -0.01010289, -0.05599894, 0.27938741, -0.23254848, -0.20086342, -0.06775926, -0.18059292, 0.92534143, 0.09500337, 0.11612320, -0.06473339],
[-0.27279299, 0.96252358, 0.67542273, 0.64720130, 0.15221471, 1.67354584, 0.53074431, 0.65513390, 0.79840666, 0.78613347, 0.34742561, -1.83272552, 0.73313516, 0.09797212],
[-0.08888317, 0.14851266, 1.00953877, 0.19915256, -0.10076691, 0.47210938, 0.04427565, 0.19299655, 0.58729172, 0.17481442, -0.57466495, -0.16196120, 0.06293163, 1.73905540],
];
pub const B1: [f32; 32] = [
-0.80723554, 0.54879200, 0.01237706, -0.22279924, 0.93692911, 0.12226531, -0.54665250, -0.49958101,
-0.20918398, -0.48646352, -0.58741039, -0.50572610, -0.04772990, -0.62962151, -0.46279392, 1.14840722,
-0.04871057, -0.31787100, 1.13966286, 0.69543558, -0.17798270, 0.66968435, -0.07442535, -0.70557600,
0.79021728, 0.65736526, -0.30761406, 0.63242179, 0.83297908, -0.04573143, -0.18454255, -0.30583009
];
pub const W2: [f32; 32] = [
1.09615684, -0.57856798, -0.08730038, -0.06425755, -0.96232760, -2.06290460, 0.70097560, 0.85189444,
-0.10077959, 1.94375157, 0.74497795, 0.88425481, 2.11908054, 0.85526127, 0.61624259, -2.93621016,
1.52211487, 0.56318259, -3.15219641, -0.55187315, 1.61819077, -0.76258671, -0.09362544, 0.86861998,
-0.79028755, -0.90605170, 0.33475992, -0.79945564, -1.16680586, 0.15120529, 0.17619221, 1.61664009
];
pub const B2: f32 = -0.52729088;
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
(3, 0.30015790, 1, 2),
(255, 0.00000000, 0, 0),
(255, 1.00000000, 0, 0),
];

2
src/ensemble/gen/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod ddos_weights;
pub mod scanner_weights;

View File

@@ -0,0 +1,71 @@
//! Auto-generated weights for the scanner ensemble.
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-scanner-mlp`.
pub const THRESHOLD: f32 = 0.50000000;
pub const NORM_MINS: [f32; 12] = [
0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000
];
pub const NORM_MAXS: [f32; 12] = [
1.00000000, 10.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000,
1.00000000, 1.00000000, 1.00000000, 1.00000000
];
pub const W1: [[f32; 12]; 32] = [
[2.25848985, 1.62502551, 1.05068624, -0.23875977, 1.29692984, -1.34665418, -1.29937541, 1.66119707, -1.43897200, 0.07720046, -1.17116165, 1.96821272],
[2.30885172, 0.02477695, -0.23236598, 1.66507626, -1.41407740, 1.88616431, 1.84703696, -1.46395433, 2.03542018, 1.68318951, 2.01550031, 1.94223917],
[2.29420924, 1.86615539, 1.69271469, 1.42137837, 1.43151915, 1.84876072, 1.09228194, 1.73608077, 0.20805965, 0.52542430, -0.02558800, 0.04718366],
[0.36484259, -0.02785611, -0.01155548, 0.08577330, -0.00468449, -0.07848717, 0.05191587, 0.50796396, 0.40799347, -0.14838840, -0.30566201, 0.00758083],
[0.28191370, 0.20945202, 0.07742970, -0.06654347, 0.17395714, 0.00011351, 0.37079588, 0.41817516, 0.56992871, 0.05705916, 0.22339216, 0.11021475],
[0.06522971, 0.64510870, 0.31671444, 0.34980071, 0.03446164, -0.10592904, -0.21302676, -0.04404496, 0.08638768, 0.04217484, 0.43021953, 0.21055792],
[0.31206250, -0.14565454, 0.38078794, 0.00860748, 0.29409558, -0.11273954, -0.02210701, 0.15525217, 0.09696059, 0.13877581, 0.06483351, 0.10946950],
[0.28374705, -0.02963164, 0.27863786, -0.23428085, 0.12715313, 0.09141072, 0.07769041, 0.01915955, -0.20936646, 0.02813511, -0.03910714, 0.30322370],
[-1.19449413, -0.84935474, -0.32267663, -0.08140022, -0.78729230, 1.58759272, 0.88281459, -0.77263606, 1.55394125, 0.10148179, 1.59524822, -0.75499195],
[-0.97152823, -0.12173092, 0.04745778, -0.85466659, 1.57352293, -0.52651149, -0.66270715, 1.32282484, -1.24654925, -0.45822921, -1.10187364, -0.91162699],
[-0.93944395, -0.57891464, -1.12100291, -0.38871467, -0.18780440, -1.11835766, -0.43614236, -1.07918274, -0.09222561, -0.23854440, -0.16720718, 0.03247443],
[0.13319625, 0.87437463, 0.32213065, 0.13902900, 0.64760798, 0.00899744, 0.45325586, -0.14138180, 0.13888212, 0.07780524, -0.12482210, 0.12632932],
[0.57018995, -0.10839911, 0.02787536, 0.16884641, 0.19435850, -0.01189608, 0.13881874, -0.10700739, -0.05463003, 0.01371983, 0.04385772, 0.01100468],
[-0.26600277, -0.11843663, -0.01081531, 0.10785927, -0.18684258, 0.08537511, 0.01054722, -0.01972559, -0.07416820, 0.57192892, 0.37873995, -0.00498434],
[0.72535324, -0.25030360, 0.51470703, -0.16410951, -0.13649474, 0.16246459, -0.27847841, 0.12250750, 0.45576489, -0.18535912, -0.45686084, 0.58293521],
[0.18614589, -0.32835677, -0.08683094, 0.07748202, -0.24785264, -0.16834147, 0.27066526, 0.06058804, 0.01903199, -0.17387865, 0.12752151, -0.03780220],
[-1.22358644, -0.78316134, -0.54068804, -0.07921790, -0.72697675, 1.80127227, 0.14326867, -0.51875746, 1.83125353, -0.02672976, 1.68589675, -0.80162954],
[-0.83690810, -0.12682360, 0.10783038, -0.64648604, 1.50810242, -0.48788729, -0.59418935, 0.94863659, -0.84788662, -0.49779284, -0.96408021, -1.14068258],
[-0.96322638, -0.50503486, -0.87195945, -0.34710455, -0.28645220, -1.10507452, -0.32122782, -0.80753750, -0.00843489, 0.04215550, 0.03197355, 0.05468401],
[-0.17587705, 0.45144933, 0.37954769, -0.15405300, 0.75590396, 0.00346784, 0.62332457, -0.15602241, -0.26471916, -0.19963606, -0.22497311, -0.20784236],
[0.60608941, -0.05316854, 0.03766245, 0.46412235, -0.41121334, -0.01225545, -0.11125158, -0.33533856, -0.04625564, -0.02995013, -0.24979964, -0.35824969],
[0.08163761, 0.04702193, -0.24007457, -0.23439978, 0.27066308, 0.48389259, 0.32692793, -0.23089454, 0.26520243, -0.14099684, 0.06713670, 0.14434725],
[-0.50808382, -0.14518137, -0.23912378, 0.33510539, 0.46566108, 0.09035082, -0.12637842, 0.55245715, -0.19972627, 0.24517706, 0.34291887, 0.01936621],
[0.35826349, 0.21200819, 0.65315312, -0.16792546, 0.41378024, 0.32129642, 0.50814188, 0.48289016, 0.06839173, 0.42079177, 0.52295685, 0.26273951],
[0.24575019, 0.10700949, 0.07041252, -0.09410189, 0.18897925, 0.31616825, -0.01306109, 0.33499330, -0.01866218, 0.06233863, 0.15316568, 0.08370106],
[0.17828286, 0.17363867, -0.10626584, 0.06075979, 0.39465010, 0.19557165, 0.30352867, 0.26720291, 0.40256795, 0.13942246, 0.05869288, 0.08310238],
[-0.04834138, 0.29206491, 0.01330532, 0.07626399, -0.17378819, 0.09515948, 0.02298534, 0.41555724, 0.09492048, 0.39422533, 0.39373979, 0.20463347],
[-0.11641891, -0.06529939, -0.18899654, -0.02157970, -0.03554495, 0.10956290, -0.11688691, 0.04077352, 0.34220406, -0.09558969, 0.16150762, 0.25759667],
[-0.17313123, 0.00591523, 0.29443163, 0.08298909, 0.07761172, 0.19023541, 0.23826212, -0.07167042, 0.08753359, 0.17917964, -0.03248737, 0.28516129],
[0.13091524, 0.21435370, 0.15093684, 0.30902347, 0.44151527, 0.55901742, 0.19933179, 0.06438518, 0.30585650, -0.34089112, 0.26879075, 0.12928906],
[-0.25311065, -0.09963353, -0.50099874, 0.57481062, 0.38744658, -0.13065037, 0.18897361, 0.49376330, -0.15626629, 0.19911517, 0.06437352, -0.09104283],
[0.35787049, -0.04814727, 0.45446551, -0.15264697, 0.36565515, 0.22795495, 0.24630190, 0.16362202, 0.21044184, 0.53882843, 0.42343852, 0.18454899],
];
pub const B1: [f32; 32] = [
1.12135851, 0.64268047, 0.44761124, -0.28471574, 0.70866716, -0.25293177, -0.19119856, 0.39284116,
-0.20628852, -0.29301032, -0.08837436, 0.92048728, 0.91167349, -0.33615190, -0.06016272, 0.79141164,
-0.43257964, 0.48180589, 0.70891160, -0.24290052, 0.83115542, 0.69964927, 0.97887653, 1.34517038,
1.10292709, 0.42009205, 1.07155228, 0.61349720, 0.46157768, 1.01911950, 0.51159418, 0.60460496
];
pub const W2: [f32; 32] = [
1.55191231, 1.27754235, 0.43588921, 0.10868450, 0.55931729, -1.46911597, -0.54461092, 0.78240824,
-1.25938582, -0.06287600, -1.02053738, 1.07076716, 1.58776867, -0.03168033, -0.11393511, 1.30535436,
-1.46621227, 0.62925971, 0.76781118, -0.74480098, 1.29669034, 0.62078375, 1.64134884, 2.09736991,
1.52834618, 0.87368065, 1.80090642, 0.89230227, 0.38757962, 1.80718291, 0.64923352, 1.18709576
];
pub const B2: f32 = 0.23270580;
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
(3, 0.50000000, 1, 2),
(255, 1.00000000, 0, 0),
(255, 0.00000000, 0, 0),
];

132
src/ensemble/mlp.rs Normal file
View File

@@ -0,0 +1,132 @@
/// Two-layer MLP forward pass with a fixed hidden size of 32:
///
/// hidden = ReLU(W1 @ input + b1)
/// output = sigmoid(w2 · hidden + b2)
///
/// Zero allocation — the 32-element hidden layer lives on the stack.
#[inline(always)]
pub fn mlp_predict_32<const INPUT: usize>(
w1: &[[f32; INPUT]; 32],
b1: &[f32; 32],
w2: &[f32; 32],
b2: f32,
input: &[f32; INPUT],
) -> f32 {
let mut hidden = [0.0f32; 32];
// Hidden layer: h_j = ReLU(sum_i(w1[j][i] * input[i]) + b1[j])
for j in 0..32 {
let mut sum = b1[j];
for i in 0..INPUT {
sum += w1[j][i] * input[i];
}
hidden[j] = relu_f32(sum);
}
// Output layer: sigmoid(sum_j(w2[j] * hidden[j]) + b2)
let mut out = b2;
for j in 0..32 {
out += w2[j] * hidden[j];
}
sigmoid_f32(out)
}
#[inline(always)]
fn sigmoid_f32(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[inline(always)]
fn relu_f32(x: f32) -> f32 {
if x > 0.0 { x } else { 0.0 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid_boundaries() {
assert!((sigmoid_f32(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid_f32(10.0) > 0.999);
assert!(sigmoid_f32(-10.0) < 0.001);
}
#[test]
fn test_relu() {
assert_eq!(relu_f32(0.0), 0.0);
assert_eq!(relu_f32(1.5), 1.5);
assert_eq!(relu_f32(-3.0), 0.0);
}
#[test]
fn test_mlp_zero_weights() {
// All weights zero, bias2 = 0 → sigmoid(0) = 0.5
let w1 = [[0.0f32; 2]; 32];
let b1 = [0.0f32; 32];
let w2 = [0.0f32; 32];
let b2 = 0.0f32;
let input = [1.0f32, 2.0];
let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input);
assert!((result - 0.5).abs() < 1e-6);
}
#[test]
fn test_mlp_known_output() {
// Single active hidden unit: w1[0] = [1, 0], b1[0] = 0, w2[0] = 2, b2 = -1
// input = [3, 0]
// hidden[0] = ReLU(1*3 + 0*0 + 0) = 3.0, rest = 0
// output = sigmoid(2*3 + (-1)) = sigmoid(5) ≈ 0.9933
let mut w1 = [[0.0f32; 2]; 32];
w1[0] = [1.0, 0.0];
let b1 = [0.0f32; 32];
let mut w2 = [0.0f32; 32];
w2[0] = 2.0;
let b2 = -1.0f32;
let input = [3.0f32, 0.0];
let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input);
let expected = sigmoid_f32(5.0);
assert!(
(result - expected).abs() < 1e-6,
"expected {expected}, got {result}"
);
}
#[test]
fn test_mlp_relu_clips_negative() {
// w1[0] = [1, 0], b1[0] = -10 → hidden[0] = ReLU(-10 + input) clips to 0
// Everything zero → sigmoid(b2) = sigmoid(0) = 0.5
let mut w1 = [[0.0f32; 2]; 32];
w1[0] = [1.0, 0.0];
let mut b1 = [0.0f32; 32];
b1[0] = -10.0;
let w2 = [1.0f32; 32]; // doesn't matter, hidden is all 0
let b2 = 0.0f32;
let input = [3.0f32, 0.0]; // hidden[0] = ReLU(3-10) = 0
let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input);
assert!((result - 0.5).abs() < 1e-6);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn output_bounded_0_1(
x0 in -10.0f32..10.0,
x1 in -10.0f32..10.0,
b2 in -5.0f32..5.0,
) {
let w1 = [[0.1f32, -0.2]; 32];
let b1 = [0.0f32; 32];
let w2 = [0.05f32; 32];
let input = [x0, x1];
let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input);
prop_assert!(result >= 0.0 && result <= 1.0,
"output {result} outside [0,1]");
}
}
}

6
src/ensemble/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod ddos;
pub mod gen;
pub mod mlp;
pub mod replay;
pub mod scanner;
pub mod tree;

170
src/ensemble/scanner.rs Normal file
View File

@@ -0,0 +1,170 @@
use crate::scanner::model::{ScannerAction, ScannerVerdict};
use super::gen::scanner_weights;
use super::mlp::mlp_predict_32;
use super::tree::{tree_predict, TreeDecision};
/// Which path the ensemble took to reach its verdict.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnsemblePath {
TreeBlock,
TreeAllow,
Mlp,
}
/// Result of the scanner ensemble: action + confidence score + explanation.
pub struct EnsembleVerdict {
pub action: ScannerAction,
pub score: f64,
pub reason: &'static str,
pub path: EnsemblePath,
}
/// Normalize raw features using trained min/max constants.
#[inline]
fn normalize(raw: &[f32; 12]) -> [f32; 12] {
let mut out = [0.0f32; 12];
for i in 0..12 {
let range = scanner_weights::NORM_MAXS[i] - scanner_weights::NORM_MINS[i];
out[i] = if range > 0.0 {
((raw[i] - scanner_weights::NORM_MINS[i]) / range).clamp(0.0, 1.0)
} else {
0.0
};
}
out
}
/// Full ensemble inference: decision tree first, MLP only on `Defer`.
///
/// Returns an [`EnsembleVerdict`] that can be converted into a
/// [`ScannerVerdict`] for the rest of the pipeline.
pub fn scanner_ensemble_predict(raw: &[f32; 12]) -> EnsembleVerdict {
let input = normalize(raw);
let tree_result = tree_predict(&scanner_weights::TREE_NODES, &input);
match tree_result {
TreeDecision::Block => EnsembleVerdict {
action: ScannerAction::Block,
score: 1.0,
reason: "ensemble:tree_block",
path: EnsemblePath::TreeBlock,
},
TreeDecision::Allow => EnsembleVerdict {
action: ScannerAction::Allow,
score: 0.0,
reason: "ensemble:tree_allow",
path: EnsemblePath::TreeAllow,
},
TreeDecision::Defer => {
let mlp_score = mlp_predict_32::<12>(
&scanner_weights::W1,
&scanner_weights::B1,
&scanner_weights::W2,
scanner_weights::B2,
&input,
);
let action = if mlp_score > scanner_weights::THRESHOLD {
ScannerAction::Block
} else {
ScannerAction::Allow
};
EnsembleVerdict {
action,
score: mlp_score as f64,
reason: "ensemble:mlp",
path: EnsemblePath::Mlp,
}
}
}
}
impl From<EnsembleVerdict> for ScannerVerdict {
fn from(v: EnsembleVerdict) -> Self {
ScannerVerdict {
action: v.action,
score: v.score,
reason: v.reason,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tree_allow_path() {
// All features at zero → feature 3 (suspicious_ua) = 0.0 <= 0.65 → left (node 1)
// feature 0 (path_depth) = 0.0 <= 0.40 → left (node 3) → Allow leaf
let raw = [0.0f32; 12];
let v = scanner_ensemble_predict(&raw);
assert_eq!(v.action, ScannerAction::Allow);
assert_eq!(v.path, EnsemblePath::TreeAllow);
assert_eq!(v.reason, "ensemble:tree_allow");
assert!((v.score - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_tree_block_path() {
// Need: feature 3 (suspicious_ua) > 0.65 (normalized) → right (node 2)
// feature 7 (payload_entropy) > 0.72 (normalized) → right (node 6) → Block
// feature 3 max = 1.0, so raw 0.8 → normalized 0.8 > 0.65 ✓
// feature 7 max = 8.0, so raw 6.0 → normalized 0.75 > 0.72 ✓
let mut raw = [0.0f32; 12];
raw[3] = 0.8; // suspicious_ua: normalized = 0.8/1.0 = 0.8 > 0.65
raw[7] = 6.0; // payload_entropy: normalized = 6.0/8.0 = 0.75 > 0.72
let v = scanner_ensemble_predict(&raw);
assert_eq!(v.action, ScannerAction::Block);
assert_eq!(v.path, EnsemblePath::TreeBlock);
assert_eq!(v.reason, "ensemble:tree_block");
}
#[test]
fn test_mlp_path() {
// Need: feature 3 > 0.65 normalized → right (node 2)
// feature 7 <= 0.72 normalized → left (node 5) → Defer
// Then MLP runs on the normalized input.
let mut raw = [0.0f32; 12];
raw[3] = 0.8; // normalized = 0.8 > 0.65
raw[7] = 4.0; // normalized = 4.0/8.0 = 0.5 <= 0.72
// Also need feature 2 (query_param_count) to navigate node 5 correctly
// node 5: split on feature 2, threshold 0.55 → left=9(Defer), right=10
// normalized feature 2 = 0.0/20.0 = 0.0 <= 0.55 → left (node 9) → Defer
let v = scanner_ensemble_predict(&raw);
assert_eq!(v.path, EnsemblePath::Mlp);
assert_eq!(v.reason, "ensemble:mlp");
// MLP output is deterministic for these inputs
assert!(v.score >= 0.0 && v.score <= 1.0);
}
#[test]
fn test_normalize_clamps() {
// Values beyond max should be clamped to 1.0
let mut raw = [0.0f32; 12];
raw[0] = 100.0; // max is 10.0
let normed = normalize(&raw);
assert!((normed[0] - 1.0).abs() < f64::EPSILON as f32);
}
#[test]
fn test_normalize_negative_clamps() {
let mut raw = [0.0f32; 12];
raw[0] = -5.0; // min is 0.0
let normed = normalize(&raw);
assert!((normed[0] - 0.0).abs() < f64::EPSILON as f32);
}
#[test]
fn test_verdict_into_scanner_verdict() {
let v = EnsembleVerdict {
action: ScannerAction::Block,
score: 0.85,
reason: "ensemble:mlp",
path: EnsemblePath::Mlp,
};
let sv: ScannerVerdict = v.into();
assert_eq!(sv.action, ScannerAction::Block);
assert!((sv.score - 0.85).abs() < f64::EPSILON);
assert_eq!(sv.reason, "ensemble:mlp");
}
}

156
src/ensemble/tree.rs Normal file
View File

@@ -0,0 +1,156 @@
/// Decision from a tree leaf node.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TreeDecision {
/// High-confidence block — overrides MLP.
Block,
/// High-confidence allow — overrides MLP.
Allow,
/// Low-confidence — defer to MLP for scoring.
Defer,
}
/// Packed tree node: `(feature_index, threshold, left_child, right_child)`.
///
/// Leaf nodes are encoded with `feature_index = 255`.
/// For leaves the threshold encodes the decision:
/// - `< 0.25` → Allow
/// - `> 0.75` → Block
/// - otherwise → Defer
pub type PackedNode = (u8, f32, u16, u16);
/// Walk a packed decision tree. O(depth), zero allocation.
#[inline(always)]
pub fn tree_predict(nodes: &[PackedNode], input: &[f32]) -> TreeDecision {
let mut idx = 0usize;
loop {
let (feature, threshold, left, right) = nodes[idx];
if feature == 255 {
return if threshold < 0.25 {
TreeDecision::Allow
} else if threshold > 0.75 {
TreeDecision::Block
} else {
TreeDecision::Defer
};
}
idx = if input[feature as usize] <= threshold {
left as usize
} else {
right as usize
};
}
}
#[cfg(test)]
mod tests {
use super::*;
// Simple tree:
// node 0: feature 0, threshold 0.5 → left=1, right=2
// node 1: leaf Allow (threshold 0.0)
// node 2: leaf Block (threshold 1.0)
const SIMPLE_TREE: [PackedNode; 3] = [
(0, 0.5, 1, 2),
(255, 0.0, 0, 0), // Allow
(255, 1.0, 0, 0), // Block
];
#[test]
fn test_tree_allow() {
let input = [0.3]; // <= 0.5 → left → Allow
assert_eq!(tree_predict(&SIMPLE_TREE, &input), TreeDecision::Allow);
}
#[test]
fn test_tree_block() {
let input = [0.8]; // > 0.5 → right → Block
assert_eq!(tree_predict(&SIMPLE_TREE, &input), TreeDecision::Block);
}
#[test]
fn test_tree_defer() {
// Tree with a Defer leaf
let tree: [PackedNode; 2] = [
(0, 0.5, 1, 1),
(255, 0.5, 0, 0), // Defer
];
let input = [0.3];
assert_eq!(tree_predict(&tree, &input), TreeDecision::Defer);
}
#[test]
fn test_tree_boundary_allow() {
// threshold exactly 0.25 is NOT < 0.25, so it should Defer
let tree: [PackedNode; 1] = [(255, 0.25, 0, 0)];
assert_eq!(tree_predict(&tree, &[]), TreeDecision::Defer);
}
#[test]
fn test_tree_boundary_block() {
// threshold exactly 0.75 is NOT > 0.75, so it should Defer
let tree: [PackedNode; 1] = [(255, 0.75, 0, 0)];
assert_eq!(tree_predict(&tree, &[]), TreeDecision::Defer);
}
#[test]
fn test_deeper_tree() {
// Depth-3 tree with 4 features
let tree: [PackedNode; 7] = [
(0, 0.5, 1, 2), // root: feature 0
(1, 0.3, 3, 4), // left: feature 1
(2, 0.7, 5, 6), // right: feature 2
(255, 0.0, 0, 0), // Allow
(255, 0.5, 0, 0), // Defer
(255, 1.0, 0, 0), // Block
(255, 0.0, 0, 0), // Allow
];
// feature 0=0.2 (<=0.5→left=1), feature 1=0.1 (<=0.3→left=3) → Allow
assert_eq!(tree_predict(&tree, &[0.2, 0.1, 0.0, 0.0]), TreeDecision::Allow);
// feature 0=0.2 (<=0.5→left=1), feature 1=0.5 (>0.3→right=4) → Defer
assert_eq!(tree_predict(&tree, &[0.2, 0.5, 0.0, 0.0]), TreeDecision::Defer);
// feature 0=0.8 (>0.5→right=2), feature 2=0.3 (<=0.7→left=5) → Block
assert_eq!(tree_predict(&tree, &[0.8, 0.0, 0.3, 0.0]), TreeDecision::Block);
// feature 0=0.8 (>0.5→right=2), feature 2=0.9 (>0.7→right=6) → Allow
assert_eq!(tree_predict(&tree, &[0.8, 0.0, 0.9, 0.0]), TreeDecision::Allow);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
/// Generate a valid packed tree (complete binary tree of given depth).
/// All internal nodes split on feature 0 at threshold 0.5.
/// Leaves cycle through Allow/Block/Defer.
fn make_complete_tree(depth: u32) -> Vec<PackedNode> {
let num_nodes = (1u32 << (depth + 1)) - 1;
let num_internal = (1u32 << depth) - 1;
let mut nodes = Vec::with_capacity(num_nodes as usize);
let decisions = [0.0f32, 1.0, 0.5]; // Allow, Block, Defer
for i in 0..num_nodes {
if i < num_internal {
let left = 2 * i + 1;
let right = 2 * i + 2;
nodes.push((0u8, 0.5f32, left as u16, right as u16));
} else {
let leaf_idx = (i - num_internal) as usize;
nodes.push((255u8, decisions[leaf_idx % 3], 0u16, 0u16));
}
}
nodes
}
proptest! {
#[test]
fn tree_always_terminates(val in 0.0f32..1.0, depth in 0u32..5) {
let tree = make_complete_tree(depth);
let input = [val; 16]; // enough features for any tree
let result = tree_predict(&tree, &input);
prop_assert!(matches!(
result,
TreeDecision::Allow | TreeDecision::Block | TreeDecision::Defer
));
}
}
}