From 5daed3ecb08d07ab655e80b4592389f39b9185f8 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Tue, 10 Mar 2026 23:38:21 +0000 Subject: [PATCH] chore: update scanner/ddos trainers, benchmarks, and tests Expand DDoS feature vector to 14 dimensions (cookie_ratio, referer_ratio, accept_language_ratio, suspicious_path_ratio). Add heuristic auto-labeling to DDoS trainer. Update benchmarks and tests to match new feature vectors. Signed-off-by: Sienna Meridian Satterwhite --- benches/ddos_bench.rs | 40 ++++ benches/scanner_bench.rs | 35 ++++ src/ddos/train.rs | 389 ++++++++++++++++++++++++--------------- src/scanner/train.rs | 292 +++++++++++++++++++++++++++-- tests/ddos_test.rs | 3 +- 5 files changed, 599 insertions(+), 160 deletions(-) create mode 100644 benches/ddos_bench.rs diff --git a/benches/ddos_bench.rs b/benches/ddos_bench.rs new file mode 100644 index 0000000..2cc524f --- /dev/null +++ b/benches/ddos_bench.rs @@ -0,0 +1,40 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use sunbeam_proxy::ensemble::ddos::ddos_ensemble_predict; +use sunbeam_proxy::ensemble::gen::ddos_weights; +use sunbeam_proxy::ensemble::mlp::mlp_predict_32; +use sunbeam_proxy::ensemble::tree::tree_predict; + +fn bench_ensemble_ddos_full(c: &mut Criterion) { + let raw: [f32; 14] = [5.0, 10.0, 2.0, 0.1, 50.0, 0.5, 3.0, 0.3, 500.0, 2.0, 0.8, 0.7, 0.9, 0.1]; + c.bench_function("ensemble::ddos full predict", |b| { + b.iter(|| ddos_ensemble_predict(black_box(&raw))) + }); +} + +fn bench_ensemble_ddos_tree_only(c: &mut Criterion) { + let input: [f32; 14] = [0.5; 14]; + c.bench_function("ensemble::ddos tree_only", |b| { + b.iter(|| tree_predict(black_box(&ddos_weights::TREE_NODES), black_box(&input))) + }); +} + +fn bench_ensemble_ddos_mlp_only(c: &mut Criterion) { + let input: [f32; 14] = [0.5; 14]; + c.bench_function("ensemble::ddos mlp_only", |b| { + b.iter(|| mlp_predict_32::<14>( + black_box(&ddos_weights::W1), + black_box(&ddos_weights::B1), + black_box(&ddos_weights::W2), + black_box(ddos_weights::B2), + black_box(&input), + )) + }); +} + +criterion_group!( + benches, + bench_ensemble_ddos_full, + bench_ensemble_ddos_tree_only, + bench_ensemble_ddos_mlp_only, +); +criterion_main!(benches); diff --git a/benches/scanner_bench.rs b/benches/scanner_bench.rs index ffbfe30..6039cb6 100644 --- a/benches/scanner_bench.rs +++ b/benches/scanner_bench.rs @@ -1,5 +1,9 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use sunbeam_proxy::config::RouteConfig; +use sunbeam_proxy::ensemble::gen::scanner_weights; +use sunbeam_proxy::ensemble::mlp::mlp_predict_32; +use sunbeam_proxy::ensemble::scanner::scanner_ensemble_predict; +use sunbeam_proxy::ensemble::tree::tree_predict; use sunbeam_proxy::scanner::detector::ScannerDetector; use sunbeam_proxy::scanner::features::{ self, fx_hash_bytes, ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS, @@ -250,6 +254,34 @@ fn bench_extract_features(c: &mut Criterion) { }); } +fn bench_ensemble_scanner_full(c: &mut Criterion) { + // Raw features simulating a scanner probe + let raw: [f32; 12] = [0.8, 3.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 1.0]; + c.bench_function("ensemble::scanner full predict", |b| { + b.iter(|| scanner_ensemble_predict(black_box(&raw))) + }); +} + +fn bench_ensemble_scanner_tree_only(c: &mut Criterion) { + let input: [f32; 12] = [0.8, 0.3, 1.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 1.0]; + c.bench_function("ensemble::scanner tree_only", |b| { + b.iter(|| tree_predict(black_box(&scanner_weights::TREE_NODES), black_box(&input))) + }); +} + +fn bench_ensemble_scanner_mlp_only(c: &mut Criterion) { + let input: [f32; 12] = [0.5; 12]; + c.bench_function("ensemble::scanner mlp_only", |b| { + b.iter(|| mlp_predict_32::<12>( + black_box(&scanner_weights::W1), + black_box(&scanner_weights::B1), + black_box(&scanner_weights::W2), + black_box(scanner_weights::B2), + black_box(&input), + )) + }); +} + criterion_group!( benches, bench_check_normal_browser, @@ -260,5 +292,8 @@ criterion_group!( bench_check_deep_path, bench_check_api_legitimate, bench_extract_features, + bench_ensemble_scanner_full, + bench_ensemble_scanner_tree_only, + bench_ensemble_scanner_mlp_only, ); criterion_main!(benches); diff --git a/src/ddos/train.rs b/src/ddos/train.rs index d9ea323..4c3789f 100644 --- a/src/ddos/train.rs +++ b/src/ddos/train.rs @@ -41,6 +41,34 @@ fn default_no_cookies_threshold() -> f64 { 0.05 } fn default_no_cookies_path_count() -> f64 { 20.0 } fn default_min_events() -> usize { 10 } +impl HeuristicThresholds { + pub fn new( + request_rate: f64, + path_repetition: f64, + error_rate: f64, + suspicious_path_ratio: f64, + no_cookies_threshold: f64, + no_cookies_path_count: f64, + min_events: usize, + ) -> Self { + Self { + request_rate, + path_repetition, + error_rate, + suspicious_path_ratio, + no_cookies_threshold, + no_cookies_path_count, + min_events, + } + } +} + +pub struct DdosTrainResult { + pub model: SerializedModel, + pub attack_count: usize, + pub normal_count: usize, +} + pub struct TrainArgs { pub input: String, pub output: String, @@ -82,147 +110,15 @@ fn parse_timestamp(ts: &str) -> f64 { } -pub fn run(args: TrainArgs) -> Result<()> { - eprintln!("Parsing logs from {}...", args.input); +/// Core training pipeline: parse logs, extract features, label IPs, build KNN model. +pub fn train_model(args: &TrainArgs) -> Result { + let ip_states = parse_logs(&args.input)?; - // Parse logs into per-IP state - let mut ip_states: FxHashMap = FxHashMap::default(); - let file = std::fs::File::open(&args.input) - .with_context(|| format!("opening {}", args.input))?; - let reader = std::io::BufReader::new(file); - - let mut total_lines = 0u64; - let mut parse_errors = 0u64; - - for line in reader.lines() { - let line = line?; - total_lines += 1; - let entry: AuditLog = match serde_json::from_str(&line) { - Ok(e) => e, - Err(_) => { - parse_errors += 1; - continue; - } - }; - - // Skip non-audit entries - if entry.fields.method.is_empty() { - continue; - } - - let ip = audit_log::strip_port(&entry.fields.client_ip).to_string(); - let ts = parse_timestamp(&entry.timestamp); - - let state = ip_states.entry(ip).or_default(); - state.timestamps.push(ts); - state.methods.push(method_to_u8(&entry.fields.method)); - state.path_hashes.push(fx_hash(&entry.fields.path)); - state.host_hashes.push(fx_hash(&entry.fields.host)); - state - .user_agent_hashes - .push(fx_hash(&entry.fields.user_agent)); - state.statuses.push(entry.fields.status); - state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32); - state - .content_lengths - .push(entry.fields.content_length.min(u32::MAX as u64) as u32); - state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false)); - state.has_referer.push( - entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false), - ); - state.has_accept_language.push( - entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false), - ); - state.suspicious_paths.push( - crate::ddos::features::is_suspicious_path(&entry.fields.path), - ); - } - - eprintln!( - "Parsed {} lines ({} errors), {} unique IPs", - total_lines, - parse_errors, - ip_states.len() - ); - - // Extract feature vectors per IP (using sliding windows) let window_secs = args.window_secs as f64; - let mut ip_features: FxHashMap> = FxHashMap::default(); - - for (ip, state) in &ip_states { - let n = state.timestamps.len(); - if n < args.min_events { - continue; - } - // Extract one feature vector per window - let mut features = Vec::new(); - let mut start = 0; - for end in 1..n { - let span = state.timestamps[end] - state.timestamps[start]; - if span >= window_secs || end == n - 1 { - let fv = state.extract_features_for_window(start, end + 1, window_secs); - features.push(fv); - start = end + 1; - } - } - if !features.is_empty() { - ip_features.insert(ip.clone(), features); - } - } + let ip_features = extract_ip_features(&ip_states, args.min_events, window_secs); // Label IPs - let mut ip_labels: FxHashMap = FxHashMap::default(); - - if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) { - // IP list mode - let attack_ips: FxHashSet = std::fs::read_to_string(attack_file) - .context("reading attack IPs file")? - .lines() - .map(|l| l.trim().to_string()) - .filter(|l| !l.is_empty()) - .collect(); - let normal_ips: FxHashSet = std::fs::read_to_string(normal_file) - .context("reading normal IPs file")? - .lines() - .map(|l| l.trim().to_string()) - .filter(|l| !l.is_empty()) - .collect(); - - for ip in ip_features.keys() { - if attack_ips.contains(ip) { - ip_labels.insert(ip.clone(), TrafficLabel::Attack); - } else if normal_ips.contains(ip) { - ip_labels.insert(ip.clone(), TrafficLabel::Normal); - } - } - } else if let Some(heuristics_file) = &args.heuristics { - // Heuristic auto-labeling - let heuristics_str = std::fs::read_to_string(heuristics_file) - .context("reading heuristics file")?; - let thresholds: HeuristicThresholds = - toml::from_str(&heuristics_str).context("parsing heuristics TOML")?; - - for (ip, features) in &ip_features { - // Use the aggregate (last/max) feature vector for labeling - let avg = average_features(features); - let is_attack = avg[0] > thresholds.request_rate // request_rate - || avg[7] > thresholds.path_repetition // path_repetition - || avg[3] > thresholds.error_rate // error_rate - || avg[13] > thresholds.suspicious_path_ratio // suspicious_path_ratio - || (avg[10] < thresholds.no_cookies_threshold // no cookies + high unique paths - && avg[1] > thresholds.no_cookies_path_count); - ip_labels.insert( - ip.clone(), - if is_attack { - TrafficLabel::Attack - } else { - TrafficLabel::Normal - }, - ); - } - } else { - bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling"); - } + let ip_labels = label_ips(args, &ip_features)?; // Build training dataset let mut all_points: Vec = Vec::new(); @@ -246,12 +142,6 @@ pub fn run(args: TrainArgs) -> Result<()> { .filter(|&&l| l == TrafficLabel::Attack) .count(); let normal_count = all_labels.len() - attack_count; - eprintln!( - "Training with {} points ({} attack, {} normal)", - all_points.len(), - attack_count, - normal_count - ); // Normalize let norm_params = NormParams::from_data(&all_points); @@ -260,7 +150,6 @@ pub fn run(args: TrainArgs) -> Result<()> { .map(|v| norm_params.normalize(v)) .collect(); - // Serialize let model = SerializedModel { points: normalized, labels: all_labels, @@ -269,7 +158,215 @@ pub fn run(args: TrainArgs) -> Result<()> { threshold: args.threshold, }; - let encoded = bincode::serialize(&model).context("serializing model")?; + Ok(DdosTrainResult { + model, + attack_count, + normal_count, + }) +} + +/// Train a DDoS model from pre-parsed IP states with programmatic heuristic thresholds. +/// Used by the autotune pipeline to avoid re-parsing logs on each trial. +pub fn train_model_from_states( + ip_states: &FxHashMap, + thresholds: &HeuristicThresholds, + k: usize, + threshold: f64, + window_secs: u64, + min_events: usize, +) -> Result { + let window_secs_f64 = window_secs as f64; + let ip_features = extract_ip_features(ip_states, min_events, window_secs_f64); + + let mut ip_labels: FxHashMap = FxHashMap::default(); + for (ip, features) in &ip_features { + let avg = average_features(features); + let is_attack = avg[0] > thresholds.request_rate + || avg[7] > thresholds.path_repetition + || avg[3] > thresholds.error_rate + || avg[13] > thresholds.suspicious_path_ratio + || (avg[10] < thresholds.no_cookies_threshold + && avg[1] > thresholds.no_cookies_path_count); + ip_labels.insert( + ip.clone(), + if is_attack { TrafficLabel::Attack } else { TrafficLabel::Normal }, + ); + } + + let mut all_points: Vec = Vec::new(); + let mut all_labels: Vec = Vec::new(); + + for (ip, features) in &ip_features { + if let Some(&label) = ip_labels.get(ip) { + for fv in features { + all_points.push(*fv); + all_labels.push(label); + } + } + } + + if all_points.is_empty() { + bail!("No labeled data points found with these heuristic thresholds."); + } + + let attack_count = all_labels.iter().filter(|&&l| l == TrafficLabel::Attack).count(); + let normal_count = all_labels.len() - attack_count; + + let norm_params = NormParams::from_data(&all_points); + let normalized: Vec = all_points.iter().map(|v| norm_params.normalize(v)).collect(); + + let model = SerializedModel { + points: normalized, + labels: all_labels, + norm_params, + k, + threshold, + }; + + Ok(DdosTrainResult { model, attack_count, normal_count }) +} + +/// Parse audit logs into per-IP state maps. +pub fn parse_logs(input: &str) -> Result> { + let mut ip_states: FxHashMap = FxHashMap::default(); + let file = std::fs::File::open(input) + .with_context(|| format!("opening {}", input))?; + let reader = std::io::BufReader::new(file); + + for line in reader.lines() { + let line = line?; + let entry: AuditLog = match serde_json::from_str(&line) { + Ok(e) => e, + Err(_) => continue, + }; + if entry.fields.method.is_empty() { + continue; + } + + let ip = audit_log::strip_port(&entry.fields.client_ip).to_string(); + let ts = parse_timestamp(&entry.timestamp); + + let state = ip_states.entry(ip).or_default(); + state.timestamps.push(ts); + state.methods.push(method_to_u8(&entry.fields.method)); + state.path_hashes.push(fx_hash(&entry.fields.path)); + state.host_hashes.push(fx_hash(&entry.fields.host)); + state.user_agent_hashes.push(fx_hash(&entry.fields.user_agent)); + state.statuses.push(entry.fields.status); + state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32); + state.content_lengths.push(entry.fields.content_length.min(u32::MAX as u64) as u32); + state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false)); + state.has_referer.push( + entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false), + ); + state.has_accept_language.push( + entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false), + ); + state.suspicious_paths.push( + crate::ddos::features::is_suspicious_path(&entry.fields.path), + ); + } + + Ok(ip_states) +} + +/// Extract feature vectors per IP using sliding windows. +pub fn extract_ip_features( + ip_states: &FxHashMap, + min_events: usize, + window_secs: f64, +) -> FxHashMap> { + let mut ip_features: FxHashMap> = FxHashMap::default(); + + for (ip, state) in ip_states { + let n = state.timestamps.len(); + if n < min_events { + continue; + } + let mut features = Vec::new(); + let mut start = 0; + for end in 1..n { + let span = state.timestamps[end] - state.timestamps[start]; + if span >= window_secs || end == n - 1 { + let fv = state.extract_features_for_window(start, end + 1, window_secs); + features.push(fv); + start = end + 1; + } + } + if !features.is_empty() { + ip_features.insert(ip.clone(), features); + } + } + + ip_features +} + +fn label_ips( + args: &TrainArgs, + ip_features: &FxHashMap>, +) -> Result> { + let mut ip_labels: FxHashMap = FxHashMap::default(); + + if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) { + let attack_ips: FxHashSet = std::fs::read_to_string(attack_file) + .context("reading attack IPs file")? + .lines() + .map(|l| l.trim().to_string()) + .filter(|l| !l.is_empty()) + .collect(); + let normal_ips: FxHashSet = std::fs::read_to_string(normal_file) + .context("reading normal IPs file")? + .lines() + .map(|l| l.trim().to_string()) + .filter(|l| !l.is_empty()) + .collect(); + + for ip in ip_features.keys() { + if attack_ips.contains(ip) { + ip_labels.insert(ip.clone(), TrafficLabel::Attack); + } else if normal_ips.contains(ip) { + ip_labels.insert(ip.clone(), TrafficLabel::Normal); + } + } + } else if let Some(heuristics_file) = &args.heuristics { + let heuristics_str = std::fs::read_to_string(heuristics_file) + .context("reading heuristics file")?; + let thresholds: HeuristicThresholds = + toml::from_str(&heuristics_str).context("parsing heuristics TOML")?; + + for (ip, features) in ip_features { + let avg = average_features(features); + let is_attack = avg[0] > thresholds.request_rate + || avg[7] > thresholds.path_repetition + || avg[3] > thresholds.error_rate + || avg[13] > thresholds.suspicious_path_ratio + || (avg[10] < thresholds.no_cookies_threshold + && avg[1] > thresholds.no_cookies_path_count); + ip_labels.insert( + ip.clone(), + if is_attack { TrafficLabel::Attack } else { TrafficLabel::Normal }, + ); + } + } else { + bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling"); + } + + Ok(ip_labels) +} + +pub fn run(args: TrainArgs) -> Result<()> { + eprintln!("Parsing logs from {}...", args.input); + + let result = train_model(&args)?; + + eprintln!( + "Training with {} points ({} attack, {} normal)", + result.model.points.len(), + result.attack_count, + result.normal_count + ); + + let encoded = bincode::serialize(&result.model).context("serializing model")?; std::fs::write(&args.output, &encoded) .with_context(|| format!("writing model to {}", args.output))?; @@ -277,7 +374,7 @@ pub fn run(args: TrainArgs) -> Result<()> { "Model saved to {} ({} bytes, {} points)", args.output, encoded.len(), - model.points.len() + result.model.points.len() ); Ok(()) diff --git a/src/scanner/train.rs b/src/scanner/train.rs index 96e4d24..0dc770b 100644 --- a/src/scanner/train.rs +++ b/src/scanner/train.rs @@ -18,7 +18,7 @@ pub struct TrainScannerArgs { } /// Default suspicious fragments — matches the DDoS feature list plus extras. -const DEFAULT_FRAGMENTS: &[&str] = &[ +pub const DEFAULT_FRAGMENTS: &[&str] = &[ ".env", ".git", ".bak", ".sql", ".tar", ".zip", "wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc", "phpinfo", "phpmyadmin", "php-info", @@ -32,9 +32,218 @@ const DEFAULT_FRAGMENTS: &[&str] = &[ const ATTACK_EXTENSIONS: &[&str] = &[".env", ".sql", ".bak", ".git/config"]; const TRAVERSAL_MARKERS: &[&str] = &["..", "%00", "%0a"]; -struct LabeledSample { - features: ScannerFeatureVector, - label: f64, // 1.0 = attack, 0.0 = normal +pub struct LabeledSample { + pub features: ScannerFeatureVector, + pub label: f64, // 1.0 = attack, 0.0 = normal +} + +pub struct ScannerTrainResult { + pub model: ScannerModel, + pub train_metrics: Metrics, + pub test_metrics: Metrics, +} + +/// Core training pipeline: parse logs, label, train, evaluate. Returns the trained model and metrics. +pub fn train_and_evaluate( + args: &TrainScannerArgs, + learning_rate: f64, + epochs: usize, + class_weight_multiplier: f64, +) -> Result { + let mut fragments: Vec = DEFAULT_FRAGMENTS.iter().map(|s| s.to_string()).collect(); + let fragment_hashes: FxHashSet = fragments + .iter() + .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) + .collect(); + let extension_hashes: FxHashSet = features::SUSPICIOUS_EXTENSIONS_LIST + .iter() + .map(|e| fx_hash_bytes(e.as_bytes())) + .collect(); + + let mut samples: Vec = Vec::new(); + let file = std::fs::File::open(&args.input) + .with_context(|| format!("opening {}", args.input))?; + let reader = std::io::BufReader::new(file); + let mut log_hosts: FxHashSet = FxHashSet::default(); + let mut parsed_entries: Vec<(AuditFields, String)> = Vec::new(); + + for line in reader.lines() { + let line = line?; + if line.trim().is_empty() { + continue; + } + let entry: AuditLog = match serde_json::from_str(&line) { + Ok(e) => e, + Err(_) => continue, + }; + let host_prefix = entry + .fields + .host + .split('.') + .next() + .unwrap_or("") + .to_string(); + log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes())); + parsed_entries.push((entry.fields, host_prefix)); + } + + for (fields, host_prefix) in &parsed_entries { + let has_cookies = fields.has_cookies.unwrap_or(false); + let has_referer = fields + .referer + .as_ref() + .map(|r| r != "-" && !r.is_empty()) + .unwrap_or(false); + let has_accept_language = fields + .accept_language + .as_ref() + .map(|a| a != "-" && !a.is_empty()) + .unwrap_or(false); + + let feats = features::extract_features( + &fields.method, + &fields.path, + host_prefix, + has_cookies, + has_referer, + has_accept_language, + "-", + &fields.user_agent, + fields.content_length, + &fragment_hashes, + &extension_hashes, + &log_hosts, + ); + + let label = if let Some(ref gt) = fields.label { + match gt.as_str() { + "attack" | "anomalous" => Some(1.0), + "normal" => Some(0.0), + _ => None, + } + } else { + label_request( + &fields.path, + has_cookies, + has_referer, + has_accept_language, + &fields.user_agent, + host_prefix, + fields.status, + &log_hosts, + &fragment_hashes, + ) + }; + + if let Some(l) = label { + samples.push(LabeledSample { + features: feats, + label: l, + }); + } + } + + if args.csic { + let csic_entries = crate::scanner::csic::fetch_csic_dataset()?; + for (_, host_prefix) in &csic_entries { + log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes())); + } + for (fields, host_prefix) in &csic_entries { + let has_cookies = fields.has_cookies.unwrap_or(false); + let has_referer = fields + .referer + .as_ref() + .map(|r| r != "-" && !r.is_empty()) + .unwrap_or(false); + let has_accept_language = fields + .accept_language + .as_ref() + .map(|a| a != "-" && !a.is_empty()) + .unwrap_or(false); + + let feats = features::extract_features( + &fields.method, + &fields.path, + host_prefix, + has_cookies, + has_referer, + has_accept_language, + "-", + &fields.user_agent, + fields.content_length, + &fragment_hashes, + &extension_hashes, + &log_hosts, + ); + + let label = match fields.label.as_deref() { + Some("attack" | "anomalous") => 1.0, + Some("normal") => 0.0, + _ => continue, + }; + + samples.push(LabeledSample { + features: feats, + label, + }); + } + } + + if let Some(wordlist_dir) = &args.wordlists { + ingest_wordlists( + wordlist_dir, + &mut samples, + &mut fragments, + &fragment_hashes, + &extension_hashes, + &log_hosts, + )?; + } + + if samples.is_empty() { + anyhow::bail!("no training samples found"); + } + + // Stratified 80/20 train/test split + let (train_samples, test_samples) = stratified_split(&mut samples, 0.8, 42); + + // Compute normalization params from training set only + let train_feature_vecs: Vec = + train_samples.iter().map(|s| s.features).collect(); + let norm_params = ScannerNormParams::from_data(&train_feature_vecs); + + let train_normalized: Vec = + train_feature_vecs.iter().map(|v| norm_params.normalize(v)).collect(); + + // Train logistic regression with configurable params + let weights = train_logistic_regression_weighted( + &train_normalized, + &train_samples, + epochs, + learning_rate, + class_weight_multiplier, + ); + + let model = ScannerModel { + weights, + threshold: args.threshold, + norm_params: norm_params.clone(), + fragments, + }; + + let train_metrics = evaluate(&train_normalized, &train_samples, &weights, args.threshold); + + let test_feature_vecs: Vec = + test_samples.iter().map(|s| s.features).collect(); + let test_normalized: Vec = + test_feature_vecs.iter().map(|v| norm_params.normalize(v)).collect(); + let test_metrics = evaluate(&test_normalized, &test_samples, &weights, args.threshold); + + Ok(ScannerTrainResult { + model, + train_metrics, + test_metrics, + }) } pub fn run(args: TrainScannerArgs) -> Result<()> { @@ -295,24 +504,28 @@ pub fn run(args: TrainScannerArgs) -> Result<()> { Ok(()) } -struct Metrics { - tp: u32, - fp: u32, - tn: u32, - fn_: u32, +pub struct Metrics { + pub tp: u32, + pub fp: u32, + pub tn: u32, + pub fn_: u32, } impl Metrics { - fn precision(&self) -> f64 { + pub fn precision(&self) -> f64 { if self.tp + self.fp > 0 { self.tp as f64 / (self.tp + self.fp) as f64 } else { 0.0 } } - fn recall(&self) -> f64 { + pub fn recall(&self) -> f64 { if self.tp + self.fn_ > 0 { self.tp as f64 / (self.tp + self.fn_) as f64 } else { 0.0 } } - fn f1(&self) -> f64 { + pub fn f1(&self) -> f64 { + self.fbeta(1.0) + } + pub fn fbeta(&self, beta: f64) -> f64 { let p = self.precision(); let r = self.recall(); - if p + r > 0.0 { 2.0 * p * r / (p + r) } else { 0.0 } + let b2 = beta * beta; + if p + r > 0.0 { (1.0 + b2) * p * r / (b2 * p + r) } else { 0.0 } } fn print(&self, total: usize) { let acc = (self.tp + self.tn) as f64 / total as f64 * 100.0; @@ -457,6 +670,59 @@ fn train_logistic_regression( weights } +/// Like `train_logistic_regression` but with configurable class_weight_multiplier. +/// A multiplier > 1.0 increases the weight of the minority (attack) class further. +fn train_logistic_regression_weighted( + normalized: &[ScannerFeatureVector], + samples: &[LabeledSample], + epochs: usize, + learning_rate: f64, + class_weight_multiplier: f64, +) -> [f64; NUM_SCANNER_WEIGHTS] { + let mut weights = [0.0f64; NUM_SCANNER_WEIGHTS]; + let n = samples.len() as f64; + + let n_attack = samples.iter().filter(|s| s.label > 0.5).count() as f64; + let n_normal = n - n_attack; + let (w_attack, w_normal) = if n_attack > 0.0 && n_normal > 0.0 { + (n / (2.0 * n_attack) * class_weight_multiplier, n / (2.0 * n_normal)) + } else { + (1.0, 1.0) + }; + + for _epoch in 0..epochs { + let mut gradients = [0.0f64; NUM_SCANNER_WEIGHTS]; + + for (i, sample) in samples.iter().enumerate() { + let f = &normalized[i]; + let mut z = weights[NUM_SCANNER_FEATURES + 2]; + for j in 0..NUM_SCANNER_FEATURES { + z += weights[j] * f[j]; + } + z += weights[12] * f[0] * (1.0 - f[3]); + z += weights[13] * (1.0 - f[9]) * (1.0 - f[5]); + + let prediction = sigmoid(z); + let error = prediction - sample.label; + let cw = if sample.label > 0.5 { w_attack } else { w_normal }; + let weighted_error = error * cw; + + for j in 0..NUM_SCANNER_FEATURES { + gradients[j] += weighted_error * f[j]; + } + gradients[12] += weighted_error * f[0] * (1.0 - f[3]); + gradients[13] += weighted_error * (1.0 - f[9]) * (1.0 - f[5]); + gradients[14] += weighted_error; + } + + for j in 0..NUM_SCANNER_WEIGHTS { + weights[j] -= learning_rate * gradients[j] / n; + } + } + + weights +} + #[allow(clippy::too_many_arguments)] fn label_request( path: &str, diff --git a/tests/ddos_test.rs b/tests/ddos_test.rs index dc9daf6..3223762 100644 --- a/tests/ddos_test.rs +++ b/tests/ddos_test.rs @@ -46,13 +46,14 @@ fn make_model( fn default_ddos_config() -> DDoSConfig { DDoSConfig { - model_path: String::new(), + model_path: Some(String::new()), k: 5, threshold: 0.6, window_secs: 60, window_capacity: 1000, min_events: 10, enabled: true, + use_ensemble: false, } }