// Copyright Sunbeam Studios 2026 // SPDX-License-Identifier: Apache-2.0 //! Scanner MLP+tree training loop using burn's SupervisedTraining. //! //! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP //! with cosine annealing + early stopping, then exports the combined ensemble //! weights as a Rust source file for `src/ensemble/gen/scanner_weights.rs`. use anyhow::{Context, Result}; use std::path::Path; use burn::backend::Autodiff; use burn::backend::Wgpu; use burn::data::dataloader::DataLoaderBuilder; use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig; use burn::optim::AdamConfig; use burn::prelude::*; use burn::record::CompactRecorder; use burn::train::metric::{AccuracyMetric, LossMetric}; use burn::train::{Learner, SupervisedTraining}; use crate::dataset::sample::{load_dataset, TrainingSample}; use crate::training::batch::{SampleBatcher, SampleDataset}; use crate::training::export::{export_to_file, ExportedModel}; use crate::training::mlp::MlpConfig; use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision}; /// Number of scanner features (matches `crate::scanner::features::NUM_SCANNER_FEATURES`). const NUM_FEATURES: usize = 12; type TrainBackend = Autodiff>; /// Arguments for the scanner MLP training command. pub struct TrainScannerMlpArgs { /// Path to a bincode `DatasetManifest` file. pub dataset_path: String, /// Directory to write output files (Rust source, model record). pub output_dir: String, /// Hidden layer width (default 32). pub hidden_dim: usize, /// Number of training epochs (default 100). pub epochs: usize, /// Adam learning rate (default 0.001). pub learning_rate: f64, /// Mini-batch size (default 64). pub batch_size: usize, /// CART max depth (default 8). pub tree_max_depth: usize, /// CART leaf purity threshold (default 0.98). pub tree_min_purity: f32, /// Minimum samples in a leaf node (default 2). pub min_samples_leaf: usize, /// Weight for cookie feature (feature 3: has_cookies). 0.0 = ignore, 1.0 = full weight. pub cookie_weight: f32, } impl Default for TrainScannerMlpArgs { fn default() -> Self { Self { dataset_path: String::new(), output_dir: ".".into(), hidden_dim: 32, epochs: 100, learning_rate: 0.0001, batch_size: 64, tree_max_depth: 8, tree_min_purity: 0.98, min_samples_leaf: 2, cookie_weight: 1.0, } } } /// Index of the has_cookies feature in the scanner feature vector. const COOKIE_FEATURE_IDX: usize = 3; /// Entry point: train scanner ensemble and export weights. pub fn run(args: TrainScannerMlpArgs) -> Result<()> { // 1. Load dataset. let manifest = load_dataset(Path::new(&args.dataset_path)) .context("loading dataset manifest")?; let samples = &manifest.scanner_samples; anyhow::ensure!(!samples.is_empty(), "no scanner samples in dataset"); for s in samples { anyhow::ensure!( s.features.len() == NUM_FEATURES, "expected {} features, got {}", NUM_FEATURES, s.features.len() ); } println!( "[scanner] loaded {} samples ({} attack, {} normal)", samples.len(), samples.iter().filter(|s| s.label >= 0.5).count(), samples.iter().filter(|s| s.label < 0.5).count(), ); // 2. Compute normalization params from training data. let (norm_mins, norm_maxs) = compute_norm_params(samples); // Apply cookie_weight: for the MLP, we scale the normalization range so // the feature contributes less gradient signal. For the CART tree, scaling // doesn't help (the tree just adjusts its threshold), so we mask the feature // to a constant on a fraction of training samples to degrade its Gini gain. if args.cookie_weight < 1.0 - f32::EPSILON { println!( "[scanner] cookie_weight={:.2} (feature {} influence reduced)", args.cookie_weight, COOKIE_FEATURE_IDX, ); } // MLP norm adjustment: scale the cookie feature's normalization range. let mut mlp_norm_maxs = norm_maxs.clone(); if args.cookie_weight < 1.0 - f32::EPSILON { let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX]; if range > f32::EPSILON && args.cookie_weight > f32::EPSILON { mlp_norm_maxs[COOKIE_FEATURE_IDX] = range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX]; } } // 3. Stratified 80/20 split. let (train_set, val_set) = stratified_split(samples, 0.8); println!( "[scanner] train={}, val={}", train_set.len(), val_set.len() ); // 4. Train CART tree (with cookie feature masking for reduced weight). let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight); let tree_config = TreeConfig { max_depth: args.tree_max_depth, min_samples_leaf: args.min_samples_leaf, min_purity: args.tree_min_purity, num_features: NUM_FEATURES, }; let tree_nodes = train_tree(&tree_train_set, &tree_config); println!("[scanner] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth); // Evaluate tree on validation set (use original norms — tree learned on masked features). let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs); println!( "[scanner] tree validation: {:.2}% correct (of decided), {:.1}% deferred", tree_correct * 100.0, tree_deferred * 100.0, ); // 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling). let device = Default::default(); let mlp_config = MlpConfig { input_dim: NUM_FEATURES, hidden_dim: args.hidden_dim, }; let artifact_dir = Path::new(&args.output_dir).join("scanner_artifacts"); std::fs::create_dir_all(&artifact_dir).ok(); let model = train_mlp( &train_set, &val_set, &mlp_config, &norm_mins, &mlp_norm_maxs, args.epochs, args.learning_rate, args.batch_size, &device, &artifact_dir, ); // 6. Extract weights from trained model (export mlp_norm_maxs so inference // automatically applies the same cookie scaling). let exported = extract_weights( &model, "scanner", &tree_nodes, 0.5, &norm_mins, &mlp_norm_maxs, &device, ); // 7. Write output. let out_dir = Path::new(&args.output_dir); std::fs::create_dir_all(out_dir).context("creating output directory")?; let rust_path = out_dir.join("scanner_weights.rs"); export_to_file(&exported, &rust_path)?; println!("[scanner] exported Rust weights to {}", rust_path.display()); Ok(()) } // --------------------------------------------------------------------------- // Cookie feature masking for CART trees // --------------------------------------------------------------------------- /// Mask the cookie feature to reduce its influence on CART tree training. /// /// Scaling a binary feature doesn't reduce its Gini gain — the tree just adjusts /// the split threshold. Instead, we mask (set to 0.5) a fraction of samples so /// the feature's apparent class-separation degrades. /// /// - `cookie_weight = 0.0` → fully masked (feature is constant 0.5, zero info gain) /// - `cookie_weight = 0.5` → 50% of samples masked (noisy, reduced gain) /// - `cookie_weight = 1.0` → no masking (full feature) fn mask_cookie_feature( samples: &[TrainingSample], cookie_idx: usize, cookie_weight: f32, ) -> Vec { if cookie_weight >= 1.0 - f32::EPSILON { return samples.to_vec(); } samples .iter() .enumerate() .map(|(i, s)| { let mut s2 = s.clone(); if cookie_weight < f32::EPSILON { s2.features[cookie_idx] = 0.5; } else { let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42); let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32; if r > cookie_weight { s2.features[cookie_idx] = 0.5; } } s2 }) .collect() } // --------------------------------------------------------------------------- // Normalization // --------------------------------------------------------------------------- fn compute_norm_params(samples: &[TrainingSample]) -> (Vec, Vec) { let dim = samples[0].features.len(); let mut mins = vec![f32::MAX; dim]; let mut maxs = vec![f32::MIN; dim]; for s in samples { for i in 0..dim { mins[i] = mins[i].min(s.features[i]); maxs[i] = maxs[i].max(s.features[i]); } } (mins, maxs) } // --------------------------------------------------------------------------- // Stratified split // --------------------------------------------------------------------------- fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec, Vec) { let mut attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label >= 0.5).collect(); let mut normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label < 0.5).collect(); deterministic_shuffle(&mut attacks); deterministic_shuffle(&mut normals); let attack_split = (attacks.len() as f64 * train_ratio) as usize; let normal_split = (normals.len() as f64 * train_ratio) as usize; let mut train = Vec::new(); let mut val = Vec::new(); for (i, s) in attacks.iter().enumerate() { if i < attack_split { train.push((*s).clone()); } else { val.push((*s).clone()); } } for (i, s) in normals.iter().enumerate() { if i < normal_split { train.push((*s).clone()); } else { val.push((*s).clone()); } } (train, val) } fn deterministic_shuffle(items: &mut [T]) { let mut rng = 42u64; for i in (1..items.len()).rev() { rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); let j = (rng >> 33) as usize % (i + 1); items.swap(i, j); } } // --------------------------------------------------------------------------- // Tree evaluation // --------------------------------------------------------------------------- fn eval_tree( nodes: &[(u8, f32, u16, u16)], val_set: &[TrainingSample], mins: &[f32], maxs: &[f32], ) -> (f64, f64) { let mut decided = 0usize; let mut correct = 0usize; let mut deferred = 0usize; for s in val_set { let normed = normalize_features(&s.features, mins, maxs); let decision = tree_predict(nodes, &normed); match decision { TreeDecision::Defer => deferred += 1, TreeDecision::Block => { decided += 1; if s.label >= 0.5 { correct += 1; } } TreeDecision::Allow => { decided += 1; if s.label < 0.5 { correct += 1; } } } } let accuracy = if decided > 0 { correct as f64 / decided as f64 } else { 0.0 }; let defer_rate = deferred as f64 / val_set.len() as f64; (accuracy, defer_rate) } fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec { features .iter() .enumerate() .map(|(i, &v)| { let range = maxs[i] - mins[i]; if range > f32::EPSILON { ((v - mins[i]) / range).clamp(0.0, 1.0) } else { 0.0 } }) .collect() } // --------------------------------------------------------------------------- // MLP training via SupervisedTraining // --------------------------------------------------------------------------- fn train_mlp( train_set: &[TrainingSample], val_set: &[TrainingSample], config: &MlpConfig, mins: &[f32], maxs: &[f32], epochs: usize, learning_rate: f64, batch_size: usize, device: &::Device, artifact_dir: &Path, ) -> crate::training::mlp::MlpModel> { let model = config.init::(device); let train_dataset = SampleDataset::new(train_set, mins, maxs); let val_dataset = SampleDataset::new(val_set, mins, maxs); let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new()) .batch_size(batch_size) .shuffle(42) .num_workers(1) .build(train_dataset); let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new()) .batch_size(batch_size) .num_workers(1) .build(val_dataset); // Cosine annealing: initial_lr must be in (0.0, 1.0]. let lr = learning_rate.min(1.0); let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs) .init() .expect("valid cosine annealing config"); let learner = Learner::new( model, AdamConfig::new().init(), lr_scheduler, ); let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid) .metric_train_numeric(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new()) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .num_epochs(epochs) .summary() .launch(learner); result.model } // --------------------------------------------------------------------------- // Weight extraction // --------------------------------------------------------------------------- fn extract_weights( model: &crate::training::mlp::MlpModel>, name: &str, tree_nodes: &[(u8, f32, u16, u16)], threshold: f32, norm_mins: &[f32], norm_maxs: &[f32], _device: & as Backend>::Device, ) -> ExportedModel { let w1_tensor = model.linear1.weight.val(); let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val(); let w2_tensor = model.linear2.weight.val(); let b2_tensor = model.linear2.bias.as_ref().expect("linear2 has bias").val(); let w1_data: Vec = w1_tensor.to_data().to_vec().expect("w1 flat"); let b1_data: Vec = b1_tensor.to_data().to_vec().expect("b1 flat"); let w2_data: Vec = w2_tensor.to_data().to_vec().expect("w2 flat"); let b2_data: Vec = b2_tensor.to_data().to_vec().expect("b2 flat"); let hidden_dim = b1_data.len(); let input_dim = w1_data.len() / hidden_dim; let w1: Vec> = (0..hidden_dim) .map(|h| w1_data[h * input_dim..(h + 1) * input_dim].to_vec()) .collect(); ExportedModel { model_name: name.to_string(), input_dim, hidden_dim, w1, b1: b1_data, w2: w2_data, b2: b2_data[0], tree_nodes: tree_nodes.to_vec(), threshold, norm_mins: norm_mins.to_vec(), norm_maxs: norm_maxs.to_vec(), } } #[cfg(test)] mod tests { use super::*; use crate::dataset::sample::{DataSource, TrainingSample}; fn make_scanner_sample(features: [f32; 12], label: f32) -> TrainingSample { TrainingSample { features: features.to_vec(), label, source: DataSource::ProductionLogs, weight: 1.0, } } #[test] fn test_stratified_split_preserves_ratio() { let mut samples = Vec::new(); for _ in 0..80 { samples.push(make_scanner_sample([0.0; 12], 0.0)); } for _ in 0..20 { samples.push(make_scanner_sample([1.0; 12], 1.0)); } let (train, val) = stratified_split(&samples, 0.8); let train_attacks = train.iter().filter(|s| s.label >= 0.5).count(); let val_attacks = val.iter().filter(|s| s.label >= 0.5).count(); assert_eq!(train_attacks, 16); assert_eq!(val_attacks, 4); assert_eq!(train.len() + val.len(), 100); } #[test] fn test_norm_params() { let samples = vec![ make_scanner_sample([0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 0.0), make_scanner_sample([1.0, 20.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 1.0), ]; let (mins, maxs) = compute_norm_params(&samples); assert_eq!(mins[0], 0.0); assert_eq!(maxs[0], 1.0); assert_eq!(mins[1], 10.0); assert_eq!(maxs[1], 20.0); } }