diff --git a/src/ensemble/ddos.rs b/src/ensemble/ddos.rs new file mode 100644 index 0000000..2b82e1b --- /dev/null +++ b/src/ensemble/ddos.rs @@ -0,0 +1,147 @@ +use crate::ddos::model::DDoSAction; +use super::gen::ddos_weights; +use super::mlp::mlp_predict_32; +use super::tree::{tree_predict, TreeDecision}; + +/// Which path the DDoS ensemble took to reach its verdict. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DDoSEnsemblePath { + TreeBlock, + TreeAllow, + Mlp, +} + +/// Result of the DDoS ensemble inference. +pub struct DDoSEnsembleVerdict { + pub action: DDoSAction, + pub score: f64, + pub reason: &'static str, + pub path: DDoSEnsemblePath, +} + +/// Normalize raw features using trained min/max constants. +#[inline] +fn normalize(raw: &[f32; 14]) -> [f32; 14] { + let mut out = [0.0f32; 14]; + for i in 0..14 { + let range = ddos_weights::NORM_MAXS[i] - ddos_weights::NORM_MINS[i]; + out[i] = if range > 0.0 { + ((raw[i] - ddos_weights::NORM_MINS[i]) / range).clamp(0.0, 1.0) + } else { + 0.0 + }; + } + out +} + +/// Full DDoS ensemble inference: decision tree first, MLP only on `Defer`. +pub fn ddos_ensemble_predict(raw: &[f32; 14]) -> DDoSEnsembleVerdict { + let input = normalize(raw); + + let tree_result = tree_predict(&ddos_weights::TREE_NODES, &input); + match tree_result { + TreeDecision::Block => DDoSEnsembleVerdict { + action: DDoSAction::Block, + score: 1.0, + reason: "ensemble:tree_block", + path: DDoSEnsemblePath::TreeBlock, + }, + TreeDecision::Allow => DDoSEnsembleVerdict { + action: DDoSAction::Allow, + score: 0.0, + reason: "ensemble:tree_allow", + path: DDoSEnsemblePath::TreeAllow, + }, + TreeDecision::Defer => { + let mlp_score = mlp_predict_32::<14>( + &ddos_weights::W1, + &ddos_weights::B1, + &ddos_weights::W2, + ddos_weights::B2, + &input, + ); + let action = if mlp_score > ddos_weights::THRESHOLD { + DDoSAction::Block + } else { + DDoSAction::Allow + }; + DDoSEnsembleVerdict { + action, + score: mlp_score as f64, + reason: "ensemble:mlp", + path: DDoSEnsemblePath::Mlp, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tree_allow_path() { + // All zeros → feature 4 (request_rate) = 0.0 <= 0.70 → left (node 1) + // feature 10 (cookie_ratio) = 0.0 <= 0.30 → left (node 3) → Allow + let raw = [0.0f32; 14]; + let v = ddos_ensemble_predict(&raw); + assert_eq!(v.action, DDoSAction::Allow); + assert_eq!(v.path, DDoSEnsemblePath::TreeAllow); + assert_eq!(v.reason, "ensemble:tree_allow"); + } + + #[test] + fn test_tree_block_path() { + // Need: feature 4 (request_rate) > 0.70 normalized → right (node 2) + // feature 12 (accept_language_ratio) > 0.25 normalized → right (node 6) → Block + // feature 4 max = 500, so raw 400 → normalized 0.8 > 0.70 ✓ + // feature 12 max = 1.0, so raw 0.5 → normalized 0.5 > 0.25 ✓ + let mut raw = [0.0f32; 14]; + raw[4] = 400.0; + raw[12] = 0.5; + let v = ddos_ensemble_predict(&raw); + assert_eq!(v.action, DDoSAction::Block); + assert_eq!(v.path, DDoSEnsemblePath::TreeBlock); + } + + #[test] + fn test_mlp_path() { + // Need: feature 4 > 0.70 normalized → right (node 2) + // feature 12 <= 0.25 normalized → left (node 5) → Defer + // feature 4 max = 500, raw 400 → 0.8 > 0.70 ✓ + // feature 12 max = 1.0, raw 0.1 → 0.1 <= 0.25 ✓ + let mut raw = [0.0f32; 14]; + raw[4] = 400.0; + raw[12] = 0.1; + let v = ddos_ensemble_predict(&raw); + assert_eq!(v.path, DDoSEnsemblePath::Mlp); + assert_eq!(v.reason, "ensemble:mlp"); + assert!(v.score >= 0.0 && v.score <= 1.0); + } + + #[test] + fn test_defer_then_mlp_allow() { + // Same Defer path as above — verify the MLP produces a valid action + let mut raw = [0.0f32; 14]; + raw[4] = 400.0; + raw[12] = 0.1; + let v = ddos_ensemble_predict(&raw); + assert!(matches!(v.action, DDoSAction::Allow | DDoSAction::Block)); + } + + #[test] + fn test_normalize_clamps_high() { + let mut raw = [0.0f32; 14]; + raw[0] = 999.0; // max is 100 + let normed = normalize(&raw); + assert!((normed[0] - 1.0).abs() < f32::EPSILON); + } + + #[test] + fn test_normalize_clamps_low() { + let mut raw = [0.0f32; 14]; + raw[1] = -500.0; // min is 0 + let normed = normalize(&raw); + assert!((normed[1] - 0.0).abs() < f32::EPSILON); + } +} diff --git a/src/ensemble/gen/ddos_weights.rs b/src/ensemble/gen/ddos_weights.rs new file mode 100644 index 0000000..e67b64d --- /dev/null +++ b/src/ensemble/gen/ddos_weights.rs @@ -0,0 +1,71 @@ +//! Auto-generated weights for the ddos ensemble. +//! DO NOT EDIT — regenerate with `cargo run --features training -- train-ddos-mlp`. + +pub const THRESHOLD: f32 = 0.50000000; + +pub const NORM_MINS: [f32; 14] = [ + 0.08778746, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.05001374, 0.02000000, + 0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000 +]; + +pub const NORM_MAXS: [f32; 14] = [ + 1000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49990082, 500.00000000, 1.00000000, + 171240.28125000, 30.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000 +]; + +pub const W1: [[f32; 14]; 32] = [ + [0.57458097, -0.10861993, 0.14037465, -0.23486336, -0.43255216, 0.16347405, 0.71766937, 0.83138502, 0.02852129, 0.56590265, 0.54848498, 0.38098580, 0.82907754, 0.61539698], + [0.06583579, 0.02305713, 0.89706898, 0.42619053, -1.20866120, -0.11974730, 1.70674825, -0.17969023, -0.26867196, 0.60768014, -0.08671998, -0.04825107, 0.58131427, -0.02062579], + [-0.40264100, 0.18836430, 0.08431315, 0.33763552, 0.44880620, -0.40894085, -0.22044741, 0.00533387, -0.61574107, 0.07670992, 0.63528854, 0.48244709, 0.20411402, -1.80697525], + [0.66713083, -0.22220801, 2.11234117, 0.41516641, -0.00165093, 0.65624571, 1.87509167, 0.63406783, -2.54182458, -0.53618753, 2.16407824, -0.61959583, -0.04717547, 0.17551991], + [-0.51027024, -0.60132700, 0.46407551, -0.57346475, -0.30902353, -0.24235034, 0.08087540, -2.14762974, -0.29429656, 0.56257033, -0.26935315, -0.16799171, 0.56852734, 1.93494022], + [-0.24938971, -0.12699288, -0.13746630, 0.64942318, 0.09490766, -0.02158179, 0.72449303, -0.28493983, -0.43053114, -0.01443988, 0.89670080, -0.34539866, -1.47019410, 0.79477930], + [0.62935185, 0.74686801, -0.15527052, -0.06635039, 0.73137009, 0.78417069, -0.06417987, 0.72259408, 0.85131824, 0.00477386, -0.14302900, 0.63481224, 0.92724019, -0.50126070], + [-0.12699059, -0.15016419, -0.48704135, 0.00581611, 0.75824696, 0.84114397, -0.08958503, 0.18609463, 0.56247348, 0.22239330, 0.43324804, 0.82077771, 0.55714250, -0.56955606], + [0.83457869, 0.40054807, 0.23281574, -0.58521581, 1.18067443, -0.49485078, 0.08600014, 0.99104887, -0.65019566, -0.44594154, 0.64507920, -0.61692268, -0.29301512, -0.11314666], + [-0.07868081, -0.18392175, 0.15165123, 0.35139060, 0.13855398, 0.16470867, 0.21025884, 1.57204449, 0.07827333, 0.05895505, 0.00810917, 1.05159700, 0.04605416, 0.38080546], + [1.47428405, -0.21614535, -0.35385504, 0.46582970, 1.26638246, 0.00133375, -3.85603786, 0.39766011, 1.92816520, 0.47828305, -0.16951409, -0.13771342, 0.49451983, 0.41184473], + [-0.23364748, 0.68134952, 0.36342716, 0.02657196, 0.07550839, 0.94861823, -0.52908695, 0.83652318, -0.05639480, 0.26536962, 0.44137934, 1.20957208, -0.60747981, -0.50647283], + [-0.16961956, -0.49570882, -0.33771378, -0.28554109, 0.95865113, -0.49269623, -0.44559151, 1.28568971, 0.79537493, -0.53175420, -3.19015551, 0.52214253, 0.86517984, 0.62523192], + [-0.16956513, -0.61727583, 0.63967121, 0.96406335, -0.28760204, 0.56459671, 0.78585202, -0.03668134, -0.14773002, -0.35764447, 0.84649116, -0.34540027, -0.12314465, -0.10070048], + [-0.34183556, -0.07760386, 0.70894319, 0.92814171, -0.19357866, 0.41449037, 0.54653358, 0.27682835, 0.81471086, 0.56383932, 0.57456553, -0.61491662, 0.92498505, 0.74495614], + [-0.38917324, -0.29217750, 1.43508542, -0.19152534, -0.18823336, 0.45097819, -0.38063127, -0.40419811, 0.56686693, -0.33231607, -0.19567636, -0.02500075, -0.04762971, 0.44703853], + [1.14234805, -0.62868208, -0.21298689, 0.00263968, -0.66115338, -1.12038326, 0.93599045, 0.77646011, -0.22770278, 1.43982041, 0.96078646, 1.15076077, -0.45110813, 0.83090556], + [0.89638984, -0.69683450, -0.29400119, 0.94997799, 0.90305328, -0.80215877, -0.09983492, -0.90757453, -0.03181892, 1.00702441, -0.97962254, -0.89580274, 0.69299418, -0.75975400], + [-0.75832003, -0.07210776, 0.07825917, 1.51633596, 0.44593197, 0.00936707, -0.12142835, -0.09877282, 0.06229200, 1.25678349, 0.25317946, 0.54112315, -0.17941843, 0.93283361], + [0.23085761, 0.53307736, 0.38696140, 0.36798462, 0.38192499, 0.23203450, 0.68225187, 0.47096270, -4.24785280, 0.18062039, 0.60047084, 0.16251479, -0.10811257, 0.48166662], + [0.10870802, 0.01576116, 0.00298645, 0.25878090, -0.16634797, 0.15850464, -0.24267951, 0.87678236, -0.27257833, 0.78637868, -0.00851476, 0.01502728, 0.92175138, -0.81292266], + [-0.74364990, -0.63139439, -0.18314177, -0.36881343, -0.53096825, -0.92442876, -0.05536628, -0.71273297, -0.94937468, -0.03863344, -0.09668982, -1.07886386, 0.58555382, 0.23351164], + [-0.09152136, 0.96538877, -0.11560653, -0.53110164, 0.89070886, 0.05664408, -0.71661353, 0.79684436, -0.00206013, 0.23857179, 0.06074178, -0.67188424, -0.15624331, 0.43436247], + [-0.28189376, -0.00535834, 0.60541785, 0.82968009, -0.21901314, -0.29874969, -0.16872653, 0.45570841, -0.25372767, -0.12359514, -1.10104620, 0.00162374, 0.07622175, 0.60413152], + [-1.13819373, -0.41320390, -5.57348347, 0.40931624, -1.59562767, 0.72510892, 0.03248254, 0.00407641, 0.57557869, 0.53510398, -0.35943517, 0.52707136, 0.61220711, -0.11644226], + [-0.02057049, 0.42545527, 0.24192038, 0.29863021, -0.22839858, -0.25318733, 0.17906551, -0.29471490, -0.04746799, 0.15909556, -0.26826856, -0.06874973, -0.03044286, 0.11770450], + [-0.18060833, -0.06301155, 0.01656315, -0.40476608, -0.35056075, 0.06344713, 0.32273614, -0.04382812, -0.18925793, 0.02124963, -0.23447622, 0.29704437, 0.19138981, -0.04584064], + [0.18248987, 0.05461208, -0.25655189, 0.16673982, 0.03251073, 0.05709980, 0.09135589, 0.06712578, -0.02372392, 0.00487196, -0.11774579, 0.34203079, 0.18477952, 0.09847298], + [-0.08292723, -0.03089223, 0.19555064, -0.18158682, -0.32060555, 0.18836822, -0.14625609, -0.83500093, -0.09893667, 0.02719803, 0.06864946, 0.00156752, 0.04342323, 0.30958080], + [-0.21274266, 0.06035644, 0.27282441, -0.01010289, -0.05599894, 0.27938741, -0.23254848, -0.20086342, -0.06775926, -0.18059292, 0.92534143, 0.09500337, 0.11612320, -0.06473339], + [-0.27279299, 0.96252358, 0.67542273, 0.64720130, 0.15221471, 1.67354584, 0.53074431, 0.65513390, 0.79840666, 0.78613347, 0.34742561, -1.83272552, 0.73313516, 0.09797212], + [-0.08888317, 0.14851266, 1.00953877, 0.19915256, -0.10076691, 0.47210938, 0.04427565, 0.19299655, 0.58729172, 0.17481442, -0.57466495, -0.16196120, 0.06293163, 1.73905540], +]; + +pub const B1: [f32; 32] = [ + -0.80723554, 0.54879200, 0.01237706, -0.22279924, 0.93692911, 0.12226531, -0.54665250, -0.49958101, + -0.20918398, -0.48646352, -0.58741039, -0.50572610, -0.04772990, -0.62962151, -0.46279392, 1.14840722, + -0.04871057, -0.31787100, 1.13966286, 0.69543558, -0.17798270, 0.66968435, -0.07442535, -0.70557600, + 0.79021728, 0.65736526, -0.30761406, 0.63242179, 0.83297908, -0.04573143, -0.18454255, -0.30583009 +]; + +pub const W2: [f32; 32] = [ + 1.09615684, -0.57856798, -0.08730038, -0.06425755, -0.96232760, -2.06290460, 0.70097560, 0.85189444, + -0.10077959, 1.94375157, 0.74497795, 0.88425481, 2.11908054, 0.85526127, 0.61624259, -2.93621016, + 1.52211487, 0.56318259, -3.15219641, -0.55187315, 1.61819077, -0.76258671, -0.09362544, 0.86861998, + -0.79028755, -0.90605170, 0.33475992, -0.79945564, -1.16680586, 0.15120529, 0.17619221, 1.61664009 +]; + +pub const B2: f32 = -0.52729088; + +pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [ + (3, 0.30015790, 1, 2), + (255, 0.00000000, 0, 0), + (255, 1.00000000, 0, 0), +]; diff --git a/src/ensemble/gen/mod.rs b/src/ensemble/gen/mod.rs new file mode 100644 index 0000000..6c07054 --- /dev/null +++ b/src/ensemble/gen/mod.rs @@ -0,0 +1,2 @@ +pub mod ddos_weights; +pub mod scanner_weights; diff --git a/src/ensemble/gen/scanner_weights.rs b/src/ensemble/gen/scanner_weights.rs new file mode 100644 index 0000000..f9030c7 --- /dev/null +++ b/src/ensemble/gen/scanner_weights.rs @@ -0,0 +1,71 @@ +//! Auto-generated weights for the scanner ensemble. +//! DO NOT EDIT — regenerate with `cargo run --features training -- train-scanner-mlp`. + +pub const THRESHOLD: f32 = 0.50000000; + +pub const NORM_MINS: [f32; 12] = [ + 0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, + 0.00000000, 0.00000000, 0.00000000, 0.00000000 +]; + +pub const NORM_MAXS: [f32; 12] = [ + 1.00000000, 10.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, + 1.00000000, 1.00000000, 1.00000000, 1.00000000 +]; + +pub const W1: [[f32; 12]; 32] = [ + [2.25848985, 1.62502551, 1.05068624, -0.23875977, 1.29692984, -1.34665418, -1.29937541, 1.66119707, -1.43897200, 0.07720046, -1.17116165, 1.96821272], + [2.30885172, 0.02477695, -0.23236598, 1.66507626, -1.41407740, 1.88616431, 1.84703696, -1.46395433, 2.03542018, 1.68318951, 2.01550031, 1.94223917], + [2.29420924, 1.86615539, 1.69271469, 1.42137837, 1.43151915, 1.84876072, 1.09228194, 1.73608077, 0.20805965, 0.52542430, -0.02558800, 0.04718366], + [0.36484259, -0.02785611, -0.01155548, 0.08577330, -0.00468449, -0.07848717, 0.05191587, 0.50796396, 0.40799347, -0.14838840, -0.30566201, 0.00758083], + [0.28191370, 0.20945202, 0.07742970, -0.06654347, 0.17395714, 0.00011351, 0.37079588, 0.41817516, 0.56992871, 0.05705916, 0.22339216, 0.11021475], + [0.06522971, 0.64510870, 0.31671444, 0.34980071, 0.03446164, -0.10592904, -0.21302676, -0.04404496, 0.08638768, 0.04217484, 0.43021953, 0.21055792], + [0.31206250, -0.14565454, 0.38078794, 0.00860748, 0.29409558, -0.11273954, -0.02210701, 0.15525217, 0.09696059, 0.13877581, 0.06483351, 0.10946950], + [0.28374705, -0.02963164, 0.27863786, -0.23428085, 0.12715313, 0.09141072, 0.07769041, 0.01915955, -0.20936646, 0.02813511, -0.03910714, 0.30322370], + [-1.19449413, -0.84935474, -0.32267663, -0.08140022, -0.78729230, 1.58759272, 0.88281459, -0.77263606, 1.55394125, 0.10148179, 1.59524822, -0.75499195], + [-0.97152823, -0.12173092, 0.04745778, -0.85466659, 1.57352293, -0.52651149, -0.66270715, 1.32282484, -1.24654925, -0.45822921, -1.10187364, -0.91162699], + [-0.93944395, -0.57891464, -1.12100291, -0.38871467, -0.18780440, -1.11835766, -0.43614236, -1.07918274, -0.09222561, -0.23854440, -0.16720718, 0.03247443], + [0.13319625, 0.87437463, 0.32213065, 0.13902900, 0.64760798, 0.00899744, 0.45325586, -0.14138180, 0.13888212, 0.07780524, -0.12482210, 0.12632932], + [0.57018995, -0.10839911, 0.02787536, 0.16884641, 0.19435850, -0.01189608, 0.13881874, -0.10700739, -0.05463003, 0.01371983, 0.04385772, 0.01100468], + [-0.26600277, -0.11843663, -0.01081531, 0.10785927, -0.18684258, 0.08537511, 0.01054722, -0.01972559, -0.07416820, 0.57192892, 0.37873995, -0.00498434], + [0.72535324, -0.25030360, 0.51470703, -0.16410951, -0.13649474, 0.16246459, -0.27847841, 0.12250750, 0.45576489, -0.18535912, -0.45686084, 0.58293521], + [0.18614589, -0.32835677, -0.08683094, 0.07748202, -0.24785264, -0.16834147, 0.27066526, 0.06058804, 0.01903199, -0.17387865, 0.12752151, -0.03780220], + [-1.22358644, -0.78316134, -0.54068804, -0.07921790, -0.72697675, 1.80127227, 0.14326867, -0.51875746, 1.83125353, -0.02672976, 1.68589675, -0.80162954], + [-0.83690810, -0.12682360, 0.10783038, -0.64648604, 1.50810242, -0.48788729, -0.59418935, 0.94863659, -0.84788662, -0.49779284, -0.96408021, -1.14068258], + [-0.96322638, -0.50503486, -0.87195945, -0.34710455, -0.28645220, -1.10507452, -0.32122782, -0.80753750, -0.00843489, 0.04215550, 0.03197355, 0.05468401], + [-0.17587705, 0.45144933, 0.37954769, -0.15405300, 0.75590396, 0.00346784, 0.62332457, -0.15602241, -0.26471916, -0.19963606, -0.22497311, -0.20784236], + [0.60608941, -0.05316854, 0.03766245, 0.46412235, -0.41121334, -0.01225545, -0.11125158, -0.33533856, -0.04625564, -0.02995013, -0.24979964, -0.35824969], + [0.08163761, 0.04702193, -0.24007457, -0.23439978, 0.27066308, 0.48389259, 0.32692793, -0.23089454, 0.26520243, -0.14099684, 0.06713670, 0.14434725], + [-0.50808382, -0.14518137, -0.23912378, 0.33510539, 0.46566108, 0.09035082, -0.12637842, 0.55245715, -0.19972627, 0.24517706, 0.34291887, 0.01936621], + [0.35826349, 0.21200819, 0.65315312, -0.16792546, 0.41378024, 0.32129642, 0.50814188, 0.48289016, 0.06839173, 0.42079177, 0.52295685, 0.26273951], + [0.24575019, 0.10700949, 0.07041252, -0.09410189, 0.18897925, 0.31616825, -0.01306109, 0.33499330, -0.01866218, 0.06233863, 0.15316568, 0.08370106], + [0.17828286, 0.17363867, -0.10626584, 0.06075979, 0.39465010, 0.19557165, 0.30352867, 0.26720291, 0.40256795, 0.13942246, 0.05869288, 0.08310238], + [-0.04834138, 0.29206491, 0.01330532, 0.07626399, -0.17378819, 0.09515948, 0.02298534, 0.41555724, 0.09492048, 0.39422533, 0.39373979, 0.20463347], + [-0.11641891, -0.06529939, -0.18899654, -0.02157970, -0.03554495, 0.10956290, -0.11688691, 0.04077352, 0.34220406, -0.09558969, 0.16150762, 0.25759667], + [-0.17313123, 0.00591523, 0.29443163, 0.08298909, 0.07761172, 0.19023541, 0.23826212, -0.07167042, 0.08753359, 0.17917964, -0.03248737, 0.28516129], + [0.13091524, 0.21435370, 0.15093684, 0.30902347, 0.44151527, 0.55901742, 0.19933179, 0.06438518, 0.30585650, -0.34089112, 0.26879075, 0.12928906], + [-0.25311065, -0.09963353, -0.50099874, 0.57481062, 0.38744658, -0.13065037, 0.18897361, 0.49376330, -0.15626629, 0.19911517, 0.06437352, -0.09104283], + [0.35787049, -0.04814727, 0.45446551, -0.15264697, 0.36565515, 0.22795495, 0.24630190, 0.16362202, 0.21044184, 0.53882843, 0.42343852, 0.18454899], +]; + +pub const B1: [f32; 32] = [ + 1.12135851, 0.64268047, 0.44761124, -0.28471574, 0.70866716, -0.25293177, -0.19119856, 0.39284116, + -0.20628852, -0.29301032, -0.08837436, 0.92048728, 0.91167349, -0.33615190, -0.06016272, 0.79141164, + -0.43257964, 0.48180589, 0.70891160, -0.24290052, 0.83115542, 0.69964927, 0.97887653, 1.34517038, + 1.10292709, 0.42009205, 1.07155228, 0.61349720, 0.46157768, 1.01911950, 0.51159418, 0.60460496 +]; + +pub const W2: [f32; 32] = [ + 1.55191231, 1.27754235, 0.43588921, 0.10868450, 0.55931729, -1.46911597, -0.54461092, 0.78240824, + -1.25938582, -0.06287600, -1.02053738, 1.07076716, 1.58776867, -0.03168033, -0.11393511, 1.30535436, + -1.46621227, 0.62925971, 0.76781118, -0.74480098, 1.29669034, 0.62078375, 1.64134884, 2.09736991, + 1.52834618, 0.87368065, 1.80090642, 0.89230227, 0.38757962, 1.80718291, 0.64923352, 1.18709576 +]; + +pub const B2: f32 = 0.23270580; + +pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [ + (3, 0.50000000, 1, 2), + (255, 1.00000000, 0, 0), + (255, 0.00000000, 0, 0), +]; diff --git a/src/ensemble/mlp.rs b/src/ensemble/mlp.rs new file mode 100644 index 0000000..fa2f096 --- /dev/null +++ b/src/ensemble/mlp.rs @@ -0,0 +1,132 @@ +/// Two-layer MLP forward pass with a fixed hidden size of 32: +/// +/// hidden = ReLU(W1 @ input + b1) +/// output = sigmoid(w2 · hidden + b2) +/// +/// Zero allocation — the 32-element hidden layer lives on the stack. +#[inline(always)] +pub fn mlp_predict_32( + w1: &[[f32; INPUT]; 32], + b1: &[f32; 32], + w2: &[f32; 32], + b2: f32, + input: &[f32; INPUT], +) -> f32 { + let mut hidden = [0.0f32; 32]; + + // Hidden layer: h_j = ReLU(sum_i(w1[j][i] * input[i]) + b1[j]) + for j in 0..32 { + let mut sum = b1[j]; + for i in 0..INPUT { + sum += w1[j][i] * input[i]; + } + hidden[j] = relu_f32(sum); + } + + // Output layer: sigmoid(sum_j(w2[j] * hidden[j]) + b2) + let mut out = b2; + for j in 0..32 { + out += w2[j] * hidden[j]; + } + sigmoid_f32(out) +} + +#[inline(always)] +fn sigmoid_f32(x: f32) -> f32 { + 1.0 / (1.0 + (-x).exp()) +} + +#[inline(always)] +fn relu_f32(x: f32) -> f32 { + if x > 0.0 { x } else { 0.0 } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sigmoid_boundaries() { + assert!((sigmoid_f32(0.0) - 0.5).abs() < 1e-6); + assert!(sigmoid_f32(10.0) > 0.999); + assert!(sigmoid_f32(-10.0) < 0.001); + } + + #[test] + fn test_relu() { + assert_eq!(relu_f32(0.0), 0.0); + assert_eq!(relu_f32(1.5), 1.5); + assert_eq!(relu_f32(-3.0), 0.0); + } + + #[test] + fn test_mlp_zero_weights() { + // All weights zero, bias2 = 0 → sigmoid(0) = 0.5 + let w1 = [[0.0f32; 2]; 32]; + let b1 = [0.0f32; 32]; + let w2 = [0.0f32; 32]; + let b2 = 0.0f32; + let input = [1.0f32, 2.0]; + let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input); + assert!((result - 0.5).abs() < 1e-6); + } + + #[test] + fn test_mlp_known_output() { + // Single active hidden unit: w1[0] = [1, 0], b1[0] = 0, w2[0] = 2, b2 = -1 + // input = [3, 0] + // hidden[0] = ReLU(1*3 + 0*0 + 0) = 3.0, rest = 0 + // output = sigmoid(2*3 + (-1)) = sigmoid(5) ≈ 0.9933 + let mut w1 = [[0.0f32; 2]; 32]; + w1[0] = [1.0, 0.0]; + let b1 = [0.0f32; 32]; + let mut w2 = [0.0f32; 32]; + w2[0] = 2.0; + let b2 = -1.0f32; + let input = [3.0f32, 0.0]; + let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input); + let expected = sigmoid_f32(5.0); + assert!( + (result - expected).abs() < 1e-6, + "expected {expected}, got {result}" + ); + } + + #[test] + fn test_mlp_relu_clips_negative() { + // w1[0] = [1, 0], b1[0] = -10 → hidden[0] = ReLU(-10 + input) clips to 0 + // Everything zero → sigmoid(b2) = sigmoid(0) = 0.5 + let mut w1 = [[0.0f32; 2]; 32]; + w1[0] = [1.0, 0.0]; + let mut b1 = [0.0f32; 32]; + b1[0] = -10.0; + let w2 = [1.0f32; 32]; // doesn't matter, hidden is all 0 + let b2 = 0.0f32; + let input = [3.0f32, 0.0]; // hidden[0] = ReLU(3-10) = 0 + let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input); + assert!((result - 0.5).abs() < 1e-6); + } +} + +#[cfg(test)] +mod proptests { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn output_bounded_0_1( + x0 in -10.0f32..10.0, + x1 in -10.0f32..10.0, + b2 in -5.0f32..5.0, + ) { + let w1 = [[0.1f32, -0.2]; 32]; + let b1 = [0.0f32; 32]; + let w2 = [0.05f32; 32]; + let input = [x0, x1]; + let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input); + prop_assert!(result >= 0.0 && result <= 1.0, + "output {result} outside [0,1]"); + } + } +} diff --git a/src/ensemble/mod.rs b/src/ensemble/mod.rs new file mode 100644 index 0000000..5c02b8b --- /dev/null +++ b/src/ensemble/mod.rs @@ -0,0 +1,6 @@ +pub mod ddos; +pub mod gen; +pub mod mlp; +pub mod replay; +pub mod scanner; +pub mod tree; diff --git a/src/ensemble/scanner.rs b/src/ensemble/scanner.rs new file mode 100644 index 0000000..6ad7f1d --- /dev/null +++ b/src/ensemble/scanner.rs @@ -0,0 +1,170 @@ +use crate::scanner::model::{ScannerAction, ScannerVerdict}; +use super::gen::scanner_weights; +use super::mlp::mlp_predict_32; +use super::tree::{tree_predict, TreeDecision}; + +/// Which path the ensemble took to reach its verdict. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EnsemblePath { + TreeBlock, + TreeAllow, + Mlp, +} + +/// Result of the scanner ensemble: action + confidence score + explanation. +pub struct EnsembleVerdict { + pub action: ScannerAction, + pub score: f64, + pub reason: &'static str, + pub path: EnsemblePath, +} + +/// Normalize raw features using trained min/max constants. +#[inline] +fn normalize(raw: &[f32; 12]) -> [f32; 12] { + let mut out = [0.0f32; 12]; + for i in 0..12 { + let range = scanner_weights::NORM_MAXS[i] - scanner_weights::NORM_MINS[i]; + out[i] = if range > 0.0 { + ((raw[i] - scanner_weights::NORM_MINS[i]) / range).clamp(0.0, 1.0) + } else { + 0.0 + }; + } + out +} + +/// Full ensemble inference: decision tree first, MLP only on `Defer`. +/// +/// Returns an [`EnsembleVerdict`] that can be converted into a +/// [`ScannerVerdict`] for the rest of the pipeline. +pub fn scanner_ensemble_predict(raw: &[f32; 12]) -> EnsembleVerdict { + let input = normalize(raw); + + let tree_result = tree_predict(&scanner_weights::TREE_NODES, &input); + match tree_result { + TreeDecision::Block => EnsembleVerdict { + action: ScannerAction::Block, + score: 1.0, + reason: "ensemble:tree_block", + path: EnsemblePath::TreeBlock, + }, + TreeDecision::Allow => EnsembleVerdict { + action: ScannerAction::Allow, + score: 0.0, + reason: "ensemble:tree_allow", + path: EnsemblePath::TreeAllow, + }, + TreeDecision::Defer => { + let mlp_score = mlp_predict_32::<12>( + &scanner_weights::W1, + &scanner_weights::B1, + &scanner_weights::W2, + scanner_weights::B2, + &input, + ); + let action = if mlp_score > scanner_weights::THRESHOLD { + ScannerAction::Block + } else { + ScannerAction::Allow + }; + EnsembleVerdict { + action, + score: mlp_score as f64, + reason: "ensemble:mlp", + path: EnsemblePath::Mlp, + } + } + } +} + +impl From for ScannerVerdict { + fn from(v: EnsembleVerdict) -> Self { + ScannerVerdict { + action: v.action, + score: v.score, + reason: v.reason, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tree_allow_path() { + // All features at zero → feature 3 (suspicious_ua) = 0.0 <= 0.65 → left (node 1) + // feature 0 (path_depth) = 0.0 <= 0.40 → left (node 3) → Allow leaf + let raw = [0.0f32; 12]; + let v = scanner_ensemble_predict(&raw); + assert_eq!(v.action, ScannerAction::Allow); + assert_eq!(v.path, EnsemblePath::TreeAllow); + assert_eq!(v.reason, "ensemble:tree_allow"); + assert!((v.score - 0.0).abs() < f64::EPSILON); + } + + #[test] + fn test_tree_block_path() { + // Need: feature 3 (suspicious_ua) > 0.65 (normalized) → right (node 2) + // feature 7 (payload_entropy) > 0.72 (normalized) → right (node 6) → Block + // feature 3 max = 1.0, so raw 0.8 → normalized 0.8 > 0.65 ✓ + // feature 7 max = 8.0, so raw 6.0 → normalized 0.75 > 0.72 ✓ + let mut raw = [0.0f32; 12]; + raw[3] = 0.8; // suspicious_ua: normalized = 0.8/1.0 = 0.8 > 0.65 + raw[7] = 6.0; // payload_entropy: normalized = 6.0/8.0 = 0.75 > 0.72 + let v = scanner_ensemble_predict(&raw); + assert_eq!(v.action, ScannerAction::Block); + assert_eq!(v.path, EnsemblePath::TreeBlock); + assert_eq!(v.reason, "ensemble:tree_block"); + } + + #[test] + fn test_mlp_path() { + // Need: feature 3 > 0.65 normalized → right (node 2) + // feature 7 <= 0.72 normalized → left (node 5) → Defer + // Then MLP runs on the normalized input. + let mut raw = [0.0f32; 12]; + raw[3] = 0.8; // normalized = 0.8 > 0.65 + raw[7] = 4.0; // normalized = 4.0/8.0 = 0.5 <= 0.72 + // Also need feature 2 (query_param_count) to navigate node 5 correctly + // node 5: split on feature 2, threshold 0.55 → left=9(Defer), right=10 + // normalized feature 2 = 0.0/20.0 = 0.0 <= 0.55 → left (node 9) → Defer + let v = scanner_ensemble_predict(&raw); + assert_eq!(v.path, EnsemblePath::Mlp); + assert_eq!(v.reason, "ensemble:mlp"); + // MLP output is deterministic for these inputs + assert!(v.score >= 0.0 && v.score <= 1.0); + } + + #[test] + fn test_normalize_clamps() { + // Values beyond max should be clamped to 1.0 + let mut raw = [0.0f32; 12]; + raw[0] = 100.0; // max is 10.0 + let normed = normalize(&raw); + assert!((normed[0] - 1.0).abs() < f64::EPSILON as f32); + } + + #[test] + fn test_normalize_negative_clamps() { + let mut raw = [0.0f32; 12]; + raw[0] = -5.0; // min is 0.0 + let normed = normalize(&raw); + assert!((normed[0] - 0.0).abs() < f64::EPSILON as f32); + } + + #[test] + fn test_verdict_into_scanner_verdict() { + let v = EnsembleVerdict { + action: ScannerAction::Block, + score: 0.85, + reason: "ensemble:mlp", + path: EnsemblePath::Mlp, + }; + let sv: ScannerVerdict = v.into(); + assert_eq!(sv.action, ScannerAction::Block); + assert!((sv.score - 0.85).abs() < f64::EPSILON); + assert_eq!(sv.reason, "ensemble:mlp"); + } +} diff --git a/src/ensemble/tree.rs b/src/ensemble/tree.rs new file mode 100644 index 0000000..ff343cd --- /dev/null +++ b/src/ensemble/tree.rs @@ -0,0 +1,156 @@ +/// Decision from a tree leaf node. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TreeDecision { + /// High-confidence block — overrides MLP. + Block, + /// High-confidence allow — overrides MLP. + Allow, + /// Low-confidence — defer to MLP for scoring. + Defer, +} + +/// Packed tree node: `(feature_index, threshold, left_child, right_child)`. +/// +/// Leaf nodes are encoded with `feature_index = 255`. +/// For leaves the threshold encodes the decision: +/// - `< 0.25` → Allow +/// - `> 0.75` → Block +/// - otherwise → Defer +pub type PackedNode = (u8, f32, u16, u16); + +/// Walk a packed decision tree. O(depth), zero allocation. +#[inline(always)] +pub fn tree_predict(nodes: &[PackedNode], input: &[f32]) -> TreeDecision { + let mut idx = 0usize; + loop { + let (feature, threshold, left, right) = nodes[idx]; + if feature == 255 { + return if threshold < 0.25 { + TreeDecision::Allow + } else if threshold > 0.75 { + TreeDecision::Block + } else { + TreeDecision::Defer + }; + } + idx = if input[feature as usize] <= threshold { + left as usize + } else { + right as usize + }; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Simple tree: + // node 0: feature 0, threshold 0.5 → left=1, right=2 + // node 1: leaf Allow (threshold 0.0) + // node 2: leaf Block (threshold 1.0) + const SIMPLE_TREE: [PackedNode; 3] = [ + (0, 0.5, 1, 2), + (255, 0.0, 0, 0), // Allow + (255, 1.0, 0, 0), // Block + ]; + + #[test] + fn test_tree_allow() { + let input = [0.3]; // <= 0.5 → left → Allow + assert_eq!(tree_predict(&SIMPLE_TREE, &input), TreeDecision::Allow); + } + + #[test] + fn test_tree_block() { + let input = [0.8]; // > 0.5 → right → Block + assert_eq!(tree_predict(&SIMPLE_TREE, &input), TreeDecision::Block); + } + + #[test] + fn test_tree_defer() { + // Tree with a Defer leaf + let tree: [PackedNode; 2] = [ + (0, 0.5, 1, 1), + (255, 0.5, 0, 0), // Defer + ]; + let input = [0.3]; + assert_eq!(tree_predict(&tree, &input), TreeDecision::Defer); + } + + #[test] + fn test_tree_boundary_allow() { + // threshold exactly 0.25 is NOT < 0.25, so it should Defer + let tree: [PackedNode; 1] = [(255, 0.25, 0, 0)]; + assert_eq!(tree_predict(&tree, &[]), TreeDecision::Defer); + } + + #[test] + fn test_tree_boundary_block() { + // threshold exactly 0.75 is NOT > 0.75, so it should Defer + let tree: [PackedNode; 1] = [(255, 0.75, 0, 0)]; + assert_eq!(tree_predict(&tree, &[]), TreeDecision::Defer); + } + + #[test] + fn test_deeper_tree() { + // Depth-3 tree with 4 features + let tree: [PackedNode; 7] = [ + (0, 0.5, 1, 2), // root: feature 0 + (1, 0.3, 3, 4), // left: feature 1 + (2, 0.7, 5, 6), // right: feature 2 + (255, 0.0, 0, 0), // Allow + (255, 0.5, 0, 0), // Defer + (255, 1.0, 0, 0), // Block + (255, 0.0, 0, 0), // Allow + ]; + // feature 0=0.2 (<=0.5→left=1), feature 1=0.1 (<=0.3→left=3) → Allow + assert_eq!(tree_predict(&tree, &[0.2, 0.1, 0.0, 0.0]), TreeDecision::Allow); + // feature 0=0.2 (<=0.5→left=1), feature 1=0.5 (>0.3→right=4) → Defer + assert_eq!(tree_predict(&tree, &[0.2, 0.5, 0.0, 0.0]), TreeDecision::Defer); + // feature 0=0.8 (>0.5→right=2), feature 2=0.3 (<=0.7→left=5) → Block + assert_eq!(tree_predict(&tree, &[0.8, 0.0, 0.3, 0.0]), TreeDecision::Block); + // feature 0=0.8 (>0.5→right=2), feature 2=0.9 (>0.7→right=6) → Allow + assert_eq!(tree_predict(&tree, &[0.8, 0.0, 0.9, 0.0]), TreeDecision::Allow); + } +} + +#[cfg(test)] +mod proptests { + use super::*; + use proptest::prelude::*; + + /// Generate a valid packed tree (complete binary tree of given depth). + /// All internal nodes split on feature 0 at threshold 0.5. + /// Leaves cycle through Allow/Block/Defer. + fn make_complete_tree(depth: u32) -> Vec { + let num_nodes = (1u32 << (depth + 1)) - 1; + let num_internal = (1u32 << depth) - 1; + let mut nodes = Vec::with_capacity(num_nodes as usize); + let decisions = [0.0f32, 1.0, 0.5]; // Allow, Block, Defer + for i in 0..num_nodes { + if i < num_internal { + let left = 2 * i + 1; + let right = 2 * i + 2; + nodes.push((0u8, 0.5f32, left as u16, right as u16)); + } else { + let leaf_idx = (i - num_internal) as usize; + nodes.push((255u8, decisions[leaf_idx % 3], 0u16, 0u16)); + } + } + nodes + } + + proptest! { + #[test] + fn tree_always_terminates(val in 0.0f32..1.0, depth in 0u32..5) { + let tree = make_complete_tree(depth); + let input = [val; 16]; // enough features for any tree + let result = tree_predict(&tree, &input); + prop_assert!(matches!( + result, + TreeDecision::Allow | TreeDecision::Block | TreeDecision::Defer + )); + } + } +}