feat(ensemble): add decision tree + MLP inference engine
Compile-time ensemble for both scanner (12 features) and DDoS (14 features) detection. Decision tree runs first; uncertain leaves defer to a single-hidden- layer MLP (32 units, ReLU, sigmoid output). Weights are embedded as const arrays for zero-overhead inference with no file I/O at runtime. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
147
src/ensemble/ddos.rs
Normal file
147
src/ensemble/ddos.rs
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
use crate::ddos::model::DDoSAction;
|
||||||
|
use super::gen::ddos_weights;
|
||||||
|
use super::mlp::mlp_predict_32;
|
||||||
|
use super::tree::{tree_predict, TreeDecision};
|
||||||
|
|
||||||
|
/// Which path the DDoS ensemble took to reach its verdict.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum DDoSEnsemblePath {
|
||||||
|
TreeBlock,
|
||||||
|
TreeAllow,
|
||||||
|
Mlp,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of the DDoS ensemble inference.
|
||||||
|
pub struct DDoSEnsembleVerdict {
|
||||||
|
pub action: DDoSAction,
|
||||||
|
pub score: f64,
|
||||||
|
pub reason: &'static str,
|
||||||
|
pub path: DDoSEnsemblePath,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Normalize raw features using trained min/max constants.
|
||||||
|
#[inline]
|
||||||
|
fn normalize(raw: &[f32; 14]) -> [f32; 14] {
|
||||||
|
let mut out = [0.0f32; 14];
|
||||||
|
for i in 0..14 {
|
||||||
|
let range = ddos_weights::NORM_MAXS[i] - ddos_weights::NORM_MINS[i];
|
||||||
|
out[i] = if range > 0.0 {
|
||||||
|
((raw[i] - ddos_weights::NORM_MINS[i]) / range).clamp(0.0, 1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Full DDoS ensemble inference: decision tree first, MLP only on `Defer`.
|
||||||
|
pub fn ddos_ensemble_predict(raw: &[f32; 14]) -> DDoSEnsembleVerdict {
|
||||||
|
let input = normalize(raw);
|
||||||
|
|
||||||
|
let tree_result = tree_predict(&ddos_weights::TREE_NODES, &input);
|
||||||
|
match tree_result {
|
||||||
|
TreeDecision::Block => DDoSEnsembleVerdict {
|
||||||
|
action: DDoSAction::Block,
|
||||||
|
score: 1.0,
|
||||||
|
reason: "ensemble:tree_block",
|
||||||
|
path: DDoSEnsemblePath::TreeBlock,
|
||||||
|
},
|
||||||
|
TreeDecision::Allow => DDoSEnsembleVerdict {
|
||||||
|
action: DDoSAction::Allow,
|
||||||
|
score: 0.0,
|
||||||
|
reason: "ensemble:tree_allow",
|
||||||
|
path: DDoSEnsemblePath::TreeAllow,
|
||||||
|
},
|
||||||
|
TreeDecision::Defer => {
|
||||||
|
let mlp_score = mlp_predict_32::<14>(
|
||||||
|
&ddos_weights::W1,
|
||||||
|
&ddos_weights::B1,
|
||||||
|
&ddos_weights::W2,
|
||||||
|
ddos_weights::B2,
|
||||||
|
&input,
|
||||||
|
);
|
||||||
|
let action = if mlp_score > ddos_weights::THRESHOLD {
|
||||||
|
DDoSAction::Block
|
||||||
|
} else {
|
||||||
|
DDoSAction::Allow
|
||||||
|
};
|
||||||
|
DDoSEnsembleVerdict {
|
||||||
|
action,
|
||||||
|
score: mlp_score as f64,
|
||||||
|
reason: "ensemble:mlp",
|
||||||
|
path: DDoSEnsemblePath::Mlp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_allow_path() {
|
||||||
|
// All zeros → feature 4 (request_rate) = 0.0 <= 0.70 → left (node 1)
|
||||||
|
// feature 10 (cookie_ratio) = 0.0 <= 0.30 → left (node 3) → Allow
|
||||||
|
let raw = [0.0f32; 14];
|
||||||
|
let v = ddos_ensemble_predict(&raw);
|
||||||
|
assert_eq!(v.action, DDoSAction::Allow);
|
||||||
|
assert_eq!(v.path, DDoSEnsemblePath::TreeAllow);
|
||||||
|
assert_eq!(v.reason, "ensemble:tree_allow");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_block_path() {
|
||||||
|
// Need: feature 4 (request_rate) > 0.70 normalized → right (node 2)
|
||||||
|
// feature 12 (accept_language_ratio) > 0.25 normalized → right (node 6) → Block
|
||||||
|
// feature 4 max = 500, so raw 400 → normalized 0.8 > 0.70 ✓
|
||||||
|
// feature 12 max = 1.0, so raw 0.5 → normalized 0.5 > 0.25 ✓
|
||||||
|
let mut raw = [0.0f32; 14];
|
||||||
|
raw[4] = 400.0;
|
||||||
|
raw[12] = 0.5;
|
||||||
|
let v = ddos_ensemble_predict(&raw);
|
||||||
|
assert_eq!(v.action, DDoSAction::Block);
|
||||||
|
assert_eq!(v.path, DDoSEnsemblePath::TreeBlock);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mlp_path() {
|
||||||
|
// Need: feature 4 > 0.70 normalized → right (node 2)
|
||||||
|
// feature 12 <= 0.25 normalized → left (node 5) → Defer
|
||||||
|
// feature 4 max = 500, raw 400 → 0.8 > 0.70 ✓
|
||||||
|
// feature 12 max = 1.0, raw 0.1 → 0.1 <= 0.25 ✓
|
||||||
|
let mut raw = [0.0f32; 14];
|
||||||
|
raw[4] = 400.0;
|
||||||
|
raw[12] = 0.1;
|
||||||
|
let v = ddos_ensemble_predict(&raw);
|
||||||
|
assert_eq!(v.path, DDoSEnsemblePath::Mlp);
|
||||||
|
assert_eq!(v.reason, "ensemble:mlp");
|
||||||
|
assert!(v.score >= 0.0 && v.score <= 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_defer_then_mlp_allow() {
|
||||||
|
// Same Defer path as above — verify the MLP produces a valid action
|
||||||
|
let mut raw = [0.0f32; 14];
|
||||||
|
raw[4] = 400.0;
|
||||||
|
raw[12] = 0.1;
|
||||||
|
let v = ddos_ensemble_predict(&raw);
|
||||||
|
assert!(matches!(v.action, DDoSAction::Allow | DDoSAction::Block));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_clamps_high() {
|
||||||
|
let mut raw = [0.0f32; 14];
|
||||||
|
raw[0] = 999.0; // max is 100
|
||||||
|
let normed = normalize(&raw);
|
||||||
|
assert!((normed[0] - 1.0).abs() < f32::EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_clamps_low() {
|
||||||
|
let mut raw = [0.0f32; 14];
|
||||||
|
raw[1] = -500.0; // min is 0
|
||||||
|
let normed = normalize(&raw);
|
||||||
|
assert!((normed[1] - 0.0).abs() < f32::EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
71
src/ensemble/gen/ddos_weights.rs
Normal file
71
src/ensemble/gen/ddos_weights.rs
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
//! Auto-generated weights for the ddos ensemble.
|
||||||
|
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-ddos-mlp`.
|
||||||
|
|
||||||
|
pub const THRESHOLD: f32 = 0.50000000;
|
||||||
|
|
||||||
|
pub const NORM_MINS: [f32; 14] = [
|
||||||
|
0.08778746, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.05001374, 0.02000000,
|
||||||
|
0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const NORM_MAXS: [f32; 14] = [
|
||||||
|
1000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49990082, 500.00000000, 1.00000000,
|
||||||
|
171240.28125000, 30.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const W1: [[f32; 14]; 32] = [
|
||||||
|
[0.57458097, -0.10861993, 0.14037465, -0.23486336, -0.43255216, 0.16347405, 0.71766937, 0.83138502, 0.02852129, 0.56590265, 0.54848498, 0.38098580, 0.82907754, 0.61539698],
|
||||||
|
[0.06583579, 0.02305713, 0.89706898, 0.42619053, -1.20866120, -0.11974730, 1.70674825, -0.17969023, -0.26867196, 0.60768014, -0.08671998, -0.04825107, 0.58131427, -0.02062579],
|
||||||
|
[-0.40264100, 0.18836430, 0.08431315, 0.33763552, 0.44880620, -0.40894085, -0.22044741, 0.00533387, -0.61574107, 0.07670992, 0.63528854, 0.48244709, 0.20411402, -1.80697525],
|
||||||
|
[0.66713083, -0.22220801, 2.11234117, 0.41516641, -0.00165093, 0.65624571, 1.87509167, 0.63406783, -2.54182458, -0.53618753, 2.16407824, -0.61959583, -0.04717547, 0.17551991],
|
||||||
|
[-0.51027024, -0.60132700, 0.46407551, -0.57346475, -0.30902353, -0.24235034, 0.08087540, -2.14762974, -0.29429656, 0.56257033, -0.26935315, -0.16799171, 0.56852734, 1.93494022],
|
||||||
|
[-0.24938971, -0.12699288, -0.13746630, 0.64942318, 0.09490766, -0.02158179, 0.72449303, -0.28493983, -0.43053114, -0.01443988, 0.89670080, -0.34539866, -1.47019410, 0.79477930],
|
||||||
|
[0.62935185, 0.74686801, -0.15527052, -0.06635039, 0.73137009, 0.78417069, -0.06417987, 0.72259408, 0.85131824, 0.00477386, -0.14302900, 0.63481224, 0.92724019, -0.50126070],
|
||||||
|
[-0.12699059, -0.15016419, -0.48704135, 0.00581611, 0.75824696, 0.84114397, -0.08958503, 0.18609463, 0.56247348, 0.22239330, 0.43324804, 0.82077771, 0.55714250, -0.56955606],
|
||||||
|
[0.83457869, 0.40054807, 0.23281574, -0.58521581, 1.18067443, -0.49485078, 0.08600014, 0.99104887, -0.65019566, -0.44594154, 0.64507920, -0.61692268, -0.29301512, -0.11314666],
|
||||||
|
[-0.07868081, -0.18392175, 0.15165123, 0.35139060, 0.13855398, 0.16470867, 0.21025884, 1.57204449, 0.07827333, 0.05895505, 0.00810917, 1.05159700, 0.04605416, 0.38080546],
|
||||||
|
[1.47428405, -0.21614535, -0.35385504, 0.46582970, 1.26638246, 0.00133375, -3.85603786, 0.39766011, 1.92816520, 0.47828305, -0.16951409, -0.13771342, 0.49451983, 0.41184473],
|
||||||
|
[-0.23364748, 0.68134952, 0.36342716, 0.02657196, 0.07550839, 0.94861823, -0.52908695, 0.83652318, -0.05639480, 0.26536962, 0.44137934, 1.20957208, -0.60747981, -0.50647283],
|
||||||
|
[-0.16961956, -0.49570882, -0.33771378, -0.28554109, 0.95865113, -0.49269623, -0.44559151, 1.28568971, 0.79537493, -0.53175420, -3.19015551, 0.52214253, 0.86517984, 0.62523192],
|
||||||
|
[-0.16956513, -0.61727583, 0.63967121, 0.96406335, -0.28760204, 0.56459671, 0.78585202, -0.03668134, -0.14773002, -0.35764447, 0.84649116, -0.34540027, -0.12314465, -0.10070048],
|
||||||
|
[-0.34183556, -0.07760386, 0.70894319, 0.92814171, -0.19357866, 0.41449037, 0.54653358, 0.27682835, 0.81471086, 0.56383932, 0.57456553, -0.61491662, 0.92498505, 0.74495614],
|
||||||
|
[-0.38917324, -0.29217750, 1.43508542, -0.19152534, -0.18823336, 0.45097819, -0.38063127, -0.40419811, 0.56686693, -0.33231607, -0.19567636, -0.02500075, -0.04762971, 0.44703853],
|
||||||
|
[1.14234805, -0.62868208, -0.21298689, 0.00263968, -0.66115338, -1.12038326, 0.93599045, 0.77646011, -0.22770278, 1.43982041, 0.96078646, 1.15076077, -0.45110813, 0.83090556],
|
||||||
|
[0.89638984, -0.69683450, -0.29400119, 0.94997799, 0.90305328, -0.80215877, -0.09983492, -0.90757453, -0.03181892, 1.00702441, -0.97962254, -0.89580274, 0.69299418, -0.75975400],
|
||||||
|
[-0.75832003, -0.07210776, 0.07825917, 1.51633596, 0.44593197, 0.00936707, -0.12142835, -0.09877282, 0.06229200, 1.25678349, 0.25317946, 0.54112315, -0.17941843, 0.93283361],
|
||||||
|
[0.23085761, 0.53307736, 0.38696140, 0.36798462, 0.38192499, 0.23203450, 0.68225187, 0.47096270, -4.24785280, 0.18062039, 0.60047084, 0.16251479, -0.10811257, 0.48166662],
|
||||||
|
[0.10870802, 0.01576116, 0.00298645, 0.25878090, -0.16634797, 0.15850464, -0.24267951, 0.87678236, -0.27257833, 0.78637868, -0.00851476, 0.01502728, 0.92175138, -0.81292266],
|
||||||
|
[-0.74364990, -0.63139439, -0.18314177, -0.36881343, -0.53096825, -0.92442876, -0.05536628, -0.71273297, -0.94937468, -0.03863344, -0.09668982, -1.07886386, 0.58555382, 0.23351164],
|
||||||
|
[-0.09152136, 0.96538877, -0.11560653, -0.53110164, 0.89070886, 0.05664408, -0.71661353, 0.79684436, -0.00206013, 0.23857179, 0.06074178, -0.67188424, -0.15624331, 0.43436247],
|
||||||
|
[-0.28189376, -0.00535834, 0.60541785, 0.82968009, -0.21901314, -0.29874969, -0.16872653, 0.45570841, -0.25372767, -0.12359514, -1.10104620, 0.00162374, 0.07622175, 0.60413152],
|
||||||
|
[-1.13819373, -0.41320390, -5.57348347, 0.40931624, -1.59562767, 0.72510892, 0.03248254, 0.00407641, 0.57557869, 0.53510398, -0.35943517, 0.52707136, 0.61220711, -0.11644226],
|
||||||
|
[-0.02057049, 0.42545527, 0.24192038, 0.29863021, -0.22839858, -0.25318733, 0.17906551, -0.29471490, -0.04746799, 0.15909556, -0.26826856, -0.06874973, -0.03044286, 0.11770450],
|
||||||
|
[-0.18060833, -0.06301155, 0.01656315, -0.40476608, -0.35056075, 0.06344713, 0.32273614, -0.04382812, -0.18925793, 0.02124963, -0.23447622, 0.29704437, 0.19138981, -0.04584064],
|
||||||
|
[0.18248987, 0.05461208, -0.25655189, 0.16673982, 0.03251073, 0.05709980, 0.09135589, 0.06712578, -0.02372392, 0.00487196, -0.11774579, 0.34203079, 0.18477952, 0.09847298],
|
||||||
|
[-0.08292723, -0.03089223, 0.19555064, -0.18158682, -0.32060555, 0.18836822, -0.14625609, -0.83500093, -0.09893667, 0.02719803, 0.06864946, 0.00156752, 0.04342323, 0.30958080],
|
||||||
|
[-0.21274266, 0.06035644, 0.27282441, -0.01010289, -0.05599894, 0.27938741, -0.23254848, -0.20086342, -0.06775926, -0.18059292, 0.92534143, 0.09500337, 0.11612320, -0.06473339],
|
||||||
|
[-0.27279299, 0.96252358, 0.67542273, 0.64720130, 0.15221471, 1.67354584, 0.53074431, 0.65513390, 0.79840666, 0.78613347, 0.34742561, -1.83272552, 0.73313516, 0.09797212],
|
||||||
|
[-0.08888317, 0.14851266, 1.00953877, 0.19915256, -0.10076691, 0.47210938, 0.04427565, 0.19299655, 0.58729172, 0.17481442, -0.57466495, -0.16196120, 0.06293163, 1.73905540],
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const B1: [f32; 32] = [
|
||||||
|
-0.80723554, 0.54879200, 0.01237706, -0.22279924, 0.93692911, 0.12226531, -0.54665250, -0.49958101,
|
||||||
|
-0.20918398, -0.48646352, -0.58741039, -0.50572610, -0.04772990, -0.62962151, -0.46279392, 1.14840722,
|
||||||
|
-0.04871057, -0.31787100, 1.13966286, 0.69543558, -0.17798270, 0.66968435, -0.07442535, -0.70557600,
|
||||||
|
0.79021728, 0.65736526, -0.30761406, 0.63242179, 0.83297908, -0.04573143, -0.18454255, -0.30583009
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const W2: [f32; 32] = [
|
||||||
|
1.09615684, -0.57856798, -0.08730038, -0.06425755, -0.96232760, -2.06290460, 0.70097560, 0.85189444,
|
||||||
|
-0.10077959, 1.94375157, 0.74497795, 0.88425481, 2.11908054, 0.85526127, 0.61624259, -2.93621016,
|
||||||
|
1.52211487, 0.56318259, -3.15219641, -0.55187315, 1.61819077, -0.76258671, -0.09362544, 0.86861998,
|
||||||
|
-0.79028755, -0.90605170, 0.33475992, -0.79945564, -1.16680586, 0.15120529, 0.17619221, 1.61664009
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const B2: f32 = -0.52729088;
|
||||||
|
|
||||||
|
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
|
||||||
|
(3, 0.30015790, 1, 2),
|
||||||
|
(255, 0.00000000, 0, 0),
|
||||||
|
(255, 1.00000000, 0, 0),
|
||||||
|
];
|
||||||
2
src/ensemble/gen/mod.rs
Normal file
2
src/ensemble/gen/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
pub mod ddos_weights;
|
||||||
|
pub mod scanner_weights;
|
||||||
71
src/ensemble/gen/scanner_weights.rs
Normal file
71
src/ensemble/gen/scanner_weights.rs
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
//! Auto-generated weights for the scanner ensemble.
|
||||||
|
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-scanner-mlp`.
|
||||||
|
|
||||||
|
pub const THRESHOLD: f32 = 0.50000000;
|
||||||
|
|
||||||
|
pub const NORM_MINS: [f32; 12] = [
|
||||||
|
0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
|
||||||
|
0.00000000, 0.00000000, 0.00000000, 0.00000000
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const NORM_MAXS: [f32; 12] = [
|
||||||
|
1.00000000, 10.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000,
|
||||||
|
1.00000000, 1.00000000, 1.00000000, 1.00000000
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const W1: [[f32; 12]; 32] = [
|
||||||
|
[2.25848985, 1.62502551, 1.05068624, -0.23875977, 1.29692984, -1.34665418, -1.29937541, 1.66119707, -1.43897200, 0.07720046, -1.17116165, 1.96821272],
|
||||||
|
[2.30885172, 0.02477695, -0.23236598, 1.66507626, -1.41407740, 1.88616431, 1.84703696, -1.46395433, 2.03542018, 1.68318951, 2.01550031, 1.94223917],
|
||||||
|
[2.29420924, 1.86615539, 1.69271469, 1.42137837, 1.43151915, 1.84876072, 1.09228194, 1.73608077, 0.20805965, 0.52542430, -0.02558800, 0.04718366],
|
||||||
|
[0.36484259, -0.02785611, -0.01155548, 0.08577330, -0.00468449, -0.07848717, 0.05191587, 0.50796396, 0.40799347, -0.14838840, -0.30566201, 0.00758083],
|
||||||
|
[0.28191370, 0.20945202, 0.07742970, -0.06654347, 0.17395714, 0.00011351, 0.37079588, 0.41817516, 0.56992871, 0.05705916, 0.22339216, 0.11021475],
|
||||||
|
[0.06522971, 0.64510870, 0.31671444, 0.34980071, 0.03446164, -0.10592904, -0.21302676, -0.04404496, 0.08638768, 0.04217484, 0.43021953, 0.21055792],
|
||||||
|
[0.31206250, -0.14565454, 0.38078794, 0.00860748, 0.29409558, -0.11273954, -0.02210701, 0.15525217, 0.09696059, 0.13877581, 0.06483351, 0.10946950],
|
||||||
|
[0.28374705, -0.02963164, 0.27863786, -0.23428085, 0.12715313, 0.09141072, 0.07769041, 0.01915955, -0.20936646, 0.02813511, -0.03910714, 0.30322370],
|
||||||
|
[-1.19449413, -0.84935474, -0.32267663, -0.08140022, -0.78729230, 1.58759272, 0.88281459, -0.77263606, 1.55394125, 0.10148179, 1.59524822, -0.75499195],
|
||||||
|
[-0.97152823, -0.12173092, 0.04745778, -0.85466659, 1.57352293, -0.52651149, -0.66270715, 1.32282484, -1.24654925, -0.45822921, -1.10187364, -0.91162699],
|
||||||
|
[-0.93944395, -0.57891464, -1.12100291, -0.38871467, -0.18780440, -1.11835766, -0.43614236, -1.07918274, -0.09222561, -0.23854440, -0.16720718, 0.03247443],
|
||||||
|
[0.13319625, 0.87437463, 0.32213065, 0.13902900, 0.64760798, 0.00899744, 0.45325586, -0.14138180, 0.13888212, 0.07780524, -0.12482210, 0.12632932],
|
||||||
|
[0.57018995, -0.10839911, 0.02787536, 0.16884641, 0.19435850, -0.01189608, 0.13881874, -0.10700739, -0.05463003, 0.01371983, 0.04385772, 0.01100468],
|
||||||
|
[-0.26600277, -0.11843663, -0.01081531, 0.10785927, -0.18684258, 0.08537511, 0.01054722, -0.01972559, -0.07416820, 0.57192892, 0.37873995, -0.00498434],
|
||||||
|
[0.72535324, -0.25030360, 0.51470703, -0.16410951, -0.13649474, 0.16246459, -0.27847841, 0.12250750, 0.45576489, -0.18535912, -0.45686084, 0.58293521],
|
||||||
|
[0.18614589, -0.32835677, -0.08683094, 0.07748202, -0.24785264, -0.16834147, 0.27066526, 0.06058804, 0.01903199, -0.17387865, 0.12752151, -0.03780220],
|
||||||
|
[-1.22358644, -0.78316134, -0.54068804, -0.07921790, -0.72697675, 1.80127227, 0.14326867, -0.51875746, 1.83125353, -0.02672976, 1.68589675, -0.80162954],
|
||||||
|
[-0.83690810, -0.12682360, 0.10783038, -0.64648604, 1.50810242, -0.48788729, -0.59418935, 0.94863659, -0.84788662, -0.49779284, -0.96408021, -1.14068258],
|
||||||
|
[-0.96322638, -0.50503486, -0.87195945, -0.34710455, -0.28645220, -1.10507452, -0.32122782, -0.80753750, -0.00843489, 0.04215550, 0.03197355, 0.05468401],
|
||||||
|
[-0.17587705, 0.45144933, 0.37954769, -0.15405300, 0.75590396, 0.00346784, 0.62332457, -0.15602241, -0.26471916, -0.19963606, -0.22497311, -0.20784236],
|
||||||
|
[0.60608941, -0.05316854, 0.03766245, 0.46412235, -0.41121334, -0.01225545, -0.11125158, -0.33533856, -0.04625564, -0.02995013, -0.24979964, -0.35824969],
|
||||||
|
[0.08163761, 0.04702193, -0.24007457, -0.23439978, 0.27066308, 0.48389259, 0.32692793, -0.23089454, 0.26520243, -0.14099684, 0.06713670, 0.14434725],
|
||||||
|
[-0.50808382, -0.14518137, -0.23912378, 0.33510539, 0.46566108, 0.09035082, -0.12637842, 0.55245715, -0.19972627, 0.24517706, 0.34291887, 0.01936621],
|
||||||
|
[0.35826349, 0.21200819, 0.65315312, -0.16792546, 0.41378024, 0.32129642, 0.50814188, 0.48289016, 0.06839173, 0.42079177, 0.52295685, 0.26273951],
|
||||||
|
[0.24575019, 0.10700949, 0.07041252, -0.09410189, 0.18897925, 0.31616825, -0.01306109, 0.33499330, -0.01866218, 0.06233863, 0.15316568, 0.08370106],
|
||||||
|
[0.17828286, 0.17363867, -0.10626584, 0.06075979, 0.39465010, 0.19557165, 0.30352867, 0.26720291, 0.40256795, 0.13942246, 0.05869288, 0.08310238],
|
||||||
|
[-0.04834138, 0.29206491, 0.01330532, 0.07626399, -0.17378819, 0.09515948, 0.02298534, 0.41555724, 0.09492048, 0.39422533, 0.39373979, 0.20463347],
|
||||||
|
[-0.11641891, -0.06529939, -0.18899654, -0.02157970, -0.03554495, 0.10956290, -0.11688691, 0.04077352, 0.34220406, -0.09558969, 0.16150762, 0.25759667],
|
||||||
|
[-0.17313123, 0.00591523, 0.29443163, 0.08298909, 0.07761172, 0.19023541, 0.23826212, -0.07167042, 0.08753359, 0.17917964, -0.03248737, 0.28516129],
|
||||||
|
[0.13091524, 0.21435370, 0.15093684, 0.30902347, 0.44151527, 0.55901742, 0.19933179, 0.06438518, 0.30585650, -0.34089112, 0.26879075, 0.12928906],
|
||||||
|
[-0.25311065, -0.09963353, -0.50099874, 0.57481062, 0.38744658, -0.13065037, 0.18897361, 0.49376330, -0.15626629, 0.19911517, 0.06437352, -0.09104283],
|
||||||
|
[0.35787049, -0.04814727, 0.45446551, -0.15264697, 0.36565515, 0.22795495, 0.24630190, 0.16362202, 0.21044184, 0.53882843, 0.42343852, 0.18454899],
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const B1: [f32; 32] = [
|
||||||
|
1.12135851, 0.64268047, 0.44761124, -0.28471574, 0.70866716, -0.25293177, -0.19119856, 0.39284116,
|
||||||
|
-0.20628852, -0.29301032, -0.08837436, 0.92048728, 0.91167349, -0.33615190, -0.06016272, 0.79141164,
|
||||||
|
-0.43257964, 0.48180589, 0.70891160, -0.24290052, 0.83115542, 0.69964927, 0.97887653, 1.34517038,
|
||||||
|
1.10292709, 0.42009205, 1.07155228, 0.61349720, 0.46157768, 1.01911950, 0.51159418, 0.60460496
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const W2: [f32; 32] = [
|
||||||
|
1.55191231, 1.27754235, 0.43588921, 0.10868450, 0.55931729, -1.46911597, -0.54461092, 0.78240824,
|
||||||
|
-1.25938582, -0.06287600, -1.02053738, 1.07076716, 1.58776867, -0.03168033, -0.11393511, 1.30535436,
|
||||||
|
-1.46621227, 0.62925971, 0.76781118, -0.74480098, 1.29669034, 0.62078375, 1.64134884, 2.09736991,
|
||||||
|
1.52834618, 0.87368065, 1.80090642, 0.89230227, 0.38757962, 1.80718291, 0.64923352, 1.18709576
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const B2: f32 = 0.23270580;
|
||||||
|
|
||||||
|
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
|
||||||
|
(3, 0.50000000, 1, 2),
|
||||||
|
(255, 1.00000000, 0, 0),
|
||||||
|
(255, 0.00000000, 0, 0),
|
||||||
|
];
|
||||||
132
src/ensemble/mlp.rs
Normal file
132
src/ensemble/mlp.rs
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
/// Two-layer MLP forward pass with a fixed hidden size of 32:
|
||||||
|
///
|
||||||
|
/// hidden = ReLU(W1 @ input + b1)
|
||||||
|
/// output = sigmoid(w2 · hidden + b2)
|
||||||
|
///
|
||||||
|
/// Zero allocation — the 32-element hidden layer lives on the stack.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn mlp_predict_32<const INPUT: usize>(
|
||||||
|
w1: &[[f32; INPUT]; 32],
|
||||||
|
b1: &[f32; 32],
|
||||||
|
w2: &[f32; 32],
|
||||||
|
b2: f32,
|
||||||
|
input: &[f32; INPUT],
|
||||||
|
) -> f32 {
|
||||||
|
let mut hidden = [0.0f32; 32];
|
||||||
|
|
||||||
|
// Hidden layer: h_j = ReLU(sum_i(w1[j][i] * input[i]) + b1[j])
|
||||||
|
for j in 0..32 {
|
||||||
|
let mut sum = b1[j];
|
||||||
|
for i in 0..INPUT {
|
||||||
|
sum += w1[j][i] * input[i];
|
||||||
|
}
|
||||||
|
hidden[j] = relu_f32(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output layer: sigmoid(sum_j(w2[j] * hidden[j]) + b2)
|
||||||
|
let mut out = b2;
|
||||||
|
for j in 0..32 {
|
||||||
|
out += w2[j] * hidden[j];
|
||||||
|
}
|
||||||
|
sigmoid_f32(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn sigmoid_f32(x: f32) -> f32 {
|
||||||
|
1.0 / (1.0 + (-x).exp())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn relu_f32(x: f32) -> f32 {
|
||||||
|
if x > 0.0 { x } else { 0.0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sigmoid_boundaries() {
|
||||||
|
assert!((sigmoid_f32(0.0) - 0.5).abs() < 1e-6);
|
||||||
|
assert!(sigmoid_f32(10.0) > 0.999);
|
||||||
|
assert!(sigmoid_f32(-10.0) < 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_relu() {
|
||||||
|
assert_eq!(relu_f32(0.0), 0.0);
|
||||||
|
assert_eq!(relu_f32(1.5), 1.5);
|
||||||
|
assert_eq!(relu_f32(-3.0), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mlp_zero_weights() {
|
||||||
|
// All weights zero, bias2 = 0 → sigmoid(0) = 0.5
|
||||||
|
let w1 = [[0.0f32; 2]; 32];
|
||||||
|
let b1 = [0.0f32; 32];
|
||||||
|
let w2 = [0.0f32; 32];
|
||||||
|
let b2 = 0.0f32;
|
||||||
|
let input = [1.0f32, 2.0];
|
||||||
|
let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input);
|
||||||
|
assert!((result - 0.5).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mlp_known_output() {
|
||||||
|
// Single active hidden unit: w1[0] = [1, 0], b1[0] = 0, w2[0] = 2, b2 = -1
|
||||||
|
// input = [3, 0]
|
||||||
|
// hidden[0] = ReLU(1*3 + 0*0 + 0) = 3.0, rest = 0
|
||||||
|
// output = sigmoid(2*3 + (-1)) = sigmoid(5) ≈ 0.9933
|
||||||
|
let mut w1 = [[0.0f32; 2]; 32];
|
||||||
|
w1[0] = [1.0, 0.0];
|
||||||
|
let b1 = [0.0f32; 32];
|
||||||
|
let mut w2 = [0.0f32; 32];
|
||||||
|
w2[0] = 2.0;
|
||||||
|
let b2 = -1.0f32;
|
||||||
|
let input = [3.0f32, 0.0];
|
||||||
|
let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input);
|
||||||
|
let expected = sigmoid_f32(5.0);
|
||||||
|
assert!(
|
||||||
|
(result - expected).abs() < 1e-6,
|
||||||
|
"expected {expected}, got {result}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mlp_relu_clips_negative() {
|
||||||
|
// w1[0] = [1, 0], b1[0] = -10 → hidden[0] = ReLU(-10 + input) clips to 0
|
||||||
|
// Everything zero → sigmoid(b2) = sigmoid(0) = 0.5
|
||||||
|
let mut w1 = [[0.0f32; 2]; 32];
|
||||||
|
w1[0] = [1.0, 0.0];
|
||||||
|
let mut b1 = [0.0f32; 32];
|
||||||
|
b1[0] = -10.0;
|
||||||
|
let w2 = [1.0f32; 32]; // doesn't matter, hidden is all 0
|
||||||
|
let b2 = 0.0f32;
|
||||||
|
let input = [3.0f32, 0.0]; // hidden[0] = ReLU(3-10) = 0
|
||||||
|
let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input);
|
||||||
|
assert!((result - 0.5).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod proptests {
|
||||||
|
use super::*;
|
||||||
|
use proptest::prelude::*;
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn output_bounded_0_1(
|
||||||
|
x0 in -10.0f32..10.0,
|
||||||
|
x1 in -10.0f32..10.0,
|
||||||
|
b2 in -5.0f32..5.0,
|
||||||
|
) {
|
||||||
|
let w1 = [[0.1f32, -0.2]; 32];
|
||||||
|
let b1 = [0.0f32; 32];
|
||||||
|
let w2 = [0.05f32; 32];
|
||||||
|
let input = [x0, x1];
|
||||||
|
let result = mlp_predict_32::<2>(&w1, &b1, &w2, b2, &input);
|
||||||
|
prop_assert!(result >= 0.0 && result <= 1.0,
|
||||||
|
"output {result} outside [0,1]");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
6
src/ensemble/mod.rs
Normal file
6
src/ensemble/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
pub mod ddos;
|
||||||
|
pub mod gen;
|
||||||
|
pub mod mlp;
|
||||||
|
pub mod replay;
|
||||||
|
pub mod scanner;
|
||||||
|
pub mod tree;
|
||||||
170
src/ensemble/scanner.rs
Normal file
170
src/ensemble/scanner.rs
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
use crate::scanner::model::{ScannerAction, ScannerVerdict};
|
||||||
|
use super::gen::scanner_weights;
|
||||||
|
use super::mlp::mlp_predict_32;
|
||||||
|
use super::tree::{tree_predict, TreeDecision};
|
||||||
|
|
||||||
|
/// Which path the ensemble took to reach its verdict.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum EnsemblePath {
|
||||||
|
TreeBlock,
|
||||||
|
TreeAllow,
|
||||||
|
Mlp,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of the scanner ensemble: action + confidence score + explanation.
|
||||||
|
pub struct EnsembleVerdict {
|
||||||
|
pub action: ScannerAction,
|
||||||
|
pub score: f64,
|
||||||
|
pub reason: &'static str,
|
||||||
|
pub path: EnsemblePath,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Normalize raw features using trained min/max constants.
|
||||||
|
#[inline]
|
||||||
|
fn normalize(raw: &[f32; 12]) -> [f32; 12] {
|
||||||
|
let mut out = [0.0f32; 12];
|
||||||
|
for i in 0..12 {
|
||||||
|
let range = scanner_weights::NORM_MAXS[i] - scanner_weights::NORM_MINS[i];
|
||||||
|
out[i] = if range > 0.0 {
|
||||||
|
((raw[i] - scanner_weights::NORM_MINS[i]) / range).clamp(0.0, 1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Full ensemble inference: decision tree first, MLP only on `Defer`.
|
||||||
|
///
|
||||||
|
/// Returns an [`EnsembleVerdict`] that can be converted into a
|
||||||
|
/// [`ScannerVerdict`] for the rest of the pipeline.
|
||||||
|
pub fn scanner_ensemble_predict(raw: &[f32; 12]) -> EnsembleVerdict {
|
||||||
|
let input = normalize(raw);
|
||||||
|
|
||||||
|
let tree_result = tree_predict(&scanner_weights::TREE_NODES, &input);
|
||||||
|
match tree_result {
|
||||||
|
TreeDecision::Block => EnsembleVerdict {
|
||||||
|
action: ScannerAction::Block,
|
||||||
|
score: 1.0,
|
||||||
|
reason: "ensemble:tree_block",
|
||||||
|
path: EnsemblePath::TreeBlock,
|
||||||
|
},
|
||||||
|
TreeDecision::Allow => EnsembleVerdict {
|
||||||
|
action: ScannerAction::Allow,
|
||||||
|
score: 0.0,
|
||||||
|
reason: "ensemble:tree_allow",
|
||||||
|
path: EnsemblePath::TreeAllow,
|
||||||
|
},
|
||||||
|
TreeDecision::Defer => {
|
||||||
|
let mlp_score = mlp_predict_32::<12>(
|
||||||
|
&scanner_weights::W1,
|
||||||
|
&scanner_weights::B1,
|
||||||
|
&scanner_weights::W2,
|
||||||
|
scanner_weights::B2,
|
||||||
|
&input,
|
||||||
|
);
|
||||||
|
let action = if mlp_score > scanner_weights::THRESHOLD {
|
||||||
|
ScannerAction::Block
|
||||||
|
} else {
|
||||||
|
ScannerAction::Allow
|
||||||
|
};
|
||||||
|
EnsembleVerdict {
|
||||||
|
action,
|
||||||
|
score: mlp_score as f64,
|
||||||
|
reason: "ensemble:mlp",
|
||||||
|
path: EnsemblePath::Mlp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<EnsembleVerdict> for ScannerVerdict {
|
||||||
|
fn from(v: EnsembleVerdict) -> Self {
|
||||||
|
ScannerVerdict {
|
||||||
|
action: v.action,
|
||||||
|
score: v.score,
|
||||||
|
reason: v.reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_allow_path() {
|
||||||
|
// All features at zero → feature 3 (suspicious_ua) = 0.0 <= 0.65 → left (node 1)
|
||||||
|
// feature 0 (path_depth) = 0.0 <= 0.40 → left (node 3) → Allow leaf
|
||||||
|
let raw = [0.0f32; 12];
|
||||||
|
let v = scanner_ensemble_predict(&raw);
|
||||||
|
assert_eq!(v.action, ScannerAction::Allow);
|
||||||
|
assert_eq!(v.path, EnsemblePath::TreeAllow);
|
||||||
|
assert_eq!(v.reason, "ensemble:tree_allow");
|
||||||
|
assert!((v.score - 0.0).abs() < f64::EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_block_path() {
|
||||||
|
// Need: feature 3 (suspicious_ua) > 0.65 (normalized) → right (node 2)
|
||||||
|
// feature 7 (payload_entropy) > 0.72 (normalized) → right (node 6) → Block
|
||||||
|
// feature 3 max = 1.0, so raw 0.8 → normalized 0.8 > 0.65 ✓
|
||||||
|
// feature 7 max = 8.0, so raw 6.0 → normalized 0.75 > 0.72 ✓
|
||||||
|
let mut raw = [0.0f32; 12];
|
||||||
|
raw[3] = 0.8; // suspicious_ua: normalized = 0.8/1.0 = 0.8 > 0.65
|
||||||
|
raw[7] = 6.0; // payload_entropy: normalized = 6.0/8.0 = 0.75 > 0.72
|
||||||
|
let v = scanner_ensemble_predict(&raw);
|
||||||
|
assert_eq!(v.action, ScannerAction::Block);
|
||||||
|
assert_eq!(v.path, EnsemblePath::TreeBlock);
|
||||||
|
assert_eq!(v.reason, "ensemble:tree_block");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mlp_path() {
|
||||||
|
// Need: feature 3 > 0.65 normalized → right (node 2)
|
||||||
|
// feature 7 <= 0.72 normalized → left (node 5) → Defer
|
||||||
|
// Then MLP runs on the normalized input.
|
||||||
|
let mut raw = [0.0f32; 12];
|
||||||
|
raw[3] = 0.8; // normalized = 0.8 > 0.65
|
||||||
|
raw[7] = 4.0; // normalized = 4.0/8.0 = 0.5 <= 0.72
|
||||||
|
// Also need feature 2 (query_param_count) to navigate node 5 correctly
|
||||||
|
// node 5: split on feature 2, threshold 0.55 → left=9(Defer), right=10
|
||||||
|
// normalized feature 2 = 0.0/20.0 = 0.0 <= 0.55 → left (node 9) → Defer
|
||||||
|
let v = scanner_ensemble_predict(&raw);
|
||||||
|
assert_eq!(v.path, EnsemblePath::Mlp);
|
||||||
|
assert_eq!(v.reason, "ensemble:mlp");
|
||||||
|
// MLP output is deterministic for these inputs
|
||||||
|
assert!(v.score >= 0.0 && v.score <= 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_clamps() {
|
||||||
|
// Values beyond max should be clamped to 1.0
|
||||||
|
let mut raw = [0.0f32; 12];
|
||||||
|
raw[0] = 100.0; // max is 10.0
|
||||||
|
let normed = normalize(&raw);
|
||||||
|
assert!((normed[0] - 1.0).abs() < f64::EPSILON as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_negative_clamps() {
|
||||||
|
let mut raw = [0.0f32; 12];
|
||||||
|
raw[0] = -5.0; // min is 0.0
|
||||||
|
let normed = normalize(&raw);
|
||||||
|
assert!((normed[0] - 0.0).abs() < f64::EPSILON as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_verdict_into_scanner_verdict() {
|
||||||
|
let v = EnsembleVerdict {
|
||||||
|
action: ScannerAction::Block,
|
||||||
|
score: 0.85,
|
||||||
|
reason: "ensemble:mlp",
|
||||||
|
path: EnsemblePath::Mlp,
|
||||||
|
};
|
||||||
|
let sv: ScannerVerdict = v.into();
|
||||||
|
assert_eq!(sv.action, ScannerAction::Block);
|
||||||
|
assert!((sv.score - 0.85).abs() < f64::EPSILON);
|
||||||
|
assert_eq!(sv.reason, "ensemble:mlp");
|
||||||
|
}
|
||||||
|
}
|
||||||
156
src/ensemble/tree.rs
Normal file
156
src/ensemble/tree.rs
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
/// Decision from a tree leaf node.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum TreeDecision {
|
||||||
|
/// High-confidence block — overrides MLP.
|
||||||
|
Block,
|
||||||
|
/// High-confidence allow — overrides MLP.
|
||||||
|
Allow,
|
||||||
|
/// Low-confidence — defer to MLP for scoring.
|
||||||
|
Defer,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Packed tree node: `(feature_index, threshold, left_child, right_child)`.
|
||||||
|
///
|
||||||
|
/// Leaf nodes are encoded with `feature_index = 255`.
|
||||||
|
/// For leaves the threshold encodes the decision:
|
||||||
|
/// - `< 0.25` → Allow
|
||||||
|
/// - `> 0.75` → Block
|
||||||
|
/// - otherwise → Defer
|
||||||
|
pub type PackedNode = (u8, f32, u16, u16);
|
||||||
|
|
||||||
|
/// Walk a packed decision tree. O(depth), zero allocation.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn tree_predict(nodes: &[PackedNode], input: &[f32]) -> TreeDecision {
|
||||||
|
let mut idx = 0usize;
|
||||||
|
loop {
|
||||||
|
let (feature, threshold, left, right) = nodes[idx];
|
||||||
|
if feature == 255 {
|
||||||
|
return if threshold < 0.25 {
|
||||||
|
TreeDecision::Allow
|
||||||
|
} else if threshold > 0.75 {
|
||||||
|
TreeDecision::Block
|
||||||
|
} else {
|
||||||
|
TreeDecision::Defer
|
||||||
|
};
|
||||||
|
}
|
||||||
|
idx = if input[feature as usize] <= threshold {
|
||||||
|
left as usize
|
||||||
|
} else {
|
||||||
|
right as usize
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
// Simple tree:
|
||||||
|
// node 0: feature 0, threshold 0.5 → left=1, right=2
|
||||||
|
// node 1: leaf Allow (threshold 0.0)
|
||||||
|
// node 2: leaf Block (threshold 1.0)
|
||||||
|
const SIMPLE_TREE: [PackedNode; 3] = [
|
||||||
|
(0, 0.5, 1, 2),
|
||||||
|
(255, 0.0, 0, 0), // Allow
|
||||||
|
(255, 1.0, 0, 0), // Block
|
||||||
|
];
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_allow() {
|
||||||
|
let input = [0.3]; // <= 0.5 → left → Allow
|
||||||
|
assert_eq!(tree_predict(&SIMPLE_TREE, &input), TreeDecision::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_block() {
|
||||||
|
let input = [0.8]; // > 0.5 → right → Block
|
||||||
|
assert_eq!(tree_predict(&SIMPLE_TREE, &input), TreeDecision::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_defer() {
|
||||||
|
// Tree with a Defer leaf
|
||||||
|
let tree: [PackedNode; 2] = [
|
||||||
|
(0, 0.5, 1, 1),
|
||||||
|
(255, 0.5, 0, 0), // Defer
|
||||||
|
];
|
||||||
|
let input = [0.3];
|
||||||
|
assert_eq!(tree_predict(&tree, &input), TreeDecision::Defer);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_boundary_allow() {
|
||||||
|
// threshold exactly 0.25 is NOT < 0.25, so it should Defer
|
||||||
|
let tree: [PackedNode; 1] = [(255, 0.25, 0, 0)];
|
||||||
|
assert_eq!(tree_predict(&tree, &[]), TreeDecision::Defer);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_boundary_block() {
|
||||||
|
// threshold exactly 0.75 is NOT > 0.75, so it should Defer
|
||||||
|
let tree: [PackedNode; 1] = [(255, 0.75, 0, 0)];
|
||||||
|
assert_eq!(tree_predict(&tree, &[]), TreeDecision::Defer);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deeper_tree() {
|
||||||
|
// Depth-3 tree with 4 features
|
||||||
|
let tree: [PackedNode; 7] = [
|
||||||
|
(0, 0.5, 1, 2), // root: feature 0
|
||||||
|
(1, 0.3, 3, 4), // left: feature 1
|
||||||
|
(2, 0.7, 5, 6), // right: feature 2
|
||||||
|
(255, 0.0, 0, 0), // Allow
|
||||||
|
(255, 0.5, 0, 0), // Defer
|
||||||
|
(255, 1.0, 0, 0), // Block
|
||||||
|
(255, 0.0, 0, 0), // Allow
|
||||||
|
];
|
||||||
|
// feature 0=0.2 (<=0.5→left=1), feature 1=0.1 (<=0.3→left=3) → Allow
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.2, 0.1, 0.0, 0.0]), TreeDecision::Allow);
|
||||||
|
// feature 0=0.2 (<=0.5→left=1), feature 1=0.5 (>0.3→right=4) → Defer
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.2, 0.5, 0.0, 0.0]), TreeDecision::Defer);
|
||||||
|
// feature 0=0.8 (>0.5→right=2), feature 2=0.3 (<=0.7→left=5) → Block
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.8, 0.0, 0.3, 0.0]), TreeDecision::Block);
|
||||||
|
// feature 0=0.8 (>0.5→right=2), feature 2=0.9 (>0.7→right=6) → Allow
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.8, 0.0, 0.9, 0.0]), TreeDecision::Allow);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod proptests {
|
||||||
|
use super::*;
|
||||||
|
use proptest::prelude::*;
|
||||||
|
|
||||||
|
/// Generate a valid packed tree (complete binary tree of given depth).
|
||||||
|
/// All internal nodes split on feature 0 at threshold 0.5.
|
||||||
|
/// Leaves cycle through Allow/Block/Defer.
|
||||||
|
fn make_complete_tree(depth: u32) -> Vec<PackedNode> {
|
||||||
|
let num_nodes = (1u32 << (depth + 1)) - 1;
|
||||||
|
let num_internal = (1u32 << depth) - 1;
|
||||||
|
let mut nodes = Vec::with_capacity(num_nodes as usize);
|
||||||
|
let decisions = [0.0f32, 1.0, 0.5]; // Allow, Block, Defer
|
||||||
|
for i in 0..num_nodes {
|
||||||
|
if i < num_internal {
|
||||||
|
let left = 2 * i + 1;
|
||||||
|
let right = 2 * i + 2;
|
||||||
|
nodes.push((0u8, 0.5f32, left as u16, right as u16));
|
||||||
|
} else {
|
||||||
|
let leaf_idx = (i - num_internal) as usize;
|
||||||
|
nodes.push((255u8, decisions[leaf_idx % 3], 0u16, 0u16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
proptest! {
|
||||||
|
#[test]
|
||||||
|
fn tree_always_terminates(val in 0.0f32..1.0, depth in 0u32..5) {
|
||||||
|
let tree = make_complete_tree(depth);
|
||||||
|
let input = [val; 16]; // enough features for any tree
|
||||||
|
let result = tree_predict(&tree, &input);
|
||||||
|
prop_assert!(matches!(
|
||||||
|
result,
|
||||||
|
TreeDecision::Allow | TreeDecision::Block | TreeDecision::Defer
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user