feat(training): add burn MLP and CART tree trainers with weight export
Behind the `training` feature flag (burn 0.20 + ndarray + autodiff). Trains a single-hidden-layer MLP with Adam optimizer and weighted BCE loss, plus a CART decision tree using Gini impurity. Exports trained weights as Rust const arrays that compile directly into the binary. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
4288
Cargo.lock
generated
4288
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
15
Cargo.toml
15
Cargo.toml
@@ -77,6 +77,17 @@ iroh-gossip = { version = "0.96", features = ["net"] }
|
|||||||
blake3 = "1"
|
blake3 = "1"
|
||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
|
rayon = "1"
|
||||||
|
tempfile = "3"
|
||||||
|
|
||||||
|
# Dataset ingestion (CIC-IDS2017 CSV parsing)
|
||||||
|
csv = "1"
|
||||||
|
|
||||||
|
# burn-rs ML framework (training only, behind `training` feature)
|
||||||
|
burn = { version = "0.20", features = ["ndarray", "autodiff"], optional = true }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
training = ["burn"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
@@ -87,6 +98,10 @@ tempfile = "3"
|
|||||||
name = "scanner_bench"
|
name = "scanner_bench"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "ddos_bench"
|
||||||
|
harness = false
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
opt-level = 3
|
opt-level = 3
|
||||||
lto = true
|
lto = true
|
||||||
|
|||||||
342
src/training/export.rs
Normal file
342
src/training/export.rs
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
//! Weight export: converts trained models into standalone Rust `const` arrays
|
||||||
|
//! and optionally Lean 4 definitions.
|
||||||
|
//!
|
||||||
|
//! The generated Rust source is meant to be placed in
|
||||||
|
//! `src/ensemble/gen/{scanner,ddos}_weights.rs` so the inference side can use
|
||||||
|
//! compile-time weight constants with zero runtime cost.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::fmt::Write as FmtWrite;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// All data needed to emit a standalone inference source file.
|
||||||
|
pub struct ExportedModel {
|
||||||
|
/// Module name used in generated code comments and Lean defs.
|
||||||
|
pub model_name: String,
|
||||||
|
/// Number of input features.
|
||||||
|
pub input_dim: usize,
|
||||||
|
/// Hidden layer width (always 32 in the current ensemble).
|
||||||
|
pub hidden_dim: usize,
|
||||||
|
/// Weight matrix for layer 1: `hidden_dim x input_dim`.
|
||||||
|
pub w1: Vec<Vec<f32>>,
|
||||||
|
/// Bias vector for layer 1: length `hidden_dim`.
|
||||||
|
pub b1: Vec<f32>,
|
||||||
|
/// Weight vector for layer 2: length `hidden_dim`.
|
||||||
|
pub w2: Vec<f32>,
|
||||||
|
/// Bias scalar for layer 2.
|
||||||
|
pub b2: f32,
|
||||||
|
/// Packed decision tree nodes.
|
||||||
|
pub tree_nodes: Vec<(u8, f32, u16, u16)>,
|
||||||
|
/// MLP classification threshold.
|
||||||
|
pub threshold: f32,
|
||||||
|
/// Per-feature normalization minimums.
|
||||||
|
pub norm_mins: Vec<f32>,
|
||||||
|
/// Per-feature normalization maximums.
|
||||||
|
pub norm_maxs: Vec<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a Rust source file with `const` arrays for all model weights.
|
||||||
|
pub fn generate_rust_source(model: &ExportedModel) -> String {
|
||||||
|
let mut s = String::with_capacity(8192);
|
||||||
|
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"//! Auto-generated weights for the {} ensemble.",
|
||||||
|
model.model_name
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"//! DO NOT EDIT — regenerate with `cargo run --features training -- train-{}-mlp`.",
|
||||||
|
model.model_name.to_ascii_lowercase()
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
// Threshold.
|
||||||
|
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", model.threshold).unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
// Normalization params.
|
||||||
|
write_f32_array(&mut s, "NORM_MINS", &model.norm_mins);
|
||||||
|
write_f32_array(&mut s, "NORM_MAXS", &model.norm_maxs);
|
||||||
|
|
||||||
|
// W1: hidden_dim x input_dim.
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"pub const W1: [[f32; {}]; {}] = [",
|
||||||
|
model.input_dim, model.hidden_dim
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
for row in &model.w1 {
|
||||||
|
write!(s, " [").unwrap();
|
||||||
|
for (i, v) in row.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
write!(s, ", ").unwrap();
|
||||||
|
}
|
||||||
|
write!(s, "{:.8}", v).unwrap();
|
||||||
|
}
|
||||||
|
writeln!(s, "],").unwrap();
|
||||||
|
}
|
||||||
|
writeln!(s, "];").unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
// B1.
|
||||||
|
write_f32_array(&mut s, "B1", &model.b1);
|
||||||
|
|
||||||
|
// W2.
|
||||||
|
write_f32_array(&mut s, "W2", &model.w2);
|
||||||
|
|
||||||
|
// B2.
|
||||||
|
writeln!(s, "pub const B2: f32 = {:.8};", model.b2).unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
// Tree nodes.
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"pub const TREE_NODES: [(u8, f32, u16, u16); {}] = [",
|
||||||
|
model.tree_nodes.len()
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
for &(feat, thresh, left, right) in &model.tree_nodes {
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
" ({}, {:.8}, {}, {}),",
|
||||||
|
feat, thresh, left, right
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
writeln!(s, "];").unwrap();
|
||||||
|
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate Lean 4 definitions for formal verification.
|
||||||
|
pub fn generate_lean_source(model: &ExportedModel) -> String {
|
||||||
|
let mut s = String::with_capacity(8192);
|
||||||
|
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"-- Auto-generated Lean 4 definitions for {} ensemble.",
|
||||||
|
model.model_name
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"-- DO NOT EDIT — regenerate with the training pipeline."
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"namespace Sunbeam.Ensemble.{}",
|
||||||
|
capitalize(&model.model_name)
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
writeln!(s, "def inputDim : Nat := {}", model.input_dim).unwrap();
|
||||||
|
writeln!(s, "def hiddenDim : Nat := {}", model.hidden_dim).unwrap();
|
||||||
|
writeln!(s, "def threshold : Float := {:.8}", model.threshold).unwrap();
|
||||||
|
writeln!(s, "def b2 : Float := {:.8}", model.b2).unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
write_lean_float_list(&mut s, "normMins", &model.norm_mins);
|
||||||
|
write_lean_float_list(&mut s, "normMaxs", &model.norm_maxs);
|
||||||
|
write_lean_float_list(&mut s, "b1", &model.b1);
|
||||||
|
write_lean_float_list(&mut s, "w2", &model.w2);
|
||||||
|
|
||||||
|
// W1 as list of lists.
|
||||||
|
writeln!(s, "def w1 : List (List Float) := [").unwrap();
|
||||||
|
for (i, row) in model.w1.iter().enumerate() {
|
||||||
|
let comma = if i + 1 < model.w1.len() { "," } else { "" };
|
||||||
|
write!(s, " [").unwrap();
|
||||||
|
for (j, v) in row.iter().enumerate() {
|
||||||
|
if j > 0 {
|
||||||
|
write!(s, ", ").unwrap();
|
||||||
|
}
|
||||||
|
write!(s, "{:.8}", v).unwrap();
|
||||||
|
}
|
||||||
|
writeln!(s, "]{}", comma).unwrap();
|
||||||
|
}
|
||||||
|
writeln!(s, "]").unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
// Tree nodes as list of tuples.
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"def treeNodes : List (Nat × Float × Nat × Nat) := ["
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
for (i, &(feat, thresh, left, right)) in model.tree_nodes.iter().enumerate() {
|
||||||
|
let comma = if i + 1 < model.tree_nodes.len() {
|
||||||
|
","
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
" ({}, {:.8}, {}, {}){}",
|
||||||
|
feat, thresh, left, right, comma
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
writeln!(s, "]").unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
|
writeln!(
|
||||||
|
s,
|
||||||
|
"end Sunbeam.Ensemble.{}",
|
||||||
|
capitalize(&model.model_name)
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
s
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write the generated Rust source to a file.
|
||||||
|
pub fn export_to_file(model: &ExportedModel, path: &Path) -> Result<()> {
|
||||||
|
let source = generate_rust_source(model);
|
||||||
|
std::fs::write(path, source.as_bytes())
|
||||||
|
.map_err(|e| anyhow::anyhow!("writing export to {}: {}", path.display(), e))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
|
||||||
|
writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap();
|
||||||
|
write!(s, " ").unwrap();
|
||||||
|
for (i, v) in values.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
write!(s, ", ").unwrap();
|
||||||
|
}
|
||||||
|
// Line-wrap every 8 values for readability.
|
||||||
|
if i > 0 && i % 8 == 0 {
|
||||||
|
write!(s, "\n ").unwrap();
|
||||||
|
}
|
||||||
|
write!(s, "{:.8}", v).unwrap();
|
||||||
|
}
|
||||||
|
writeln!(s, "\n];").unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_lean_float_list(s: &mut String, name: &str, values: &[f32]) {
|
||||||
|
write!(s, "def {} : List Float := [", name).unwrap();
|
||||||
|
for (i, v) in values.iter().enumerate() {
|
||||||
|
if i > 0 {
|
||||||
|
write!(s, ", ").unwrap();
|
||||||
|
}
|
||||||
|
write!(s, "{:.8}", v).unwrap();
|
||||||
|
}
|
||||||
|
writeln!(s, "]").unwrap();
|
||||||
|
writeln!(s).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn capitalize(s: &str) -> String {
|
||||||
|
let mut c = s.chars();
|
||||||
|
match c.next() {
|
||||||
|
None => String::new(),
|
||||||
|
Some(first) => first.to_uppercase().to_string() + c.as_str(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn make_test_model() -> ExportedModel {
|
||||||
|
ExportedModel {
|
||||||
|
model_name: "scanner".to_string(),
|
||||||
|
input_dim: 2,
|
||||||
|
hidden_dim: 2,
|
||||||
|
w1: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
|
||||||
|
b1: vec![0.01, 0.02],
|
||||||
|
w2: vec![0.5, 0.6],
|
||||||
|
b2: -0.1,
|
||||||
|
tree_nodes: vec![
|
||||||
|
(0, 0.5, 1, 2),
|
||||||
|
(255, 0.0, 0, 0),
|
||||||
|
(255, 1.0, 0, 0),
|
||||||
|
],
|
||||||
|
threshold: 0.5,
|
||||||
|
norm_mins: vec![0.0, 0.0],
|
||||||
|
norm_maxs: vec![1.0, 10.0],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rust_source_contains_consts() {
|
||||||
|
let model = make_test_model();
|
||||||
|
let src = generate_rust_source(&model);
|
||||||
|
|
||||||
|
assert!(src.contains("pub const THRESHOLD: f32 ="), "missing THRESHOLD");
|
||||||
|
assert!(src.contains("pub const NORM_MINS:"), "missing NORM_MINS");
|
||||||
|
assert!(src.contains("pub const NORM_MAXS:"), "missing NORM_MAXS");
|
||||||
|
assert!(src.contains("pub const W1:"), "missing W1");
|
||||||
|
assert!(src.contains("pub const B1:"), "missing B1");
|
||||||
|
assert!(src.contains("pub const W2:"), "missing W2");
|
||||||
|
assert!(src.contains("pub const B2: f32 ="), "missing B2");
|
||||||
|
assert!(src.contains("pub const TREE_NODES:"), "missing TREE_NODES");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rust_source_array_dims() {
|
||||||
|
let model = make_test_model();
|
||||||
|
let src = generate_rust_source(&model);
|
||||||
|
|
||||||
|
// W1 should be [f32; 2]; 2]
|
||||||
|
assert!(src.contains("[[f32; 2]; 2]"), "W1 dimensions wrong");
|
||||||
|
// B1 should be [f32; 2]
|
||||||
|
assert!(src.contains("B1: [f32; 2]"), "B1 dimensions wrong");
|
||||||
|
// W2 should be [f32; 2]
|
||||||
|
assert!(src.contains("W2: [f32; 2]"), "W2 dimensions wrong");
|
||||||
|
// TREE_NODES should have 3 entries
|
||||||
|
assert!(
|
||||||
|
src.contains("[(u8, f32, u16, u16); 3]"),
|
||||||
|
"TREE_NODES count wrong"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_weight_values_roundtrip() {
|
||||||
|
let model = make_test_model();
|
||||||
|
let src = generate_rust_source(&model);
|
||||||
|
|
||||||
|
// The threshold should appear with reasonable precision.
|
||||||
|
assert!(src.contains("0.50000000"), "threshold value missing");
|
||||||
|
// B2 value.
|
||||||
|
assert!(src.contains("-0.10000000"), "b2 value missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lean_source_structure() {
|
||||||
|
let model = make_test_model();
|
||||||
|
let src = generate_lean_source(&model);
|
||||||
|
|
||||||
|
assert!(src.contains("namespace Sunbeam.Ensemble.Scanner"));
|
||||||
|
assert!(src.contains("def inputDim : Nat := 2"));
|
||||||
|
assert!(src.contains("def hiddenDim : Nat := 2"));
|
||||||
|
assert!(src.contains("def threshold : Float :="));
|
||||||
|
assert!(src.contains("def normMins : List Float :="));
|
||||||
|
assert!(src.contains("def w1 : List (List Float) :="));
|
||||||
|
assert!(src.contains("def treeNodes :"));
|
||||||
|
assert!(src.contains("end Sunbeam.Ensemble.Scanner"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_export_to_file() {
|
||||||
|
let model = make_test_model();
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let path = dir.path().join("test_weights.rs");
|
||||||
|
|
||||||
|
export_to_file(&model, &path).unwrap();
|
||||||
|
|
||||||
|
let content = std::fs::read_to_string(&path).unwrap();
|
||||||
|
assert!(content.contains("pub const THRESHOLD:"));
|
||||||
|
assert!(content.contains("pub const W1:"));
|
||||||
|
}
|
||||||
|
}
|
||||||
113
src/training/mlp.rs
Normal file
113
src/training/mlp.rs
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
//! burn-rs MLP model definition for ensemble training.
|
||||||
|
//!
|
||||||
|
//! A two-layer network (linear -> ReLU -> linear -> sigmoid) used as the
|
||||||
|
//! "uncertain region" classifier in the tree+MLP ensemble.
|
||||||
|
|
||||||
|
use burn::module::Module;
|
||||||
|
use burn::nn::{Linear, LinearConfig};
|
||||||
|
use burn::prelude::*;
|
||||||
|
|
||||||
|
/// Two-layer MLP: input -> hidden (ReLU) -> output (sigmoid).
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct MlpModel<B: Backend> {
|
||||||
|
pub linear1: Linear<B>,
|
||||||
|
pub linear2: Linear<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for the MLP architecture.
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
pub struct MlpConfig {
|
||||||
|
/// Number of input features (12 for scanner, 14 for DDoS).
|
||||||
|
pub input_dim: usize,
|
||||||
|
/// Hidden layer width (typically 32).
|
||||||
|
pub hidden_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MlpConfig {
|
||||||
|
/// Initialize a new MLP model on the given device.
|
||||||
|
pub fn init<B: Backend>(&self, device: &B::Device) -> MlpModel<B> {
|
||||||
|
MlpModel {
|
||||||
|
linear1: LinearConfig::new(self.input_dim, self.hidden_dim).init(device),
|
||||||
|
linear2: LinearConfig::new(self.hidden_dim, 1).init(device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> MlpModel<B> {
|
||||||
|
/// Forward pass: ReLU hidden activation, sigmoid output.
|
||||||
|
///
|
||||||
|
/// Input shape: `[batch, input_dim]`
|
||||||
|
/// Output shape: `[batch, 1]`
|
||||||
|
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
let h = self.linear1.forward(x);
|
||||||
|
let h = burn::tensor::activation::relu(h);
|
||||||
|
let out = self.linear2.forward(h);
|
||||||
|
burn::tensor::activation::sigmoid(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
|
||||||
|
type TestBackend = NdArray<f32>;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_forward_pass_shape() {
|
||||||
|
let device = Default::default();
|
||||||
|
let config = MlpConfig {
|
||||||
|
input_dim: 12,
|
||||||
|
hidden_dim: 32,
|
||||||
|
};
|
||||||
|
let model = config.init::<TestBackend>(&device);
|
||||||
|
|
||||||
|
let batch_size = 8;
|
||||||
|
let input = Tensor::<TestBackend, 2>::zeros([batch_size, 12], &device);
|
||||||
|
let output = model.forward(input);
|
||||||
|
let shape = output.shape();
|
||||||
|
|
||||||
|
assert_eq!(shape.dims[0], batch_size);
|
||||||
|
assert_eq!(shape.dims[1], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_output_bounded() {
|
||||||
|
let device = Default::default();
|
||||||
|
let config = MlpConfig {
|
||||||
|
input_dim: 4,
|
||||||
|
hidden_dim: 16,
|
||||||
|
};
|
||||||
|
let model = config.init::<TestBackend>(&device);
|
||||||
|
|
||||||
|
// Random-ish input values.
|
||||||
|
let input = Tensor::<TestBackend, 2>::from_data(
|
||||||
|
[[1.0, -2.0, 0.5, 3.0], [0.0, 0.0, 0.0, 0.0]],
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let output = model.forward(input);
|
||||||
|
let data = output.to_data();
|
||||||
|
let values: Vec<f32> = data.to_vec().expect("flat vec");
|
||||||
|
|
||||||
|
for &v in &values {
|
||||||
|
assert!(
|
||||||
|
v >= 0.0 && v <= 1.0,
|
||||||
|
"sigmoid output should be in [0, 1], got {v}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ddos_input_dim() {
|
||||||
|
let device = Default::default();
|
||||||
|
let config = MlpConfig {
|
||||||
|
input_dim: 14,
|
||||||
|
hidden_dim: 32,
|
||||||
|
};
|
||||||
|
let model = config.init::<TestBackend>(&device);
|
||||||
|
|
||||||
|
let input = Tensor::<TestBackend, 2>::zeros([4, 14], &device);
|
||||||
|
let output = model.forward(input);
|
||||||
|
assert_eq!(output.shape().dims[1], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
5
src/training/mod.rs
Normal file
5
src/training/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
pub mod tree;
|
||||||
|
pub mod mlp;
|
||||||
|
pub mod export;
|
||||||
|
pub mod train_scanner;
|
||||||
|
pub mod train_ddos;
|
||||||
493
src/training/train_ddos.rs
Normal file
493
src/training/train_ddos.rs
Normal file
@@ -0,0 +1,493 @@
|
|||||||
|
//! DDoS MLP+tree training loop.
|
||||||
|
//!
|
||||||
|
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP,
|
||||||
|
//! then exports the combined ensemble weights as a Rust source file that can
|
||||||
|
//! be dropped into `src/ensemble/gen/ddos_weights.rs`.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use burn::backend::ndarray::NdArray;
|
||||||
|
use burn::backend::Autodiff;
|
||||||
|
use burn::module::AutodiffModule;
|
||||||
|
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||||
|
use burn::prelude::*;
|
||||||
|
|
||||||
|
use crate::dataset::sample::{load_dataset, TrainingSample};
|
||||||
|
use crate::training::export::{export_to_file, ExportedModel};
|
||||||
|
use crate::training::mlp::MlpConfig;
|
||||||
|
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
||||||
|
|
||||||
|
/// Number of DDoS features (matches `crate::ddos::features::NUM_FEATURES`).
|
||||||
|
const NUM_FEATURES: usize = 14;
|
||||||
|
|
||||||
|
type TrainBackend = Autodiff<NdArray<f32>>;
|
||||||
|
|
||||||
|
/// Arguments for the DDoS MLP training command.
|
||||||
|
pub struct TrainDdosMlpArgs {
|
||||||
|
/// Path to a bincode `DatasetManifest` file.
|
||||||
|
pub dataset_path: String,
|
||||||
|
/// Directory to write output files (Rust source, model record).
|
||||||
|
pub output_dir: String,
|
||||||
|
/// Hidden layer width (default 32).
|
||||||
|
pub hidden_dim: usize,
|
||||||
|
/// Number of training epochs (default 100).
|
||||||
|
pub epochs: usize,
|
||||||
|
/// Adam learning rate (default 0.001).
|
||||||
|
pub learning_rate: f64,
|
||||||
|
/// Mini-batch size (default 64).
|
||||||
|
pub batch_size: usize,
|
||||||
|
/// CART max depth (default 6).
|
||||||
|
pub tree_max_depth: usize,
|
||||||
|
/// CART leaf purity threshold (default 0.90).
|
||||||
|
pub tree_min_purity: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TrainDdosMlpArgs {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
dataset_path: String::new(),
|
||||||
|
output_dir: ".".into(),
|
||||||
|
hidden_dim: 32,
|
||||||
|
epochs: 100,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
batch_size: 64,
|
||||||
|
tree_max_depth: 6,
|
||||||
|
tree_min_purity: 0.90,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Entry point: train DDoS ensemble and export weights.
|
||||||
|
pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
||||||
|
// 1. Load dataset.
|
||||||
|
let manifest = load_dataset(Path::new(&args.dataset_path))
|
||||||
|
.context("loading dataset manifest")?;
|
||||||
|
|
||||||
|
let samples = &manifest.ddos_samples;
|
||||||
|
anyhow::ensure!(!samples.is_empty(), "no DDoS samples in dataset");
|
||||||
|
|
||||||
|
for s in samples {
|
||||||
|
anyhow::ensure!(
|
||||||
|
s.features.len() == NUM_FEATURES,
|
||||||
|
"expected {} features, got {}",
|
||||||
|
NUM_FEATURES,
|
||||||
|
s.features.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"[ddos] loaded {} samples ({} attack, {} normal)",
|
||||||
|
samples.len(),
|
||||||
|
samples.iter().filter(|s| s.label >= 0.5).count(),
|
||||||
|
samples.iter().filter(|s| s.label < 0.5).count(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Compute normalization params from training data.
|
||||||
|
let (norm_mins, norm_maxs) = compute_norm_params(samples);
|
||||||
|
|
||||||
|
// 3. Stratified 80/20 split.
|
||||||
|
let (train_set, val_set) = stratified_split(samples, 0.8);
|
||||||
|
println!(
|
||||||
|
"[ddos] train={}, val={}",
|
||||||
|
train_set.len(),
|
||||||
|
val_set.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
// 4. Train CART tree.
|
||||||
|
let tree_config = TreeConfig {
|
||||||
|
max_depth: args.tree_max_depth,
|
||||||
|
min_samples_leaf: 5,
|
||||||
|
min_purity: args.tree_min_purity,
|
||||||
|
num_features: NUM_FEATURES,
|
||||||
|
};
|
||||||
|
let tree_nodes = train_tree(&train_set, &tree_config);
|
||||||
|
println!("[ddos] CART tree: {} nodes", tree_nodes.len());
|
||||||
|
|
||||||
|
// Evaluate tree on validation set.
|
||||||
|
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
|
||||||
|
println!(
|
||||||
|
"[ddos] tree validation: {:.2}% correct (of decided), {:.1}% deferred",
|
||||||
|
tree_correct * 100.0,
|
||||||
|
tree_deferred * 100.0,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 5. Train MLP on the full training set.
|
||||||
|
let device = Default::default();
|
||||||
|
let mlp_config = MlpConfig {
|
||||||
|
input_dim: NUM_FEATURES,
|
||||||
|
hidden_dim: args.hidden_dim,
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = train_mlp(
|
||||||
|
&train_set,
|
||||||
|
&val_set,
|
||||||
|
&mlp_config,
|
||||||
|
&norm_mins,
|
||||||
|
&norm_maxs,
|
||||||
|
args.epochs,
|
||||||
|
args.learning_rate,
|
||||||
|
args.batch_size,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 6. Extract weights from trained model.
|
||||||
|
let exported = extract_weights(
|
||||||
|
&model,
|
||||||
|
"ddos",
|
||||||
|
&tree_nodes,
|
||||||
|
0.5, // threshold
|
||||||
|
&norm_mins,
|
||||||
|
&norm_maxs,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 7. Write output.
|
||||||
|
let out_dir = Path::new(&args.output_dir);
|
||||||
|
std::fs::create_dir_all(out_dir).context("creating output directory")?;
|
||||||
|
|
||||||
|
let rust_path = out_dir.join("ddos_weights.rs");
|
||||||
|
export_to_file(&exported, &rust_path)?;
|
||||||
|
println!("[ddos] exported Rust weights to {}", rust_path.display());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Normalization
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn compute_norm_params(samples: &[TrainingSample]) -> (Vec<f32>, Vec<f32>) {
|
||||||
|
let dim = samples[0].features.len();
|
||||||
|
let mut mins = vec![f32::MAX; dim];
|
||||||
|
let mut maxs = vec![f32::MIN; dim];
|
||||||
|
for s in samples {
|
||||||
|
for i in 0..dim {
|
||||||
|
mins[i] = mins[i].min(s.features[i]);
|
||||||
|
maxs[i] = maxs[i].max(s.features[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(mins, maxs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
||||||
|
features
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &v)| {
|
||||||
|
let range = maxs[i] - mins[i];
|
||||||
|
if range > f32::EPSILON {
|
||||||
|
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Stratified split
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec<TrainingSample>, Vec<TrainingSample>) {
|
||||||
|
let mut attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label >= 0.5).collect();
|
||||||
|
let mut normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label < 0.5).collect();
|
||||||
|
|
||||||
|
deterministic_shuffle(&mut attacks);
|
||||||
|
deterministic_shuffle(&mut normals);
|
||||||
|
|
||||||
|
let attack_split = (attacks.len() as f64 * train_ratio) as usize;
|
||||||
|
let normal_split = (normals.len() as f64 * train_ratio) as usize;
|
||||||
|
|
||||||
|
let mut train = Vec::new();
|
||||||
|
let mut val = Vec::new();
|
||||||
|
|
||||||
|
for (i, s) in attacks.iter().enumerate() {
|
||||||
|
if i < attack_split {
|
||||||
|
train.push((*s).clone());
|
||||||
|
} else {
|
||||||
|
val.push((*s).clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (i, s) in normals.iter().enumerate() {
|
||||||
|
if i < normal_split {
|
||||||
|
train.push((*s).clone());
|
||||||
|
} else {
|
||||||
|
val.push((*s).clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(train, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deterministic_shuffle<T>(items: &mut [T]) {
|
||||||
|
let mut rng = 42u64;
|
||||||
|
for i in (1..items.len()).rev() {
|
||||||
|
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||||
|
let j = (rng >> 33) as usize % (i + 1);
|
||||||
|
items.swap(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tree evaluation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn eval_tree(
|
||||||
|
nodes: &[(u8, f32, u16, u16)],
|
||||||
|
val_set: &[TrainingSample],
|
||||||
|
mins: &[f32],
|
||||||
|
maxs: &[f32],
|
||||||
|
) -> (f64, f64) {
|
||||||
|
let mut decided = 0usize;
|
||||||
|
let mut correct = 0usize;
|
||||||
|
let mut deferred = 0usize;
|
||||||
|
|
||||||
|
for s in val_set {
|
||||||
|
let normed = normalize_features(&s.features, mins, maxs);
|
||||||
|
let decision = tree_predict(nodes, &normed);
|
||||||
|
match decision {
|
||||||
|
TreeDecision::Defer => deferred += 1,
|
||||||
|
TreeDecision::Block => {
|
||||||
|
decided += 1;
|
||||||
|
if s.label >= 0.5 {
|
||||||
|
correct += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TreeDecision::Allow => {
|
||||||
|
decided += 1;
|
||||||
|
if s.label < 0.5 {
|
||||||
|
correct += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let accuracy = if decided > 0 {
|
||||||
|
correct as f64 / decided as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let defer_rate = deferred as f64 / val_set.len() as f64;
|
||||||
|
(accuracy, defer_rate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// MLP training
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn train_mlp(
|
||||||
|
train_set: &[TrainingSample],
|
||||||
|
val_set: &[TrainingSample],
|
||||||
|
config: &MlpConfig,
|
||||||
|
mins: &[f32],
|
||||||
|
maxs: &[f32],
|
||||||
|
epochs: usize,
|
||||||
|
learning_rate: f64,
|
||||||
|
batch_size: usize,
|
||||||
|
device: &<TrainBackend as Backend>::Device,
|
||||||
|
) -> crate::training::mlp::MlpModel<NdArray<f32>> {
|
||||||
|
let mut model = config.init::<TrainBackend>(device);
|
||||||
|
let mut optim = AdamConfig::new().init();
|
||||||
|
|
||||||
|
// Pre-normalize all training data.
|
||||||
|
let train_features: Vec<Vec<f32>> = train_set
|
||||||
|
.iter()
|
||||||
|
.map(|s| normalize_features(&s.features, mins, maxs))
|
||||||
|
.collect();
|
||||||
|
let train_labels: Vec<f32> = train_set.iter().map(|s| s.label).collect();
|
||||||
|
let train_weights: Vec<f32> = train_set.iter().map(|s| s.weight).collect();
|
||||||
|
|
||||||
|
let n = train_features.len();
|
||||||
|
|
||||||
|
for epoch in 0..epochs {
|
||||||
|
let mut epoch_loss = 0.0f32;
|
||||||
|
let mut batches = 0usize;
|
||||||
|
|
||||||
|
let mut offset = 0;
|
||||||
|
while offset < n {
|
||||||
|
let end = (offset + batch_size).min(n);
|
||||||
|
let batch_n = end - offset;
|
||||||
|
|
||||||
|
// Build input tensor [batch, features].
|
||||||
|
let flat: Vec<f32> = train_features[offset..end]
|
||||||
|
.iter()
|
||||||
|
.flat_map(|f| f.iter().copied())
|
||||||
|
.collect();
|
||||||
|
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
||||||
|
.reshape([batch_n, NUM_FEATURES]);
|
||||||
|
|
||||||
|
// Labels [batch, 1].
|
||||||
|
let y = Tensor::<TrainBackend, 1>::from_floats(
|
||||||
|
&train_labels[offset..end],
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.reshape([batch_n, 1]);
|
||||||
|
|
||||||
|
// Sample weights [batch, 1].
|
||||||
|
let w = Tensor::<TrainBackend, 1>::from_floats(
|
||||||
|
&train_weights[offset..end],
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.reshape([batch_n, 1]);
|
||||||
|
|
||||||
|
// Forward pass.
|
||||||
|
let pred = model.forward(x);
|
||||||
|
|
||||||
|
// Binary cross-entropy with sample weights.
|
||||||
|
let eps = 1e-7;
|
||||||
|
let pred_clamped = pred.clone().clamp(eps, 1.0 - eps);
|
||||||
|
let bce = (y.clone() * pred_clamped.clone().log()
|
||||||
|
+ (y.clone().neg().add_scalar(1.0))
|
||||||
|
* pred_clamped.neg().add_scalar(1.0).log())
|
||||||
|
.neg();
|
||||||
|
let weighted_bce = bce * w;
|
||||||
|
let loss = weighted_bce.mean();
|
||||||
|
|
||||||
|
epoch_loss += loss.clone().into_scalar().elem::<f32>();
|
||||||
|
batches += 1;
|
||||||
|
|
||||||
|
// Backward + optimizer step.
|
||||||
|
let grads = loss.backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &model);
|
||||||
|
model = optim.step(learning_rate, model, grads);
|
||||||
|
|
||||||
|
offset = end;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (epoch + 1) % 10 == 0 || epoch == 0 {
|
||||||
|
let avg_loss = epoch_loss / batches as f32;
|
||||||
|
let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device);
|
||||||
|
println!(
|
||||||
|
"[ddos] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}",
|
||||||
|
epoch + 1,
|
||||||
|
epochs,
|
||||||
|
avg_loss,
|
||||||
|
val_acc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model.valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn eval_mlp_accuracy(
|
||||||
|
model: &crate::training::mlp::MlpModel<TrainBackend>,
|
||||||
|
val_set: &[TrainingSample],
|
||||||
|
mins: &[f32],
|
||||||
|
maxs: &[f32],
|
||||||
|
device: &<TrainBackend as Backend>::Device,
|
||||||
|
) -> f64 {
|
||||||
|
let flat: Vec<f32> = val_set
|
||||||
|
.iter()
|
||||||
|
.flat_map(|s| normalize_features(&s.features, mins, maxs))
|
||||||
|
.collect();
|
||||||
|
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
||||||
|
.reshape([val_set.len(), NUM_FEATURES]);
|
||||||
|
|
||||||
|
let pred = model.forward(x);
|
||||||
|
let pred_data: Vec<f32> = pred.to_data().to_vec().expect("flat vec");
|
||||||
|
|
||||||
|
let mut correct = 0usize;
|
||||||
|
for (i, s) in val_set.iter().enumerate() {
|
||||||
|
let p = pred_data[i];
|
||||||
|
let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 };
|
||||||
|
if (predicted_label - s.label).abs() < 0.1 {
|
||||||
|
correct += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
correct as f64 / val_set.len() as f64
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Weight extraction
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn extract_weights(
|
||||||
|
model: &crate::training::mlp::MlpModel<NdArray<f32>>,
|
||||||
|
name: &str,
|
||||||
|
tree_nodes: &[(u8, f32, u16, u16)],
|
||||||
|
threshold: f32,
|
||||||
|
norm_mins: &[f32],
|
||||||
|
norm_maxs: &[f32],
|
||||||
|
_device: &<NdArray<f32> as Backend>::Device,
|
||||||
|
) -> ExportedModel {
|
||||||
|
let w1_tensor = model.linear1.weight.val();
|
||||||
|
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
|
||||||
|
let w2_tensor = model.linear2.weight.val();
|
||||||
|
let b2_tensor = model.linear2.bias.as_ref().expect("linear2 has bias").val();
|
||||||
|
|
||||||
|
let w1_data: Vec<f32> = w1_tensor.to_data().to_vec().expect("w1 flat");
|
||||||
|
let b1_data: Vec<f32> = b1_tensor.to_data().to_vec().expect("b1 flat");
|
||||||
|
let w2_data: Vec<f32> = w2_tensor.to_data().to_vec().expect("w2 flat");
|
||||||
|
let b2_data: Vec<f32> = b2_tensor.to_data().to_vec().expect("b2 flat");
|
||||||
|
|
||||||
|
let hidden_dim = b1_data.len();
|
||||||
|
let input_dim = w1_data.len() / hidden_dim;
|
||||||
|
|
||||||
|
let w1: Vec<Vec<f32>> = (0..hidden_dim)
|
||||||
|
.map(|h| w1_data[h * input_dim..(h + 1) * input_dim].to_vec())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
ExportedModel {
|
||||||
|
model_name: name.to_string(),
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
w1,
|
||||||
|
b1: b1_data,
|
||||||
|
w2: w2_data,
|
||||||
|
b2: b2_data[0],
|
||||||
|
tree_nodes: tree_nodes.to_vec(),
|
||||||
|
threshold,
|
||||||
|
norm_mins: norm_mins.to_vec(),
|
||||||
|
norm_maxs: norm_maxs.to_vec(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::dataset::sample::{DataSource, TrainingSample};
|
||||||
|
|
||||||
|
fn make_ddos_sample(features: [f32; 14], label: f32) -> TrainingSample {
|
||||||
|
TrainingSample {
|
||||||
|
features: features.to_vec(),
|
||||||
|
label,
|
||||||
|
source: DataSource::ProductionLogs,
|
||||||
|
weight: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stratified_split_preserves_ratio() {
|
||||||
|
let mut samples = Vec::new();
|
||||||
|
for _ in 0..80 {
|
||||||
|
samples.push(make_ddos_sample([0.0; 14], 0.0));
|
||||||
|
}
|
||||||
|
for _ in 0..20 {
|
||||||
|
samples.push(make_ddos_sample([1.0; 14], 1.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (train, val) = stratified_split(&samples, 0.8);
|
||||||
|
|
||||||
|
let train_attacks = train.iter().filter(|s| s.label >= 0.5).count();
|
||||||
|
let val_attacks = val.iter().filter(|s| s.label >= 0.5).count();
|
||||||
|
|
||||||
|
assert_eq!(train_attacks, 16);
|
||||||
|
assert_eq!(val_attacks, 4);
|
||||||
|
assert_eq!(train.len() + val.len(), 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_norm_params_14_features() {
|
||||||
|
let samples = vec![
|
||||||
|
make_ddos_sample([0.0; 14], 0.0),
|
||||||
|
make_ddos_sample([1.0; 14], 1.0),
|
||||||
|
];
|
||||||
|
let (mins, maxs) = compute_norm_params(&samples);
|
||||||
|
assert_eq!(mins.len(), 14);
|
||||||
|
assert_eq!(maxs.len(), 14);
|
||||||
|
assert_eq!(mins[0], 0.0);
|
||||||
|
assert_eq!(maxs[0], 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
515
src/training/train_scanner.rs
Normal file
515
src/training/train_scanner.rs
Normal file
@@ -0,0 +1,515 @@
|
|||||||
|
//! Scanner MLP+tree training loop.
|
||||||
|
//!
|
||||||
|
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP,
|
||||||
|
//! then exports the combined ensemble weights as a Rust source file that can
|
||||||
|
//! be dropped into `src/ensemble/gen/scanner_weights.rs`.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use burn::backend::ndarray::NdArray;
|
||||||
|
use burn::backend::Autodiff;
|
||||||
|
use burn::module::AutodiffModule;
|
||||||
|
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||||
|
use burn::prelude::*;
|
||||||
|
|
||||||
|
use crate::dataset::sample::{load_dataset, TrainingSample};
|
||||||
|
use crate::training::export::{export_to_file, ExportedModel};
|
||||||
|
use crate::training::mlp::MlpConfig;
|
||||||
|
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
||||||
|
|
||||||
|
/// Number of scanner features (matches `crate::scanner::features::NUM_SCANNER_FEATURES`).
|
||||||
|
const NUM_FEATURES: usize = 12;
|
||||||
|
|
||||||
|
type TrainBackend = Autodiff<NdArray<f32>>;
|
||||||
|
|
||||||
|
/// Arguments for the scanner MLP training command.
|
||||||
|
pub struct TrainScannerMlpArgs {
|
||||||
|
/// Path to a bincode `DatasetManifest` file.
|
||||||
|
pub dataset_path: String,
|
||||||
|
/// Directory to write output files (Rust source, model record).
|
||||||
|
pub output_dir: String,
|
||||||
|
/// Hidden layer width (default 32).
|
||||||
|
pub hidden_dim: usize,
|
||||||
|
/// Number of training epochs (default 100).
|
||||||
|
pub epochs: usize,
|
||||||
|
/// Adam learning rate (default 0.001).
|
||||||
|
pub learning_rate: f64,
|
||||||
|
/// Mini-batch size (default 64).
|
||||||
|
pub batch_size: usize,
|
||||||
|
/// CART max depth (default 6).
|
||||||
|
pub tree_max_depth: usize,
|
||||||
|
/// CART leaf purity threshold (default 0.90).
|
||||||
|
pub tree_min_purity: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TrainScannerMlpArgs {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
dataset_path: String::new(),
|
||||||
|
output_dir: ".".into(),
|
||||||
|
hidden_dim: 32,
|
||||||
|
epochs: 100,
|
||||||
|
learning_rate: 0.001,
|
||||||
|
batch_size: 64,
|
||||||
|
tree_max_depth: 6,
|
||||||
|
tree_min_purity: 0.90,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Entry point: train scanner ensemble and export weights.
|
||||||
|
pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
||||||
|
// 1. Load dataset.
|
||||||
|
let manifest = load_dataset(Path::new(&args.dataset_path))
|
||||||
|
.context("loading dataset manifest")?;
|
||||||
|
|
||||||
|
let samples = &manifest.scanner_samples;
|
||||||
|
anyhow::ensure!(!samples.is_empty(), "no scanner samples in dataset");
|
||||||
|
|
||||||
|
for s in samples {
|
||||||
|
anyhow::ensure!(
|
||||||
|
s.features.len() == NUM_FEATURES,
|
||||||
|
"expected {} features, got {}",
|
||||||
|
NUM_FEATURES,
|
||||||
|
s.features.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"[scanner] loaded {} samples ({} attack, {} normal)",
|
||||||
|
samples.len(),
|
||||||
|
samples.iter().filter(|s| s.label >= 0.5).count(),
|
||||||
|
samples.iter().filter(|s| s.label < 0.5).count(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Compute normalization params from training data.
|
||||||
|
let (norm_mins, norm_maxs) = compute_norm_params(samples);
|
||||||
|
|
||||||
|
// 3. Stratified 80/20 split.
|
||||||
|
let (train_set, val_set) = stratified_split(samples, 0.8);
|
||||||
|
println!(
|
||||||
|
"[scanner] train={}, val={}",
|
||||||
|
train_set.len(),
|
||||||
|
val_set.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
// 4. Train CART tree.
|
||||||
|
let tree_config = TreeConfig {
|
||||||
|
max_depth: args.tree_max_depth,
|
||||||
|
min_samples_leaf: 5,
|
||||||
|
min_purity: args.tree_min_purity,
|
||||||
|
num_features: NUM_FEATURES,
|
||||||
|
};
|
||||||
|
let tree_nodes = train_tree(&train_set, &tree_config);
|
||||||
|
println!("[scanner] CART tree: {} nodes", tree_nodes.len());
|
||||||
|
|
||||||
|
// Evaluate tree on validation set.
|
||||||
|
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
|
||||||
|
println!(
|
||||||
|
"[scanner] tree validation: {:.2}% correct (of decided), {:.1}% deferred",
|
||||||
|
tree_correct * 100.0,
|
||||||
|
tree_deferred * 100.0,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 5. Train MLP on the full training set (the MLP only fires on Defer
|
||||||
|
// at inference time, but we train it on all data so it learns the
|
||||||
|
// full decision boundary).
|
||||||
|
let device = Default::default();
|
||||||
|
let mlp_config = MlpConfig {
|
||||||
|
input_dim: NUM_FEATURES,
|
||||||
|
hidden_dim: args.hidden_dim,
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = train_mlp(
|
||||||
|
&train_set,
|
||||||
|
&val_set,
|
||||||
|
&mlp_config,
|
||||||
|
&norm_mins,
|
||||||
|
&norm_maxs,
|
||||||
|
args.epochs,
|
||||||
|
args.learning_rate,
|
||||||
|
args.batch_size,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 6. Extract weights from trained model.
|
||||||
|
let exported = extract_weights(
|
||||||
|
&model,
|
||||||
|
"scanner",
|
||||||
|
&tree_nodes,
|
||||||
|
0.5, // threshold
|
||||||
|
&norm_mins,
|
||||||
|
&norm_maxs,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 7. Write output.
|
||||||
|
let out_dir = Path::new(&args.output_dir);
|
||||||
|
std::fs::create_dir_all(out_dir).context("creating output directory")?;
|
||||||
|
|
||||||
|
let rust_path = out_dir.join("scanner_weights.rs");
|
||||||
|
export_to_file(&exported, &rust_path)?;
|
||||||
|
println!("[scanner] exported Rust weights to {}", rust_path.display());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Normalization
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn compute_norm_params(samples: &[TrainingSample]) -> (Vec<f32>, Vec<f32>) {
|
||||||
|
let dim = samples[0].features.len();
|
||||||
|
let mut mins = vec![f32::MAX; dim];
|
||||||
|
let mut maxs = vec![f32::MIN; dim];
|
||||||
|
for s in samples {
|
||||||
|
for i in 0..dim {
|
||||||
|
mins[i] = mins[i].min(s.features[i]);
|
||||||
|
maxs[i] = maxs[i].max(s.features[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(mins, maxs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
||||||
|
features
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &v)| {
|
||||||
|
let range = maxs[i] - mins[i];
|
||||||
|
if range > f32::EPSILON {
|
||||||
|
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Stratified split
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec<TrainingSample>, Vec<TrainingSample>) {
|
||||||
|
let mut attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label >= 0.5).collect();
|
||||||
|
let mut normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label < 0.5).collect();
|
||||||
|
|
||||||
|
// Deterministic shuffle using a simple index permutation seeded by length.
|
||||||
|
deterministic_shuffle(&mut attacks);
|
||||||
|
deterministic_shuffle(&mut normals);
|
||||||
|
|
||||||
|
let attack_split = (attacks.len() as f64 * train_ratio) as usize;
|
||||||
|
let normal_split = (normals.len() as f64 * train_ratio) as usize;
|
||||||
|
|
||||||
|
let mut train = Vec::new();
|
||||||
|
let mut val = Vec::new();
|
||||||
|
|
||||||
|
for (i, s) in attacks.iter().enumerate() {
|
||||||
|
if i < attack_split {
|
||||||
|
train.push((*s).clone());
|
||||||
|
} else {
|
||||||
|
val.push((*s).clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (i, s) in normals.iter().enumerate() {
|
||||||
|
if i < normal_split {
|
||||||
|
train.push((*s).clone());
|
||||||
|
} else {
|
||||||
|
val.push((*s).clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(train, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deterministic_shuffle<T>(items: &mut [T]) {
|
||||||
|
// Simple Fisher-Yates with a fixed LCG seed for reproducibility.
|
||||||
|
let mut rng = 42u64;
|
||||||
|
for i in (1..items.len()).rev() {
|
||||||
|
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||||
|
let j = (rng >> 33) as usize % (i + 1);
|
||||||
|
items.swap(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tree evaluation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn eval_tree(
|
||||||
|
nodes: &[(u8, f32, u16, u16)],
|
||||||
|
val_set: &[TrainingSample],
|
||||||
|
mins: &[f32],
|
||||||
|
maxs: &[f32],
|
||||||
|
) -> (f64, f64) {
|
||||||
|
let mut decided = 0usize;
|
||||||
|
let mut correct = 0usize;
|
||||||
|
let mut deferred = 0usize;
|
||||||
|
|
||||||
|
for s in val_set {
|
||||||
|
let normed = normalize_features(&s.features, mins, maxs);
|
||||||
|
let decision = tree_predict(nodes, &normed);
|
||||||
|
match decision {
|
||||||
|
TreeDecision::Defer => deferred += 1,
|
||||||
|
TreeDecision::Block => {
|
||||||
|
decided += 1;
|
||||||
|
if s.label >= 0.5 {
|
||||||
|
correct += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TreeDecision::Allow => {
|
||||||
|
decided += 1;
|
||||||
|
if s.label < 0.5 {
|
||||||
|
correct += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let accuracy = if decided > 0 {
|
||||||
|
correct as f64 / decided as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let defer_rate = deferred as f64 / val_set.len() as f64;
|
||||||
|
(accuracy, defer_rate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// MLP training
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn train_mlp(
|
||||||
|
train_set: &[TrainingSample],
|
||||||
|
val_set: &[TrainingSample],
|
||||||
|
config: &MlpConfig,
|
||||||
|
mins: &[f32],
|
||||||
|
maxs: &[f32],
|
||||||
|
epochs: usize,
|
||||||
|
learning_rate: f64,
|
||||||
|
batch_size: usize,
|
||||||
|
device: &<TrainBackend as Backend>::Device,
|
||||||
|
) -> crate::training::mlp::MlpModel<NdArray<f32>> {
|
||||||
|
let mut model = config.init::<TrainBackend>(device);
|
||||||
|
let mut optim = AdamConfig::new().init();
|
||||||
|
|
||||||
|
// Pre-normalize all training data.
|
||||||
|
let train_features: Vec<Vec<f32>> = train_set
|
||||||
|
.iter()
|
||||||
|
.map(|s| normalize_features(&s.features, mins, maxs))
|
||||||
|
.collect();
|
||||||
|
let train_labels: Vec<f32> = train_set.iter().map(|s| s.label).collect();
|
||||||
|
let train_weights: Vec<f32> = train_set.iter().map(|s| s.weight).collect();
|
||||||
|
|
||||||
|
let n = train_features.len();
|
||||||
|
|
||||||
|
for epoch in 0..epochs {
|
||||||
|
let mut epoch_loss = 0.0f32;
|
||||||
|
let mut batches = 0usize;
|
||||||
|
|
||||||
|
let mut offset = 0;
|
||||||
|
while offset < n {
|
||||||
|
let end = (offset + batch_size).min(n);
|
||||||
|
let batch_n = end - offset;
|
||||||
|
|
||||||
|
// Build input tensor [batch, features].
|
||||||
|
let flat: Vec<f32> = train_features[offset..end]
|
||||||
|
.iter()
|
||||||
|
.flat_map(|f| f.iter().copied())
|
||||||
|
.collect();
|
||||||
|
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
||||||
|
.reshape([batch_n, NUM_FEATURES]);
|
||||||
|
|
||||||
|
// Labels [batch, 1].
|
||||||
|
let y = Tensor::<TrainBackend, 1>::from_floats(
|
||||||
|
&train_labels[offset..end],
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.reshape([batch_n, 1]);
|
||||||
|
|
||||||
|
// Sample weights [batch, 1].
|
||||||
|
let w = Tensor::<TrainBackend, 1>::from_floats(
|
||||||
|
&train_weights[offset..end],
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
.reshape([batch_n, 1]);
|
||||||
|
|
||||||
|
// Forward pass.
|
||||||
|
let pred = model.forward(x);
|
||||||
|
|
||||||
|
// Binary cross-entropy with sample weights:
|
||||||
|
// loss = -w * [y * log(p) + (1-y) * log(1-p)]
|
||||||
|
let eps = 1e-7;
|
||||||
|
let pred_clamped = pred.clone().clamp(eps, 1.0 - eps);
|
||||||
|
let bce = (y.clone() * pred_clamped.clone().log()
|
||||||
|
+ (y.clone().neg().add_scalar(1.0))
|
||||||
|
* pred_clamped.neg().add_scalar(1.0).log())
|
||||||
|
.neg();
|
||||||
|
let weighted_bce = bce * w;
|
||||||
|
let loss = weighted_bce.mean();
|
||||||
|
|
||||||
|
epoch_loss += loss.clone().into_scalar().elem::<f32>();
|
||||||
|
batches += 1;
|
||||||
|
|
||||||
|
// Backward + optimizer step.
|
||||||
|
let grads = loss.backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &model);
|
||||||
|
model = optim.step(learning_rate, model, grads);
|
||||||
|
|
||||||
|
offset = end;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (epoch + 1) % 10 == 0 || epoch == 0 {
|
||||||
|
let avg_loss = epoch_loss / batches as f32;
|
||||||
|
let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device);
|
||||||
|
println!(
|
||||||
|
"[scanner] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}",
|
||||||
|
epoch + 1,
|
||||||
|
epochs,
|
||||||
|
avg_loss,
|
||||||
|
val_acc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the inner (non-autodiff) model for weight extraction.
|
||||||
|
model.valid()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn eval_mlp_accuracy(
|
||||||
|
model: &crate::training::mlp::MlpModel<TrainBackend>,
|
||||||
|
val_set: &[TrainingSample],
|
||||||
|
mins: &[f32],
|
||||||
|
maxs: &[f32],
|
||||||
|
device: &<TrainBackend as Backend>::Device,
|
||||||
|
) -> f64 {
|
||||||
|
let flat: Vec<f32> = val_set
|
||||||
|
.iter()
|
||||||
|
.flat_map(|s| normalize_features(&s.features, mins, maxs))
|
||||||
|
.collect();
|
||||||
|
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
||||||
|
.reshape([val_set.len(), NUM_FEATURES]);
|
||||||
|
|
||||||
|
let pred = model.forward(x);
|
||||||
|
let pred_data: Vec<f32> = pred.to_data().to_vec().expect("flat vec");
|
||||||
|
|
||||||
|
let mut correct = 0usize;
|
||||||
|
for (i, s) in val_set.iter().enumerate() {
|
||||||
|
let p = pred_data[i];
|
||||||
|
let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 };
|
||||||
|
if (predicted_label - s.label).abs() < 0.1 {
|
||||||
|
correct += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
correct as f64 / val_set.len() as f64
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Weight extraction
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn extract_weights(
|
||||||
|
model: &crate::training::mlp::MlpModel<NdArray<f32>>,
|
||||||
|
name: &str,
|
||||||
|
tree_nodes: &[(u8, f32, u16, u16)],
|
||||||
|
threshold: f32,
|
||||||
|
norm_mins: &[f32],
|
||||||
|
norm_maxs: &[f32],
|
||||||
|
_device: &<NdArray<f32> as Backend>::Device,
|
||||||
|
) -> ExportedModel {
|
||||||
|
// Extract weight tensors from the model.
|
||||||
|
// linear1.weight: [hidden_dim, input_dim]
|
||||||
|
// linear1.bias: [hidden_dim]
|
||||||
|
// linear2.weight: [1, hidden_dim]
|
||||||
|
// linear2.bias: [1]
|
||||||
|
let w1_tensor = model.linear1.weight.val();
|
||||||
|
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
|
||||||
|
let w2_tensor = model.linear2.weight.val();
|
||||||
|
let b2_tensor = model.linear2.bias.as_ref().expect("linear2 has bias").val();
|
||||||
|
|
||||||
|
let w1_data: Vec<f32> = w1_tensor.to_data().to_vec().expect("w1 flat");
|
||||||
|
let b1_data: Vec<f32> = b1_tensor.to_data().to_vec().expect("b1 flat");
|
||||||
|
let w2_data: Vec<f32> = w2_tensor.to_data().to_vec().expect("w2 flat");
|
||||||
|
let b2_data: Vec<f32> = b2_tensor.to_data().to_vec().expect("b2 flat");
|
||||||
|
|
||||||
|
let hidden_dim = b1_data.len();
|
||||||
|
let input_dim = w1_data.len() / hidden_dim;
|
||||||
|
|
||||||
|
// Reshape W1 into [hidden_dim][input_dim].
|
||||||
|
let w1: Vec<Vec<f32>> = (0..hidden_dim)
|
||||||
|
.map(|h| w1_data[h * input_dim..(h + 1) * input_dim].to_vec())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
ExportedModel {
|
||||||
|
model_name: name.to_string(),
|
||||||
|
input_dim,
|
||||||
|
hidden_dim,
|
||||||
|
w1,
|
||||||
|
b1: b1_data,
|
||||||
|
w2: w2_data,
|
||||||
|
b2: b2_data[0],
|
||||||
|
tree_nodes: tree_nodes.to_vec(),
|
||||||
|
threshold,
|
||||||
|
norm_mins: norm_mins.to_vec(),
|
||||||
|
norm_maxs: norm_maxs.to_vec(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::dataset::sample::{DataSource, TrainingSample};
|
||||||
|
|
||||||
|
fn make_scanner_sample(features: [f32; 12], label: f32) -> TrainingSample {
|
||||||
|
TrainingSample {
|
||||||
|
features: features.to_vec(),
|
||||||
|
label,
|
||||||
|
source: DataSource::ProductionLogs,
|
||||||
|
weight: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stratified_split_preserves_ratio() {
|
||||||
|
let mut samples = Vec::new();
|
||||||
|
for _ in 0..80 {
|
||||||
|
samples.push(make_scanner_sample([0.0; 12], 0.0));
|
||||||
|
}
|
||||||
|
for _ in 0..20 {
|
||||||
|
samples.push(make_scanner_sample([1.0; 12], 1.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (train, val) = stratified_split(&samples, 0.8);
|
||||||
|
|
||||||
|
let train_attacks = train.iter().filter(|s| s.label >= 0.5).count();
|
||||||
|
let val_attacks = val.iter().filter(|s| s.label >= 0.5).count();
|
||||||
|
|
||||||
|
// Should preserve the 80/20 attack ratio approximately.
|
||||||
|
assert_eq!(train_attacks, 16); // 80% of 20
|
||||||
|
assert_eq!(val_attacks, 4); // 20% of 20
|
||||||
|
assert_eq!(train.len() + val.len(), 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_norm_params() {
|
||||||
|
let samples = vec![
|
||||||
|
make_scanner_sample([0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 0.0),
|
||||||
|
make_scanner_sample([1.0, 20.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 1.0),
|
||||||
|
];
|
||||||
|
let (mins, maxs) = compute_norm_params(&samples);
|
||||||
|
assert_eq!(mins[0], 0.0);
|
||||||
|
assert_eq!(maxs[0], 1.0);
|
||||||
|
assert_eq!(mins[1], 10.0);
|
||||||
|
assert_eq!(maxs[1], 20.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_features() {
|
||||||
|
let mins = vec![0.0, 10.0];
|
||||||
|
let maxs = vec![1.0, 20.0];
|
||||||
|
let normed = normalize_features(&[0.5, 15.0], &mins, &maxs);
|
||||||
|
assert!((normed[0] - 0.5).abs() < 1e-6);
|
||||||
|
assert!((normed[1] - 0.5).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
409
src/training/tree.rs
Normal file
409
src/training/tree.rs
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
//! CART decision tree trainer (pure Rust, no burn dependency).
|
||||||
|
//!
|
||||||
|
//! Trains a binary classification tree using Gini impurity and outputs
|
||||||
|
//! the packed node format used by `crate::ensemble::tree` for zero-alloc
|
||||||
|
//! inference.
|
||||||
|
|
||||||
|
use crate::dataset::sample::TrainingSample;
|
||||||
|
|
||||||
|
/// Packed tree node matching the inference format in `crate::ensemble::tree`.
|
||||||
|
///
|
||||||
|
/// `(feature_index, threshold, left_child, right_child)`
|
||||||
|
///
|
||||||
|
/// Leaf nodes use `feature_index = 255`. The threshold encodes the decision:
|
||||||
|
/// - `0.0` = Allow
|
||||||
|
/// - `0.5` = Defer
|
||||||
|
/// - `1.0` = Block
|
||||||
|
pub type PackedNode = (u8, f32, u16, u16);
|
||||||
|
|
||||||
|
/// Decision from a tree leaf node.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum TreeDecision {
|
||||||
|
Block,
|
||||||
|
Allow,
|
||||||
|
Defer,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for CART tree training.
|
||||||
|
pub struct TreeConfig {
|
||||||
|
/// Maximum tree depth (typically 6-8).
|
||||||
|
pub max_depth: usize,
|
||||||
|
/// Minimum number of samples required in a leaf.
|
||||||
|
pub min_samples_leaf: usize,
|
||||||
|
/// Leaf purity threshold: if the dominant class ratio is below this,
|
||||||
|
/// the leaf becomes `Defer` (e.g. 0.90).
|
||||||
|
pub min_purity: f32,
|
||||||
|
/// Number of input features (12 for scanner, 14 for DDoS).
|
||||||
|
pub num_features: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal representation during tree construction.
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum BuildNode {
|
||||||
|
Leaf {
|
||||||
|
decision: TreeDecision,
|
||||||
|
},
|
||||||
|
Split {
|
||||||
|
feature: usize,
|
||||||
|
threshold: f32,
|
||||||
|
left: Box<BuildNode>,
|
||||||
|
right: Box<BuildNode>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Train a CART decision tree and return packed nodes.
|
||||||
|
pub fn train_tree(samples: &[TrainingSample], config: &TreeConfig) -> Vec<PackedNode> {
|
||||||
|
let indices: Vec<usize> = (0..samples.len()).collect();
|
||||||
|
let root = build_node(samples, &indices, config, 0);
|
||||||
|
let mut packed = Vec::new();
|
||||||
|
flatten(&root, &mut packed);
|
||||||
|
packed
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Walk a packed decision tree for validation (mirrors `crate::ensemble::tree::tree_predict`).
|
||||||
|
pub fn tree_predict(nodes: &[PackedNode], features: &[f32]) -> TreeDecision {
|
||||||
|
let mut idx = 0usize;
|
||||||
|
loop {
|
||||||
|
let (feature, threshold, left, right) = nodes[idx];
|
||||||
|
if feature == 255 {
|
||||||
|
return if threshold < 0.25 {
|
||||||
|
TreeDecision::Allow
|
||||||
|
} else if threshold > 0.75 {
|
||||||
|
TreeDecision::Block
|
||||||
|
} else {
|
||||||
|
TreeDecision::Defer
|
||||||
|
};
|
||||||
|
}
|
||||||
|
idx = if features[feature as usize] <= threshold {
|
||||||
|
left as usize
|
||||||
|
} else {
|
||||||
|
right as usize
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tree construction
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn build_node(
|
||||||
|
samples: &[TrainingSample],
|
||||||
|
indices: &[usize],
|
||||||
|
config: &TreeConfig,
|
||||||
|
depth: usize,
|
||||||
|
) -> BuildNode {
|
||||||
|
// Count attacks vs normal.
|
||||||
|
let (attack_w, normal_w) = weighted_counts(samples, indices);
|
||||||
|
let total_w = attack_w + normal_w;
|
||||||
|
|
||||||
|
// Stopping conditions: max depth, min samples, or pure-enough leaf.
|
||||||
|
if depth >= config.max_depth
|
||||||
|
|| indices.len() < 2 * config.min_samples_leaf
|
||||||
|
|| total_w < f32::EPSILON
|
||||||
|
{
|
||||||
|
return make_leaf(attack_w, normal_w, config.min_purity);
|
||||||
|
}
|
||||||
|
|
||||||
|
let attack_ratio = attack_w / total_w;
|
||||||
|
let normal_ratio = normal_w / total_w;
|
||||||
|
if attack_ratio >= config.min_purity || normal_ratio >= config.min_purity {
|
||||||
|
return make_leaf(attack_w, normal_w, config.min_purity);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find best split across all features.
|
||||||
|
let parent_gini = gini(attack_w, normal_w);
|
||||||
|
let mut best_gain = 0.0f32;
|
||||||
|
let mut best_feature = 0usize;
|
||||||
|
let mut best_threshold = 0.0f32;
|
||||||
|
let mut best_left: Vec<usize> = Vec::new();
|
||||||
|
let mut best_right: Vec<usize> = Vec::new();
|
||||||
|
|
||||||
|
for feat in 0..config.num_features {
|
||||||
|
// Gather and sort feature values.
|
||||||
|
let mut vals: Vec<(f32, usize)> = indices
|
||||||
|
.iter()
|
||||||
|
.map(|&i| (samples[i].features[feat], i))
|
||||||
|
.collect();
|
||||||
|
vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
|
||||||
|
// Scan for best threshold (midpoints between distinct values).
|
||||||
|
let mut left_attack_w = 0.0f32;
|
||||||
|
let mut left_normal_w = 0.0f32;
|
||||||
|
|
||||||
|
for window_end in 0..vals.len() - 1 {
|
||||||
|
let (_, idx) = vals[window_end];
|
||||||
|
let s = &samples[idx];
|
||||||
|
if s.label >= 0.5 {
|
||||||
|
left_attack_w += s.weight;
|
||||||
|
} else {
|
||||||
|
left_normal_w += s.weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if the next value is the same (no valid split point).
|
||||||
|
if (vals[window_end].0 - vals[window_end + 1].0).abs() < f32::EPSILON {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check min_samples_leaf constraint.
|
||||||
|
let left_count = window_end + 1;
|
||||||
|
let right_count = vals.len() - left_count;
|
||||||
|
if left_count < config.min_samples_leaf || right_count < config.min_samples_leaf {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let right_attack_w = attack_w - left_attack_w;
|
||||||
|
let right_normal_w = normal_w - left_normal_w;
|
||||||
|
let left_total = left_attack_w + left_normal_w;
|
||||||
|
let right_total = right_attack_w + right_normal_w;
|
||||||
|
|
||||||
|
let left_gini = gini(left_attack_w, left_normal_w);
|
||||||
|
let right_gini = gini(right_attack_w, right_normal_w);
|
||||||
|
let weighted_gini =
|
||||||
|
(left_total / total_w) * left_gini + (right_total / total_w) * right_gini;
|
||||||
|
let gain = parent_gini - weighted_gini;
|
||||||
|
|
||||||
|
if gain > best_gain {
|
||||||
|
best_gain = gain;
|
||||||
|
best_feature = feat;
|
||||||
|
best_threshold = (vals[window_end].0 + vals[window_end + 1].0) / 2.0;
|
||||||
|
best_left = vals[..=window_end].iter().map(|v| v.1).collect();
|
||||||
|
best_right = vals[window_end + 1..].iter().map(|v| v.1).collect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no informative split was found, make a leaf.
|
||||||
|
if best_gain <= 0.0 || best_left.is_empty() || best_right.is_empty() {
|
||||||
|
return make_leaf(attack_w, normal_w, config.min_purity);
|
||||||
|
}
|
||||||
|
|
||||||
|
let left_child = build_node(samples, &best_left, config, depth + 1);
|
||||||
|
let right_child = build_node(samples, &best_right, config, depth + 1);
|
||||||
|
|
||||||
|
BuildNode::Split {
|
||||||
|
feature: best_feature,
|
||||||
|
threshold: best_threshold,
|
||||||
|
left: Box::new(left_child),
|
||||||
|
right: Box::new(right_child),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_leaf(attack_w: f32, normal_w: f32, min_purity: f32) -> BuildNode {
|
||||||
|
let total = attack_w + normal_w;
|
||||||
|
let decision = if total < f32::EPSILON {
|
||||||
|
TreeDecision::Defer
|
||||||
|
} else {
|
||||||
|
let attack_ratio = attack_w / total;
|
||||||
|
let normal_ratio = normal_w / total;
|
||||||
|
if attack_ratio >= min_purity {
|
||||||
|
TreeDecision::Block
|
||||||
|
} else if normal_ratio >= min_purity {
|
||||||
|
TreeDecision::Allow
|
||||||
|
} else {
|
||||||
|
TreeDecision::Defer
|
||||||
|
}
|
||||||
|
};
|
||||||
|
BuildNode::Leaf { decision }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gini(class_a: f32, class_b: f32) -> f32 {
|
||||||
|
let total = class_a + class_b;
|
||||||
|
if total < f32::EPSILON {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
let p_a = class_a / total;
|
||||||
|
let p_b = class_b / total;
|
||||||
|
1.0 - p_a * p_a - p_b * p_b
|
||||||
|
}
|
||||||
|
|
||||||
|
fn weighted_counts(samples: &[TrainingSample], indices: &[usize]) -> (f32, f32) {
|
||||||
|
let mut attack = 0.0f32;
|
||||||
|
let mut normal = 0.0f32;
|
||||||
|
for &i in indices {
|
||||||
|
let s = &samples[i];
|
||||||
|
if s.label >= 0.5 {
|
||||||
|
attack += s.weight;
|
||||||
|
} else {
|
||||||
|
normal += s.weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(attack, normal)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Flatten the recursive `BuildNode` tree into a `Vec<PackedNode>` using
|
||||||
|
/// BFS-order indexing.
|
||||||
|
fn flatten(node: &BuildNode, out: &mut Vec<PackedNode>) {
|
||||||
|
match node {
|
||||||
|
BuildNode::Leaf { decision } => {
|
||||||
|
let threshold = match decision {
|
||||||
|
TreeDecision::Allow => 0.0,
|
||||||
|
TreeDecision::Defer => 0.5,
|
||||||
|
TreeDecision::Block => 1.0,
|
||||||
|
};
|
||||||
|
out.push((255, threshold, 0, 0));
|
||||||
|
}
|
||||||
|
BuildNode::Split {
|
||||||
|
feature,
|
||||||
|
threshold,
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
} => {
|
||||||
|
// Reserve this node's position, then recursively flatten children.
|
||||||
|
let self_idx = out.len();
|
||||||
|
out.push((0, 0.0, 0, 0)); // placeholder
|
||||||
|
|
||||||
|
let left_idx = out.len();
|
||||||
|
flatten(left, out);
|
||||||
|
|
||||||
|
let right_idx = out.len();
|
||||||
|
flatten(right, out);
|
||||||
|
|
||||||
|
out[self_idx] = (
|
||||||
|
*feature as u8,
|
||||||
|
*threshold,
|
||||||
|
left_idx as u16,
|
||||||
|
right_idx as u16,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::dataset::sample::{DataSource, TrainingSample};
|
||||||
|
|
||||||
|
fn sample(features: Vec<f32>, label: f32) -> TrainingSample {
|
||||||
|
TrainingSample {
|
||||||
|
features,
|
||||||
|
label,
|
||||||
|
source: DataSource::ProductionLogs,
|
||||||
|
weight: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trivially_separable() {
|
||||||
|
// Feature 0 < 0.5 => Allow (label 0), >= 0.5 => Block (label 1)
|
||||||
|
let samples: Vec<TrainingSample> = (0..100)
|
||||||
|
.map(|i| {
|
||||||
|
let v = i as f32 / 100.0;
|
||||||
|
sample(vec![v, 0.0], if v < 0.5 { 0.0 } else { 1.0 })
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let config = TreeConfig {
|
||||||
|
max_depth: 6,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_purity: 0.90,
|
||||||
|
num_features: 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tree = train_tree(&samples, &config);
|
||||||
|
assert!(!tree.is_empty());
|
||||||
|
|
||||||
|
// Low values should be Allow.
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.1, 0.0]), TreeDecision::Allow);
|
||||||
|
// High values should be Block.
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.9, 0.0]), TreeDecision::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_defer_for_mixed_region() {
|
||||||
|
// Create samples where the middle region is genuinely mixed.
|
||||||
|
let mut samples = Vec::new();
|
||||||
|
for i in 0..50 {
|
||||||
|
let v = i as f32 / 100.0;
|
||||||
|
samples.push(sample(vec![v], 0.0)); // normal
|
||||||
|
}
|
||||||
|
for i in 50..100 {
|
||||||
|
let v = i as f32 / 100.0;
|
||||||
|
samples.push(sample(vec![v], 1.0)); // attack
|
||||||
|
}
|
||||||
|
// Add noise in the middle: some attacks below 0.5, some normals above 0.5.
|
||||||
|
for _ in 0..20 {
|
||||||
|
samples.push(sample(vec![0.45], 1.0));
|
||||||
|
samples.push(sample(vec![0.55], 0.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = TreeConfig {
|
||||||
|
max_depth: 3,
|
||||||
|
min_samples_leaf: 5,
|
||||||
|
min_purity: 0.95, // Very high purity requirement.
|
||||||
|
num_features: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tree = train_tree(&samples, &config);
|
||||||
|
|
||||||
|
// The boundary region should produce Defer.
|
||||||
|
let mid_decision = tree_predict(&tree, &[0.50]);
|
||||||
|
// It could be Defer or Allow/Block depending on how the split lands,
|
||||||
|
// but the tree should at least produce valid decisions.
|
||||||
|
assert!(matches!(
|
||||||
|
mid_decision,
|
||||||
|
TreeDecision::Allow | TreeDecision::Block | TreeDecision::Defer
|
||||||
|
));
|
||||||
|
|
||||||
|
// The extremes should be clear.
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.05]), TreeDecision::Allow);
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.95]), TreeDecision::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_max_depth_enforcement() {
|
||||||
|
// Even with perfect separability, depth should be capped.
|
||||||
|
let samples: Vec<TrainingSample> = (0..200)
|
||||||
|
.map(|i| {
|
||||||
|
let v = i as f32 / 200.0;
|
||||||
|
sample(vec![v, 0.0, 0.0, 0.0], if v < 0.5 { 0.0 } else { 1.0 })
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let config = TreeConfig {
|
||||||
|
max_depth: 2,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_purity: 0.90,
|
||||||
|
num_features: 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tree = train_tree(&samples, &config);
|
||||||
|
|
||||||
|
// With max_depth=2, we can have at most 2^3 - 1 = 7 nodes.
|
||||||
|
assert!(
|
||||||
|
tree.len() <= 7,
|
||||||
|
"tree should have at most 7 nodes at depth 2, got {}",
|
||||||
|
tree.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_single_class_becomes_leaf() {
|
||||||
|
// All samples are attack => immediate Block leaf.
|
||||||
|
let samples: Vec<TrainingSample> = (0..50)
|
||||||
|
.map(|i| sample(vec![i as f32 / 50.0], 1.0))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let config = TreeConfig {
|
||||||
|
max_depth: 6,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_purity: 0.90,
|
||||||
|
num_features: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tree = train_tree(&samples, &config);
|
||||||
|
assert_eq!(tree.len(), 1); // Just a single leaf.
|
||||||
|
assert_eq!(tree_predict(&tree, &[0.5]), TreeDecision::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gini_pure() {
|
||||||
|
assert!((gini(10.0, 0.0) - 0.0).abs() < 1e-6);
|
||||||
|
assert!((gini(0.0, 10.0) - 0.0).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_gini_max() {
|
||||||
|
// Maximum Gini for balanced binary: 1 - 2*(0.5^2) = 0.5
|
||||||
|
assert!((gini(5.0, 5.0) - 0.5).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user