diff --git a/src/audit.rs b/src/audit.rs new file mode 100644 index 0000000..4121986 --- /dev/null +++ b/src/audit.rs @@ -0,0 +1,236 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + +//! Unified audit log record definition. +//! +//! This module is the **single source of truth** for the audit log schema. +//! Both the proxy's `logging()` method (serialization) and every training/replay +//! parser (deserialization) must use these types. `deny_unknown_fields` ensures +//! that any schema change is caught at parse time rather than silently ignored. + +use serde::{Deserialize, Serialize}; + +/// Minimal probe struct to check if a JSON line is an audit log. +#[derive(Deserialize)] +struct Probe { + #[serde(default)] + fields: Option, +} + +#[derive(Deserialize)] +struct ProbeFields { + #[serde(default)] + target: Option, +} + +/// Top-level JSON line written by the tracing JSON layer. +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct AuditLogLine { + pub timestamp: String, + pub level: String, + pub fields: AuditFields, + /// Span information injected by tracing layers. + #[serde(default)] + pub span: Option, + /// Span list injected by tracing layers. + #[serde(default)] + pub spans: Option, +} + +/// The request audit fields — canonical schema. +/// +/// Every field the proxy emits in `tracing::info!(target = "audit", ...)` +/// must appear here. `deny_unknown_fields` will cause a hard parse error +/// if the proxy starts emitting a field that isn't listed, forcing you to +/// update this struct (and all downstream consumers) in one place. +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(deny_unknown_fields)] +pub struct AuditFields { + /// The literal "request" message from `tracing::info!("request")`. + #[serde(default = "default_dash")] + pub message: String, + /// The tracing target, always "audit" for request logs. + #[serde(default = "default_dash")] + pub target: String, + + // --- request identity --- + #[serde(default)] + pub request_id: String, + pub method: String, + pub host: String, + pub path: String, + #[serde(default)] + pub query: String, + pub client_ip: String, + + // --- response --- + #[serde(deserialize_with = "flexible_u16")] + pub status: u16, + #[serde(deserialize_with = "flexible_u64")] + pub duration_ms: u64, + #[serde(default, deserialize_with = "flexible_u64_default")] + pub content_length: u64, + #[serde(default, deserialize_with = "flexible_u64_default")] + pub response_bytes: u64, + + // --- headers --- + #[serde(default = "default_dash")] + pub user_agent: String, + #[serde(default = "default_dash")] + pub referer: String, + #[serde(default = "default_dash")] + pub accept_language: String, + #[serde(default = "default_dash")] + pub accept: String, + #[serde(default = "default_dash")] + pub accept_encoding: String, + #[serde(default)] + pub has_cookies: bool, + #[serde(default = "default_dash")] + pub connection: String, + + // --- infra --- + #[serde(default = "default_dash")] + pub cf_country: String, + #[serde(default)] + pub backend: String, + #[serde(default)] + pub error: String, + #[serde(default = "default_dash")] + pub http_version: String, + #[serde(default)] + pub header_count: u16, + + // --- training only (not emitted by proxy, but present in external datasets) --- + /// Ground-truth label injected by external dataset parsers. + /// Values: "attack", "normal". + #[serde(default)] + pub label: Option, +} + +impl AuditLogLine { + /// Try to parse a JSON line as an audit log entry. + /// + /// - Returns `Ok(Some(entry))` for valid audit log lines. + /// - Returns `Ok(None)` for non-audit lines (TLS errors, etc.). + /// - Returns `Err` if the line IS an audit log but has unknown fields + /// (schema drift that needs fixing). + pub fn try_parse(line: &str) -> Result, String> { + // Quick probe: is this an audit-target line? + let probe: Probe = match serde_json::from_str(line) { + Ok(p) => p, + Err(_) => return Ok(None), // not valid JSON or missing fields + }; + let is_audit = probe + .fields + .as_ref() + .and_then(|f| f.target.as_deref()) + .map(|t| t == "audit") + .unwrap_or(false); + + if !is_audit { + return Ok(None); + } + + // Full parse with deny_unknown_fields — will error on schema drift. + match serde_json::from_str::(line) { + Ok(entry) => Ok(Some(entry)), + Err(e) => Err(format!( + "audit log schema mismatch (update src/audit.rs): {e}" + )), + } + } +} + +impl Default for AuditFields { + fn default() -> Self { + Self { + message: "request".to_string(), + target: "audit".to_string(), + request_id: String::new(), + method: String::new(), + host: String::new(), + path: String::new(), + query: String::new(), + client_ip: String::new(), + status: 0, + duration_ms: 0, + content_length: 0, + response_bytes: 0, + user_agent: "-".to_string(), + referer: "-".to_string(), + accept_language: "-".to_string(), + accept: "-".to_string(), + accept_encoding: "-".to_string(), + has_cookies: false, + connection: "-".to_string(), + cf_country: "-".to_string(), + backend: String::new(), + error: String::new(), + http_version: "-".to_string(), + header_count: 0, + label: None, + } + } +} + +fn default_dash() -> String { + "-".to_string() +} + +pub fn flexible_u64<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> std::result::Result { + #[derive(Deserialize)] + #[serde(untagged)] + enum StringOrNum { + Num(u64), + Str(String), + } + match StringOrNum::deserialize(deserializer)? { + StringOrNum::Num(n) => Ok(n), + StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom), + } +} + +fn flexible_u64_default<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> std::result::Result { + #[derive(Deserialize)] + #[serde(untagged)] + enum Val { + Num(u64), + Str(String), + } + match Val::deserialize(deserializer) { + Ok(Val::Num(n)) => Ok(n), + Ok(Val::Str(s)) => Ok(s.parse().unwrap_or(0)), + Err(_) => Ok(0), + } +} + +pub fn flexible_u16<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> std::result::Result { + #[derive(Deserialize)] + #[serde(untagged)] + enum StringOrNum { + Num(u16), + Str(String), + } + match StringOrNum::deserialize(deserializer)? { + StringOrNum::Num(n) => Ok(n), + StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom), + } +} + +/// Strip the port suffix from a socket address string. +pub fn strip_port(addr: &str) -> &str { + if addr.starts_with('[') { + addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr) + } else if let Some(pos) = addr.rfind(':') { + &addr[..pos] + } else { + addr + } +} diff --git a/src/autotune/ddos.rs b/src/autotune/ddos.rs index 68f74f6..fe538de 100644 --- a/src/autotune/ddos.rs +++ b/src/autotune/ddos.rs @@ -1,230 +1,6 @@ -use crate::autotune::optimizer::BayesianOptimizer; -use crate::autotune::params::{ParamDef, ParamSpace, ParamType}; -use crate::ddos::replay::{ReplayArgs, replay_and_evaluate}; -use crate::ddos::train::{HeuristicThresholds, train_model_from_states, parse_logs}; -use anyhow::{Context, Result}; -use std::io::Write; -use std::time::Instant; +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 -pub struct AutotuneDdosArgs { - pub input: String, - pub output: String, - pub trials: usize, - pub beta: f64, - pub trial_log: Option, -} - -fn ddos_param_space() -> ParamSpace { - ParamSpace::new(vec![ - ParamDef { name: "k".into(), param_type: ParamType::Integer { min: 1, max: 20 } }, - ParamDef { name: "threshold".into(), param_type: ParamType::Continuous { min: 0.1, max: 0.95 } }, - ParamDef { name: "window_secs".into(), param_type: ParamType::Integer { min: 10, max: 300 } }, - ParamDef { name: "min_events".into(), param_type: ParamType::Integer { min: 3, max: 50 } }, - ParamDef { name: "request_rate".into(), param_type: ParamType::Continuous { min: 1.0, max: 100.0 } }, - ParamDef { name: "path_repetition".into(), param_type: ParamType::Continuous { min: 0.3, max: 0.99 } }, - ParamDef { name: "error_rate".into(), param_type: ParamType::Continuous { min: 0.2, max: 0.95 } }, - ParamDef { name: "suspicious_path_ratio".into(), param_type: ParamType::Continuous { min: 0.05, max: 0.8 } }, - ParamDef { name: "no_cookies_threshold".into(), param_type: ParamType::Continuous { min: 0.01, max: 0.3 } }, - ParamDef { name: "no_cookies_path_count".into(), param_type: ParamType::Continuous { min: 5.0, max: 100.0 } }, - ]) -} - -pub fn run_autotune(args: AutotuneDdosArgs) -> Result<()> { - let space = ddos_param_space(); - let mut optimizer = BayesianOptimizer::new(space); - - let mut trial_log_file = if let Some(ref path) = args.trial_log { - Some(std::fs::File::create(path)?) - } else { - None - }; - - // Parse logs once upfront - eprintln!("Parsing logs from {}...", args.input); - let ip_states = parse_logs(&args.input)?; - eprintln!(" {} unique IPs", ip_states.len()); - - let mut best_objective = f64::NEG_INFINITY; - let mut best_model_bytes: Option> = None; - - // Create a temporary directory for intermediate models - let tmp_dir = tempfile::tempdir().context("creating temp dir")?; - - eprintln!("Starting DDoS autotune: {} trials, beta={}", args.trials, args.beta); - - for trial_num in 1..=args.trials { - let params = optimizer.suggest(); - let k = params[0] as usize; - let threshold = params[1]; - let window_secs = params[2] as u64; - let min_events = params[3] as usize; - let request_rate = params[4]; - let path_repetition = params[5]; - let error_rate = params[6]; - let suspicious_path_ratio = params[7]; - let no_cookies_threshold = params[8]; - let no_cookies_path_count = params[9]; - - let heuristics = HeuristicThresholds::new( - request_rate, - path_repetition, - error_rate, - suspicious_path_ratio, - no_cookies_threshold, - no_cookies_path_count, - min_events, - ); - - let start = Instant::now(); - - // Train model with these parameters - let train_result = match train_model_from_states( - &ip_states, &heuristics, k, threshold, window_secs, min_events, - ) { - Ok(r) => r, - Err(e) => { - eprintln!(" trial {trial_num}: TRAIN FAILED ({e})"); - optimizer.observe(params, 0.0, start.elapsed()); - continue; - } - }; - - // Save temporary model for replay - let tmp_model_path = tmp_dir.path().join(format!("trial_{trial_num}.bin")); - let encoded = match bincode::serialize(&train_result.model) { - Ok(e) => e, - Err(e) => { - eprintln!(" trial {trial_num}: SERIALIZE FAILED ({e})"); - optimizer.observe(params, 0.0, start.elapsed()); - continue; - } - }; - if let Err(e) = std::fs::write(&tmp_model_path, &encoded) { - eprintln!(" trial {trial_num}: WRITE FAILED ({e})"); - optimizer.observe(params, 0.0, start.elapsed()); - continue; - } - - // Replay to evaluate - let replay_args = ReplayArgs { - input: args.input.clone(), - model_path: tmp_model_path.to_string_lossy().into_owned(), - config_path: None, - k, - threshold, - window_secs, - min_events, - rate_limit: false, - }; - - let replay_result = match replay_and_evaluate(&replay_args) { - Ok(r) => r, - Err(e) => { - eprintln!(" trial {trial_num}: REPLAY FAILED ({e})"); - optimizer.observe(params, 0.0, start.elapsed()); - continue; - } - }; - let duration = start.elapsed(); - - // Compute F-beta from replay false-positive analysis - let tp = replay_result.true_positive_ips as f64; - let fp = replay_result.false_positive_ips as f64; - let total_blocked = replay_result.ddos_blocked_ips.len() as f64; - let fn_ = if total_blocked > 0.0 { 0.0 } else { 1.0 }; // We don't know true FN without ground truth - - let objective = if tp + fp > 0.0 { - let precision = tp / (tp + fp); - let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 }; - let b2 = args.beta * args.beta; - if precision + recall > 0.0 { - (1.0 + b2) * precision * recall / (b2 * precision + recall) - } else { - 0.0 - } - } else { - 0.0 - }; - - eprintln!( - " trial {trial_num}/{}: fbeta={objective:.4} (k={k}, thr={threshold:.3}, win={window_secs}s, tp={}, fp={}) [{:.1}s]", - args.trials, - replay_result.true_positive_ips, - replay_result.false_positive_ips, - duration.as_secs_f64(), - ); - - // Log trial as JSONL - if let Some(ref mut f) = trial_log_file { - let trial_json = serde_json::json!({ - "trial": trial_num, - "params": { - "k": k, - "threshold": threshold, - "window_secs": window_secs, - "min_events": min_events, - "request_rate": request_rate, - "path_repetition": path_repetition, - "error_rate": error_rate, - "suspicious_path_ratio": suspicious_path_ratio, - "no_cookies_threshold": no_cookies_threshold, - "no_cookies_path_count": no_cookies_path_count, - }, - "objective": objective, - "duration_secs": duration.as_secs_f64(), - "true_positive_ips": replay_result.true_positive_ips, - "false_positive_ips": replay_result.false_positive_ips, - "ddos_blocked": replay_result.ddos_blocked, - "allowed": replay_result.allowed, - "attack_count": train_result.attack_count, - "normal_count": train_result.normal_count, - }); - writeln!(f, "{}", trial_json)?; - } - - if objective > best_objective { - best_objective = objective; - best_model_bytes = Some(encoded); - } - - // Clean up temporary model - let _ = std::fs::remove_file(&tmp_model_path); - - optimizer.observe(params, objective, duration); - } - - // Save best model - if let Some(bytes) = best_model_bytes { - std::fs::write(&args.output, &bytes)?; - eprintln!("\nBest model saved to {}", args.output); - } - - // Print summary - if let Some(best) = optimizer.best() { - eprintln!("\n═══ Autotune Results ═══════════════════════════════════════"); - eprintln!(" Best trial: #{}", best.trial_num); - eprintln!(" Best F-beta: {:.4}", best.objective); - eprintln!(" Parameters:"); - for (name, val) in best.param_names.iter().zip(best.params.iter()) { - eprintln!(" {:<30} = {:.6}", name, val); - } - eprintln!("\n Heuristics TOML snippet:"); - eprintln!(" request_rate = {:.2}", best.params[4]); - eprintln!(" path_repetition = {:.4}", best.params[5]); - eprintln!(" error_rate = {:.4}", best.params[6]); - eprintln!(" suspicious_path_ratio = {:.4}", best.params[7]); - eprintln!(" no_cookies_threshold = {:.4}", best.params[8]); - eprintln!(" no_cookies_path_count = {:.1}", best.params[9]); - eprintln!(" min_events = {}", best.params[3] as usize); - eprintln!("\n Reproduce:"); - eprintln!( - " cargo run -- train-ddos --input {} --output {} --k {} --threshold {:.4} --window-secs {} --min-events {} --heuristics ", - args.input, args.output, - best.params[0] as usize, best.params[1], - best.params[2] as u64, best.params[3] as usize, - ); - eprintln!("══════════════════════════════════════════════════════════"); - } - - Ok(()) -} +// Legacy KNN autotune removed — ensemble models are tuned via +// `cargo run --features training -- sweep-cookie-weight` and the +// training pipeline in src/training/. diff --git a/src/autotune/scanner.rs b/src/autotune/scanner.rs index 110c940..82bfb72 100644 --- a/src/autotune/scanner.rs +++ b/src/autotune/scanner.rs @@ -1,128 +1,6 @@ -use crate::autotune::optimizer::BayesianOptimizer; -use crate::autotune::params::{ParamDef, ParamSpace, ParamType}; -use crate::scanner::train::{TrainScannerArgs, train_and_evaluate}; -use anyhow::Result; -use std::io::Write; -use std::time::Instant; +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 -pub struct AutotuneScannerArgs { - pub input: String, - pub output: String, - pub wordlists: Option, - pub csic: bool, - pub trials: usize, - pub beta: f64, - pub trial_log: Option, -} - -fn scanner_param_space() -> ParamSpace { - ParamSpace::new(vec![ - ParamDef { name: "threshold".into(), param_type: ParamType::Continuous { min: 0.1, max: 0.95 } }, - ParamDef { name: "learning_rate".into(), param_type: ParamType::LogScale { min: 0.001, max: 0.1 } }, - ParamDef { name: "epochs".into(), param_type: ParamType::Integer { min: 100, max: 5000 } }, - ParamDef { name: "class_weight_multiplier".into(), param_type: ParamType::Continuous { min: 0.5, max: 5.0 } }, - ]) -} - -pub fn run_autotune(args: AutotuneScannerArgs) -> Result<()> { - let space = scanner_param_space(); - let mut optimizer = BayesianOptimizer::new(space); - - let mut trial_log_file = if let Some(ref path) = args.trial_log { - Some(std::fs::File::create(path)?) - } else { - None - }; - - let mut best_objective = f64::NEG_INFINITY; - let mut best_model_bytes: Option> = None; - - eprintln!("Starting scanner autotune: {} trials, beta={}", args.trials, args.beta); - - for trial_num in 1..=args.trials { - let params = optimizer.suggest(); - let threshold = params[0]; - let learning_rate = params[1]; - let epochs = params[2] as usize; - let class_weight_multiplier = params[3]; - - let train_args = TrainScannerArgs { - input: args.input.clone(), - output: String::new(), // don't save intermediate models - wordlists: args.wordlists.clone(), - threshold, - csic: args.csic, - }; - - let start = Instant::now(); - let result = match train_and_evaluate(&train_args, learning_rate, epochs, class_weight_multiplier) { - Ok(r) => r, - Err(e) => { - eprintln!(" trial {trial_num}: FAILED ({e})"); - optimizer.observe(params, 0.0, start.elapsed()); - continue; - } - }; - let duration = start.elapsed(); - - let objective = result.test_metrics.fbeta(args.beta); - - eprintln!( - " trial {trial_num}/{}: fbeta={objective:.4} (threshold={threshold:.3}, lr={learning_rate:.5}, epochs={epochs}, cwm={class_weight_multiplier:.2}) [{:.1}s]", - args.trials, - duration.as_secs_f64(), - ); - - // Log trial as JSONL - if let Some(ref mut f) = trial_log_file { - let trial_json = serde_json::json!({ - "trial": trial_num, - "params": { - "threshold": threshold, - "learning_rate": learning_rate, - "epochs": epochs, - "class_weight_multiplier": class_weight_multiplier, - }, - "objective": objective, - "duration_secs": duration.as_secs_f64(), - "train_f1": result.train_metrics.f1(), - "test_precision": result.test_metrics.precision(), - "test_recall": result.test_metrics.recall(), - }); - writeln!(f, "{}", trial_json)?; - } - - if objective > best_objective { - best_objective = objective; - let encoded = bincode::serialize(&result.model)?; - best_model_bytes = Some(encoded); - } - - optimizer.observe(params, objective, duration); - } - - // Save best model - if let Some(bytes) = best_model_bytes { - std::fs::write(&args.output, &bytes)?; - eprintln!("\nBest model saved to {}", args.output); - } - - // Print summary - if let Some(best) = optimizer.best() { - eprintln!("\n═══ Autotune Results ═══════════════════════════════════════"); - eprintln!(" Best trial: #{}", best.trial_num); - eprintln!(" Best F-beta: {:.4}", best.objective); - eprintln!(" Parameters:"); - for (name, val) in best.param_names.iter().zip(best.params.iter()) { - eprintln!(" {:<30} = {:.6}", name, val); - } - eprintln!("\n Reproduce:"); - eprintln!( - " cargo run -- train-scanner --input {} --output {} --threshold {:.4}", - args.input, args.output, best.params[0], - ); - eprintln!("══════════════════════════════════════════════════════════"); - } - - Ok(()) -} +// Legacy linear-model autotune removed — ensemble models are tuned via +// `cargo run --features training -- sweep-cookie-weight` and the +// training pipeline in src/training/. diff --git a/src/config.rs b/src/config.rs index 8e139c6..6f6b0e7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use anyhow::{Context, Result}; use serde::Deserialize; use std::fs; @@ -18,7 +21,7 @@ pub struct Config { pub routes: Vec, /// Optional SSH TCP passthrough (port 22 → Gitea SSH). pub ssh: Option, - /// Optional KNN-based DDoS detection. + /// Optional DDoS detection (ensemble: decision tree + MLP). pub ddos: Option, /// Optional per-identity rate limiting. pub rate_limit: Option, @@ -60,10 +63,6 @@ fn default_config_configmap() -> String { "pingora-config".to_string() } #[derive(Debug, Deserialize, Clone)] pub struct DDoSConfig { - #[serde(default)] - pub model_path: Option, - #[serde(default = "default_k")] - pub k: usize, #[serde(default = "default_threshold")] pub threshold: f64, #[serde(default = "default_window_secs")] @@ -74,8 +73,10 @@ pub struct DDoSConfig { pub min_events: usize, #[serde(default = "default_enabled")] pub enabled: bool, - #[serde(default = "default_use_ensemble")] - pub use_ensemble: bool, + /// When true, run the model and log decisions but never block traffic. + /// Useful for gathering data on model accuracy before enforcing. + #[serde(default)] + pub observe_only: bool, } #[derive(Debug, Deserialize, Clone)] @@ -100,23 +101,20 @@ pub struct BucketConfig { #[derive(Debug, Deserialize, Clone)] pub struct ScannerConfig { - #[serde(default)] - pub model_path: Option, #[serde(default = "default_scanner_threshold")] pub threshold: f64, #[serde(default = "default_scanner_enabled")] pub enabled: bool, - /// How often (seconds) to check the model file for changes. 0 = no hot-reload. - #[serde(default = "default_scanner_poll_interval")] - pub poll_interval_secs: u64, /// Bot allowlist rules. Verified bots bypass the scanner model. #[serde(default)] pub allowlist: Vec, /// TTL (seconds) for verified bot IP cache entries. #[serde(default = "default_bot_cache_ttl")] pub bot_cache_ttl_secs: u64, - #[serde(default = "default_use_ensemble")] - pub use_ensemble: bool, + /// When true, run the model and log decisions but never block traffic. + /// Useful for gathering data on model accuracy before enforcing. + #[serde(default)] + pub observe_only: bool, } #[derive(Debug, Deserialize, Clone)] @@ -136,17 +134,14 @@ pub struct BotAllowlistRule { } fn default_bot_cache_ttl() -> u64 { 86400 } // 24h -fn default_use_ensemble() -> bool { true } fn default_scanner_threshold() -> f64 { 0.5 } fn default_scanner_enabled() -> bool { true } -fn default_scanner_poll_interval() -> u64 { 30 } fn default_rl_enabled() -> bool { true } fn default_eviction_interval() -> u64 { 300 } fn default_stale_after() -> u64 { 600 } -fn default_k() -> usize { 5 } fn default_threshold() -> f64 { 0.6 } fn default_window_secs() -> u64 { 60 } fn default_window_capacity() -> usize { 1000 } diff --git a/src/dataset/cicids.rs b/src/dataset/cicids.rs index 12abc47..d619e24 100644 --- a/src/dataset/cicids.rs +++ b/src/dataset/cicids.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! CIC-IDS2017 timing profile extractor. //! //! Parses CIC-IDS2017 CSV files and extracts statistical timing profiles @@ -218,6 +221,262 @@ fn parse_csv_file( Ok(()) } +/// Convert CIC-IDS2017 flow records directly into DDoS training samples. +/// +/// Maps network-layer flow features to our 14-dimensional HTTP-layer feature vector. +/// Non-BENIGN labels → attack (1.0), BENIGN → normal (0.0). +/// Uses a deterministic RNG seeded per-row to fill HTTP-only features (cookies, etc.) +/// that don't exist in the network-layer data. +pub fn extract_ddos_samples(csv_dir: &Path) -> Result> { + use crate::dataset::sample::TrainingSample; + use rand::prelude::*; + use rand::rngs::StdRng; + + let entries: Vec = if csv_dir.is_file() { + vec![csv_dir.to_path_buf()] + } else { + let mut files: Vec = std::fs::read_dir(csv_dir) + .with_context(|| format!("reading directory {}", csv_dir.display()))? + .filter_map(|e| e.ok()) + .map(|e| e.path()) + .filter(|p| { + p.extension() + .map(|e| e.to_ascii_lowercase() == "csv") + .unwrap_or(false) + }) + .collect(); + files.sort(); + files + }; + + if entries.is_empty() { + anyhow::bail!("no CSV files found in {}", csv_dir.display()); + } + + let mut samples = Vec::new(); + let mut rng = StdRng::seed_from_u64(0xC1C1D5); + + for csv_path in &entries { + let file_samples = extract_ddos_samples_from_csv(csv_path, &mut rng) + .with_context(|| format!("extracting DDoS samples from {}", csv_path.display()))?; + let filename = csv_path.file_name().unwrap_or_default().to_string_lossy(); + let attack_count = file_samples.iter().filter(|s| s.label > 0.5).count(); + let normal_count = file_samples.len() - attack_count; + eprintln!( + " {}: {} samples ({} attack, {} normal)", + filename, + file_samples.len(), + attack_count, + normal_count + ); + samples.extend(file_samples); + } + + // Subsample if too large — cap at 500K to keep training tractable. + let max_samples = 500_000; + if samples.len() > max_samples { + // Stratified subsample: keep attack ratio balanced. + let mut attacks: Vec = + samples.iter().filter(|s| s.label > 0.5).cloned().collect(); + let mut normals: Vec = + samples.iter().filter(|s| s.label <= 0.5).cloned().collect(); + + // Shuffle both + attacks.shuffle(&mut rng); + normals.shuffle(&mut rng); + + // Take equal parts, favoring attacks if underrepresented + let attack_cap = max_samples / 2; + let normal_cap = max_samples - attack_cap.min(attacks.len()); + attacks.truncate(attack_cap); + normals.truncate(normal_cap); + + eprintln!( + " subsampled to {} ({} attack, {} normal)", + attacks.len() + normals.len(), + attacks.len(), + normals.len() + ); + samples = attacks; + samples.extend(normals); + samples.shuffle(&mut rng); + } + + Ok(samples) +} + +fn extract_ddos_samples_from_csv( + path: &Path, + rng: &mut rand::rngs::StdRng, +) -> Result> { + use crate::dataset::sample::{DataSource, TrainingSample}; + use crate::ddos::features::NUM_FEATURES; + use rand::prelude::*; + + let mut rdr = csv::ReaderBuilder::new() + .flexible(true) + .trim(csv::Trim::All) + .from_path(path)?; + + let headers: Vec = rdr.headers()?.iter().map(|h| h.to_string()).collect(); + + let col_label = find_column(&headers, "Label") + .with_context(|| format!("missing 'Label' in {}", path.display()))?; + let col_flow_duration = find_column(&headers, "Flow Duration"); + let col_total_fwd_pkts = find_column(&headers, "Total Fwd Packets"); + let col_total_bwd_pkts = find_column(&headers, "Total Backward Packets"); + let col_flow_pkts_s = find_column(&headers, "Flow Packets/s"); + let col_flow_iat_mean = find_column(&headers, "Flow IAT Mean"); + let col_avg_pkt_size = find_column(&headers, "Average Packet Size"); + let col_pkt_len_std = find_column(&headers, "Packet Length Std"); + let col_syn_flag = find_column(&headers, "SYN Flag Count"); + + let mut samples = Vec::new(); + + for result in rdr.records() { + let record = match result { + Ok(r) => r, + Err(_) => continue, + }; + + let label_str = match record.get(col_label) { + Some(l) => l.trim().to_string(), + None => continue, + }; + if label_str.is_empty() { + continue; + } + + let is_attack = label_str != "BENIGN"; + let label_f32 = if is_attack { 1.0f32 } else { 0.0f32 }; + + // Parse numeric columns (0.0 fallback for missing/malformed). + let get_f64 = |col: Option| -> f64 { + col.and_then(|c| record.get(c)) + .and_then(|v| v.trim().parse::().ok()) + .unwrap_or(0.0) + }; + + let flow_duration_us = get_f64(col_flow_duration); // microseconds + let total_fwd_pkts = get_f64(col_total_fwd_pkts); + let total_bwd_pkts = get_f64(col_total_bwd_pkts); + let flow_pkts_s = get_f64(col_flow_pkts_s); + let flow_iat_mean = get_f64(col_flow_iat_mean); // microseconds + let avg_pkt_size = get_f64(col_avg_pkt_size); + let pkt_len_std = get_f64(col_pkt_len_std); + let syn_flag = get_f64(col_syn_flag); + + let flow_duration_s = (flow_duration_us / 1_000_000.0).max(0.001); + let total_pkts = total_fwd_pkts + total_bwd_pkts; + + // Map to our 14 HTTP-layer DDoS features: + let mut features = vec![0.0f32; NUM_FEATURES]; + + // 0: request_rate — packets/sec as proxy for requests/sec + features[0] = flow_pkts_s.max(0.0).min(10000.0) as f32; + + // 1: unique_paths — approximate from packet diversity (std/mean ratio) + let diversity = if avg_pkt_size > 0.0 { + (pkt_len_std / avg_pkt_size).min(10.0) + } else { + 0.0 + }; + features[1] = (diversity * 5.0 + 1.0) as f32; + + // 2: unique_hosts — infer from port (attack traffic often targets one host) + features[2] = if is_attack { 1.0 } else { rng.random_range(1.0..5.0) as f32 }; + + // 3: error_rate — SYN-heavy flows suggest connection errors + let error_signal = if total_pkts > 0.0 { + (syn_flag / total_pkts.max(1.0)).min(1.0) + } else { + 0.0 + }; + features[3] = if is_attack { + (error_signal + rng.random_range(0.1..0.5)).min(1.0) as f32 + } else { + (error_signal * 0.3) as f32 + }; + + // 4: avg_duration_ms — flow duration / total packets, in ms + features[4] = if total_pkts > 0.0 { + ((flow_duration_s * 1000.0) / total_pkts).min(5000.0) as f32 + } else { + 0.0 + }; + + // 5: method_entropy — low for attacks (single method), moderate for normal + features[5] = if is_attack { + rng.random_range(0.0..0.3) as f32 + } else { + rng.random_range(0.2..1.5) as f32 + }; + + // 6: burst_score — inverse of inter-arrival time + let iat_s = (flow_iat_mean / 1_000_000.0).max(0.001); + features[6] = (1.0 / iat_s).min(500.0) as f32; + + // 7: path_repetition — attacks repeat paths heavily + features[7] = if is_attack { + rng.random_range(0.6..1.0) as f32 + } else { + rng.random_range(0.05..0.4) as f32 + }; + + // 8: avg_content_length — from average packet size + features[8] = avg_pkt_size.max(0.0).min(10000.0) as f32; + + // 9: unique_user_agents — low for attacks + features[9] = if is_attack { + rng.random_range(1.0..2.0) as f32 + } else { + rng.random_range(1.0..4.0) as f32 + }; + + // 10: cookie_ratio — bots don't send cookies + features[10] = if is_attack { + rng.random_range(0.0..0.1) as f32 + } else { + rng.random_range(0.6..1.0) as f32 + }; + + // 11: referer_ratio — bots rarely send referer + features[11] = if is_attack { + rng.random_range(0.0..0.1) as f32 + } else { + rng.random_range(0.3..1.0) as f32 + }; + + // 12: accept_language_ratio — bots don't send this + features[12] = if is_attack { + rng.random_range(0.0..0.1) as f32 + } else { + rng.random_range(0.6..1.0) as f32 + }; + + // 13: suspicious_path_ratio — attacks may probe paths + let is_web_attack = label_str.contains("Web Attack") + || label_str.contains("Bot") + || label_str.contains("Infiltration"); + features[13] = if is_web_attack { + rng.random_range(0.2..0.7) as f32 + } else if is_attack { + rng.random_range(0.0..0.3) as f32 + } else { + rng.random_range(0.0..0.05) as f32 + }; + + samples.push(TrainingSample { + features, + label: label_f32, + source: DataSource::SyntheticCicTiming, + weight: 0.7, // higher than pure synthetic (0.5), lower than prod (1.0) + }); + } + + Ok(samples) +} + /// Parse timing profiles from an in-memory CSV string (useful for tests). pub fn extract_timing_profiles_from_str(csv_content: &str) -> Result> { let dir = tempfile::tempdir()?; diff --git a/src/dataset/download.rs b/src/dataset/download.rs index c36b8cb..fa753e2 100644 --- a/src/dataset/download.rs +++ b/src/dataset/download.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! Download and cache upstream datasets for training. //! //! Cached under `~/.cache/sunbeam//`. Files are only downloaded @@ -19,8 +22,17 @@ fn cache_base() -> PathBuf { // --- CIC-IDS2017 --- -/// Only the Friday DDoS file — contains DDoS Hulk, Slowloris, slowhttptest, GoldenEye. -const CICIDS_FILE: &str = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv"; +/// All CIC-IDS2017 CSV files — covers every attack day and normal baselines. +const CICIDS_FILES: &[&str] = &[ + "Monday-WorkingHours.pcap_ISCX.csv", + "Tuesday-WorkingHours.pcap_ISCX.csv", + "Wednesday-workingHours.pcap_ISCX.csv", + "Thursday-WorkingHours-Morning-WebAttacks.pcap_ISCX.csv", + "Thursday-WorkingHours-Afternoon-Infilteration.pcap_ISCX.csv", + "Friday-WorkingHours-Morning.pcap_ISCX.csv", + "Friday-WorkingHours-Afternoon-PortScan.pcap_ISCX.csv", + "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv", +]; /// Hugging Face mirror (public, no auth required). const CICIDS_BASE_URL: &str = @@ -30,49 +42,56 @@ fn cicids_cache_dir() -> PathBuf { cache_base().join("cicids") } -/// Return the path to the cached CIC-IDS2017 DDoS CSV, or `None` if not downloaded. +/// Return the cache directory if ALL CIC-IDS2017 CSVs are downloaded, else `None`. pub fn cicids_cached_path() -> Option { - let path = cicids_cache_dir().join(CICIDS_FILE); - if path.exists() { - Some(path) + let dir = cicids_cache_dir(); + if CICIDS_FILES.iter().all(|f| dir.join(f).exists()) { + Some(dir) } else { None } } -/// Download the CIC-IDS2017 Friday DDoS CSV to cache. Returns the cached path. +/// Download all CIC-IDS2017 CSV files to cache. Returns the cache directory. pub fn download_cicids() -> Result { let dir = cicids_cache_dir(); - let path = dir.join(CICIDS_FILE); - - if path.exists() { - eprintln!(" cached: {}", path.display()); - return Ok(path); - } - - let url = format!("{CICIDS_BASE_URL}/{CICIDS_FILE}"); - eprintln!(" downloading: {url}"); - eprintln!(" (this is ~170 MB, may take a minute)"); - std::fs::create_dir_all(&dir)?; - // Stream to file to avoid holding 170MB in memory. - let resp = reqwest::blocking::Client::builder() + let client = reqwest::blocking::Client::builder() .timeout(std::time::Duration::from_secs(600)) - .build()? - .get(&url) - .send() - .with_context(|| format!("fetching {url}"))? - .error_for_status() - .with_context(|| format!("HTTP error for {url}"))?; + .build()?; - let mut file = std::fs::File::create(&path) - .with_context(|| format!("creating {}", path.display()))?; - let bytes = resp.bytes().with_context(|| "reading response body")?; - std::io::Write::write_all(&mut file, &bytes)?; + for (i, filename) in CICIDS_FILES.iter().enumerate() { + let path = dir.join(filename); + if path.exists() { + eprintln!(" [{}/{}] cached: {}", i + 1, CICIDS_FILES.len(), filename); + continue; + } - eprintln!(" saved: {}", path.display()); - Ok(path) + let url = format!("{CICIDS_BASE_URL}/{filename}"); + eprintln!( + " [{}/{}] downloading: {}", + i + 1, + CICIDS_FILES.len(), + filename + ); + + let resp = client + .get(&url) + .send() + .with_context(|| format!("fetching {url}"))? + .error_for_status() + .with_context(|| format!("HTTP error for {url}"))?; + + let mut file = std::fs::File::create(&path) + .with_context(|| format!("creating {}", path.display()))?; + let bytes = resp.bytes().with_context(|| "reading response body")?; + std::io::Write::write_all(&mut file, &bytes)?; + + eprintln!(" saved: {}", path.display()); + } + + Ok(dir) } // --- CSIC 2010 --- @@ -80,7 +99,10 @@ pub fn download_cicids() -> Result { /// Download CSIC 2010 dataset files to cache (delegates to scanner::csic). pub fn download_csic() -> Result<()> { if crate::scanner::csic::csic_is_cached() { - eprintln!(" cached: {}", crate::scanner::csic::csic_cache_path().display()); + eprintln!( + " cached: {}", + crate::scanner::csic::csic_cache_path().display() + ); return Ok(()); } // fetch_csic_dataset downloads, caches, and parses — we only need the download side-effect. @@ -96,9 +118,9 @@ pub fn download_all() -> Result<()> { download_csic()?; eprintln!(); - eprintln!("[2/2] CIC-IDS2017 DDoS timing profiles"); + eprintln!("[2/2] CIC-IDS2017 (all attack days + normal baselines)"); let path = download_cicids()?; - eprintln!(" ok: {}\n", path.display()); + eprintln!(" ok: {} ({} files)\n", path.display(), CICIDS_FILES.len()); eprintln!("all datasets cached."); Ok(()) @@ -116,4 +138,10 @@ mod tests { let cicids = cicids_cache_dir(); assert!(cicids.to_str().unwrap().contains("cicids")); } + + #[test] + fn test_all_files_listed() { + assert_eq!(CICIDS_FILES.len(), 8); + assert!(CICIDS_FILES.iter().all(|f| f.ends_with(".csv"))); + } } diff --git a/src/dataset/modsec.rs b/src/dataset/modsec.rs index a35eb23..5b4443e 100644 --- a/src/dataset/modsec.rs +++ b/src/dataset/modsec.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! Parser for OWASP ModSecurity audit log files (Serial / concurrent format). //! //! ModSecurity audit logs consist of multi-section entries delimited by boundary @@ -176,6 +179,7 @@ fn transaction_to_audit_fields( let content_length: u64 = get_header("content-length") .and_then(|v| v.parse().ok()) .unwrap_or(0); + let accept = get_header("accept").filter(|a| a != "-" && !a.is_empty()); // Section F: response status let status = sections @@ -216,11 +220,13 @@ fn transaction_to_audit_fields( duration_ms: 0, content_length, user_agent, - has_cookies: Some(has_cookies), - referer, - accept_language, + has_cookies, + referer: referer.unwrap_or_else(|| "-".to_string()), + accept_language: accept_language.unwrap_or_else(|| "-".to_string()), + accept: accept.unwrap_or_else(|| "-".to_string()), backend: "-".to_string(), label: Some(label.clone()), + ..AuditFields::default() }; Some((fields, label)) @@ -302,7 +308,7 @@ Content-Type: text/html assert_eq!(attack_fields.client_ip, "192.168.1.100"); assert_eq!(attack_fields.user_agent, "curl/7.68.0"); assert_eq!(attack_fields.status, 403); - assert!(!attack_fields.has_cookies.unwrap_or(true)); + assert!(!attack_fields.has_cookies); // Second entry: normal (no rule match). let (normal_fields, normal_label) = &results[1]; @@ -311,9 +317,9 @@ Content-Type: text/html assert_eq!(normal_fields.path, "/index.html"); assert_eq!(normal_fields.client_ip, "10.0.0.50"); assert_eq!(normal_fields.status, 200); - assert!(normal_fields.has_cookies.unwrap_or(false)); - assert!(normal_fields.referer.is_some()); - assert!(normal_fields.accept_language.is_some()); + assert!(normal_fields.has_cookies); + assert!(normal_fields.referer != "-"); + assert!(normal_fields.accept_language != "-"); } #[test] diff --git a/src/dataset/prepare.rs b/src/dataset/prepare.rs index c90d1cd..9ddbabf 100644 --- a/src/dataset/prepare.rs +++ b/src/dataset/prepare.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! Dataset preparation orchestrator. //! //! Combines production logs, external datasets (CSIC, OWASP ModSec), and @@ -30,6 +33,10 @@ pub struct PrepareDatasetArgs { pub seed: u64, /// Path to heuristics.toml for auto-labeling production logs. pub heuristics: Option, + /// Inject CSIC 2010 entries as labeled audit logs into production stream. + pub inject_csic: bool, + /// Inject OWASP ModSec entries as labeled audit logs (path to .log file). + pub inject_modsec: Option, } impl Default for PrepareDatasetArgs { @@ -41,6 +48,8 @@ impl Default for PrepareDatasetArgs { output: "dataset.bin".to_string(), seed: 42, heuristics: None, + inject_csic: false, + inject_modsec: None, } } } @@ -71,18 +80,21 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> { scanner_samples.extend(prod_scanner); ddos_samples.extend(prod_ddos); - // --- 2. CSIC 2010 (scanner) --- - eprintln!("fetching CSIC 2010 dataset..."); - let csic_entries = crate::scanner::csic::fetch_csic_dataset()?; - let csic_samples = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?; - eprintln!(" CSIC: {} scanner samples", csic_samples.len()); - scanner_samples.extend(csic_samples); + // --- 2. Inject external datasets as labeled audit log entries --- + // These go through the same feature extraction as production logs, + // with ground-truth labels (no heuristic labeling needed). + if args.inject_csic { + eprintln!("injecting CSIC 2010 as labeled audit entries..."); + let csic_entries = crate::scanner::csic::fetch_csic_dataset()?; + let csic_scanner = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?; + eprintln!(" CSIC injected: {} scanner samples", csic_scanner.len()); + scanner_samples.extend(csic_scanner); + } - // --- 3. OWASP ModSec (scanner) --- - if let Some(owasp_path) = &args.owasp { - eprintln!("parsing OWASP ModSec audit log from {owasp_path}..."); + if let Some(modsec_path) = &args.inject_modsec { + eprintln!("injecting ModSec audit log from {modsec_path}..."); let modsec_entries = - crate::dataset::modsec::parse_modsec_audit_log(Path::new(owasp_path))?; + crate::dataset::modsec::parse_modsec_audit_log(Path::new(modsec_path))?; let entries_with_host: Vec<(AuditFields, String)> = modsec_entries .into_iter() .map(|(fields, _label)| { @@ -90,16 +102,49 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> { (fields, host_prefix) }) .collect(); - let modsec_samples = + let modsec_scanner = entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?; - eprintln!(" OWASP: {} scanner samples", modsec_samples.len()); - scanner_samples.extend(modsec_samples); + eprintln!(" ModSec injected: {} scanner samples", modsec_scanner.len()); + scanner_samples.extend(modsec_scanner); } - // --- 4. CIC-IDS2017 timing profiles (from cache if downloaded) --- + // --- 3. Legacy OWASP path (kept for backwards compat) --- + if let Some(owasp_path) = &args.owasp { + if args.inject_modsec.as_deref() != Some(owasp_path.as_str()) { + eprintln!("parsing OWASP ModSec audit log from {owasp_path}..."); + let modsec_entries = + crate::dataset::modsec::parse_modsec_audit_log(Path::new(owasp_path))?; + let entries_with_host: Vec<(AuditFields, String)> = modsec_entries + .into_iter() + .map(|(fields, _label)| { + let host_prefix = fields.host.split('.').next().unwrap_or("").to_string(); + (fields, host_prefix) + }) + .collect(); + let modsec_samples = + entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?; + eprintln!(" OWASP: {} scanner samples", modsec_samples.len()); + scanner_samples.extend(modsec_samples); + } + } + + // --- 4. CIC-IDS2017 (direct DDoS samples + timing profiles for synthetic) --- let cicids_profiles = if let Some(cached_path) = crate::dataset::download::cicids_cached_path() { - eprintln!("extracting CIC-IDS2017 timing profiles from cache..."); + // Direct conversion: CIC-IDS2017 flows → DDoS training samples + eprintln!("extracting CIC-IDS2017 DDoS samples from cache..."); + let cicids_ddos = crate::dataset::cicids::extract_ddos_samples(&cached_path)?; + let attack_count = cicids_ddos.iter().filter(|s| s.label > 0.5).count(); + eprintln!( + " CIC-IDS2017 direct: {} DDoS samples ({} attack, {} normal)", + cicids_ddos.len(), + attack_count, + cicids_ddos.len() - attack_count + ); + ddos_samples.extend(cicids_ddos); + + // Also extract timing profiles for synthetic generation + eprintln!("extracting CIC-IDS2017 timing profiles..."); let profiles = crate::dataset::cicids::extract_timing_profiles(&cached_path)?; eprintln!(" extracted {} attack-type profiles", profiles.len()); profiles @@ -112,10 +157,10 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> { // --- 5. Synthetic data (both models, always generated) --- eprintln!("generating synthetic samples..."); let config = crate::dataset::synthetic::SyntheticConfig { - num_ddos_attack: 10000, - num_ddos_normal: 10000, - num_scanner_attack: 5000, - num_scanner_normal: 5000, + num_ddos_attack: 50000, + num_ddos_normal: 50000, + num_scanner_attack: 25000, + num_scanner_normal: 25000, seed: args.seed, }; @@ -240,17 +285,9 @@ fn parse_production_logs( // --- Scanner samples from production logs --- for (fields, host_prefix) in &parsed_entries { - let has_cookies = fields.has_cookies.unwrap_or(false); - let has_referer = fields - .referer - .as_ref() - .map(|r| r != "-" && !r.is_empty()) - .unwrap_or(false); - let has_accept_language = fields - .accept_language - .as_ref() - .map(|a| a != "-" && !a.is_empty()) - .unwrap_or(false); + let has_cookies = fields.has_cookies; + let has_referer = !fields.referer.is_empty() && fields.referer != "-"; + let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-"; let feats = features::extract_features( &fields.method, @@ -259,7 +296,7 @@ fn parse_production_logs( has_cookies, has_referer, has_accept_language, - "-", + &fields.accept, &fields.user_agent, fields.content_length, &fragment_hashes, @@ -352,20 +389,12 @@ fn extract_ddos_samples_from_entries( .push(fields.content_length.min(u32::MAX as u64) as u32); state .has_cookies - .push(fields.has_cookies.unwrap_or(false)); + .push(fields.has_cookies); state.has_referer.push( - fields - .referer - .as_deref() - .map(|r| r != "-") - .unwrap_or(false), + !fields.referer.is_empty() && fields.referer != "-", ); state.has_accept_language.push( - fields - .accept_language - .as_deref() - .map(|a| a != "-") - .unwrap_or(false), + !fields.accept_language.is_empty() && fields.accept_language != "-", ); state.suspicious_paths.push( crate::ddos::features::is_suspicious_path(&fields.path), @@ -462,17 +491,9 @@ fn entries_to_scanner_samples( let mut samples = Vec::new(); for (fields, host_prefix) in entries { - let has_cookies = fields.has_cookies.unwrap_or(false); - let has_referer = fields - .referer - .as_ref() - .map(|r| r != "-" && !r.is_empty()) - .unwrap_or(false); - let has_accept_language = fields - .accept_language - .as_ref() - .map(|a| a != "-" && !a.is_empty()) - .unwrap_or(false); + let has_cookies = fields.has_cookies; + let has_referer = !fields.referer.is_empty() && fields.referer != "-"; + let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-"; let feats = features::extract_features( &fields.method, @@ -481,7 +502,7 @@ fn entries_to_scanner_samples( has_cookies, has_referer, has_accept_language, - "-", + &fields.accept, &fields.user_agent, fields.content_length, &fragment_hashes, @@ -587,11 +608,12 @@ mod tests { duration_ms: 10, content_length: 0, user_agent: "Mozilla/5.0".to_string(), - has_cookies: Some(true), - referer: Some("https://test.sunbeam.pt".to_string()), - accept_language: Some("en-US".to_string()), + has_cookies: true, + referer: "https://test.sunbeam.pt".to_string(), + accept_language: "en-US".to_string(), backend: "test-svc:8080".to_string(), label: Some(label.to_string()), + ..AuditFields::default() }; (fields, "test".to_string()) } diff --git a/src/ddos/audit_log.rs b/src/ddos/audit_log.rs index e022b55..f6f9c51 100644 --- a/src/ddos/audit_log.rs +++ b/src/ddos/audit_log.rs @@ -1,83 +1,11 @@ -use serde::Deserialize; +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 -#[derive(Deserialize)] -pub struct AuditLog { - pub timestamp: String, - pub fields: AuditFields, -} +//! Re-exports from `crate::audit` — the canonical audit log definition. +//! +//! All new code should `use crate::audit::*` directly. -#[derive(Deserialize)] -pub struct AuditFields { - pub method: String, - pub host: String, - pub path: String, - pub client_ip: String, - #[serde(deserialize_with = "flexible_u16")] - pub status: u16, - #[serde(deserialize_with = "flexible_u64")] - pub duration_ms: u64, - #[serde(default)] - pub backend: String, - #[serde(default)] - pub content_length: u64, - #[serde(default = "default_ua")] - pub user_agent: String, - #[serde(default)] - pub query: String, - #[serde(default)] - pub has_cookies: Option, - #[serde(default)] - pub referer: Option, - #[serde(default)] - pub accept_language: Option, - /// Optional ground-truth label from external datasets (e.g. CSIC 2010). - /// Values: "attack", "normal". When present, trainers should use this - /// instead of heuristic labeling. - #[serde(default)] - pub label: Option, -} - -fn default_ua() -> String { - "-".to_string() -} - -pub fn flexible_u64<'de, D: serde::Deserializer<'de>>( - deserializer: D, -) -> std::result::Result { - #[derive(Deserialize)] - #[serde(untagged)] - enum StringOrNum { - Num(u64), - Str(String), - } - match StringOrNum::deserialize(deserializer)? { - StringOrNum::Num(n) => Ok(n), - StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom), - } -} - -pub fn flexible_u16<'de, D: serde::Deserializer<'de>>( - deserializer: D, -) -> std::result::Result { - #[derive(Deserialize)] - #[serde(untagged)] - enum StringOrNum { - Num(u16), - Str(String), - } - match StringOrNum::deserialize(deserializer)? { - StringOrNum::Num(n) => Ok(n), - StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom), - } -} - -/// Strip the port suffix from a socket address string. -pub fn strip_port(addr: &str) -> &str { - if addr.starts_with('[') { - addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr) - } else if let Some(pos) = addr.rfind(':') { - &addr[..pos] - } else { - addr - } -} +pub use crate::audit::strip_port; +pub use crate::audit::AuditFields; +pub use crate::audit::AuditLogLine as AuditLog; +pub use crate::audit::{flexible_u16, flexible_u64}; diff --git a/src/ddos/detector.rs b/src/ddos/detector.rs index ffc5183..d712061 100644 --- a/src/ddos/detector.rs +++ b/src/ddos/detector.rs @@ -1,6 +1,9 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use crate::config::DDoSConfig; use crate::ddos::features::{method_to_u8, IpState, RequestEvent}; -use crate::ddos::model::{DDoSAction, TrainedModel}; +use crate::ddos::model::DDoSAction; use rustc_hash::FxHashMap; use std::hash::{Hash, Hasher}; use std::net::IpAddr; @@ -10,12 +13,10 @@ use std::time::Instant; const NUM_SHARDS: usize = 256; pub struct DDoSDetector { - model: TrainedModel, shards: Vec>>, window_secs: u64, window_capacity: usize, min_events: usize, - use_ensemble: bool, } fn shard_index(ip: &IpAddr) -> usize { @@ -25,34 +26,15 @@ fn shard_index(ip: &IpAddr) -> usize { } impl DDoSDetector { - pub fn new(model: TrainedModel, config: &DDoSConfig) -> Self { + pub fn new(config: &DDoSConfig) -> Self { let shards = (0..NUM_SHARDS) .map(|_| RwLock::new(FxHashMap::default())) .collect(); Self { - model, shards, window_secs: config.window_secs, window_capacity: config.window_capacity, min_events: config.min_events, - use_ensemble: false, - } - } - - /// Create a detector that uses the ensemble (decision tree + MLP) path. - /// A dummy model is still needed for fallback, but ensemble inference - /// takes priority when `use_ensemble` is true. - pub fn new_ensemble(model: TrainedModel, config: &DDoSConfig) -> Self { - let shards = (0..NUM_SHARDS) - .map(|_| RwLock::new(FxHashMap::default())) - .collect(); - Self { - model, - shards, - window_secs: config.window_secs, - window_capacity: config.window_capacity, - min_events: config.min_events, - use_ensemble: true, } } @@ -99,24 +81,20 @@ impl DDoSDetector { let features = state.extract_features(self.window_secs); - if self.use_ensemble { - // Cast f64 features to f32 array for ensemble inference. - let mut f32_features = [0.0f32; 14]; - for (i, &v) in features.iter().enumerate().take(14) { - f32_features[i] = v as f32; - } - let ev = crate::ensemble::ddos::ddos_ensemble_predict(&f32_features); - crate::metrics::DDOS_ENSEMBLE_PATH - .with_label_values(&[match ev.path { - crate::ensemble::ddos::DDoSEnsemblePath::TreeBlock => "tree_block", - crate::ensemble::ddos::DDoSEnsemblePath::TreeAllow => "tree_allow", - crate::ensemble::ddos::DDoSEnsemblePath::Mlp => "mlp", - }]) - .inc(); - return ev.action; + // Cast f64 features to f32 array for ensemble inference. + let mut f32_features = [0.0f32; 14]; + for (i, &v) in features.iter().enumerate().take(14) { + f32_features[i] = v as f32; } - - self.model.classify(&features) + let ev = crate::ensemble::ddos::ddos_ensemble_predict(&f32_features); + crate::metrics::DDOS_ENSEMBLE_PATH + .with_label_values(&[match ev.path { + crate::ensemble::ddos::DDoSEnsemblePath::TreeBlock => "tree_block", + crate::ensemble::ddos::DDoSEnsemblePath::TreeAllow => "tree_allow", + crate::ensemble::ddos::DDoSEnsemblePath::Mlp => "mlp", + }]) + .inc(); + ev.action } /// Feed response data back into the IP's event history. @@ -125,10 +103,6 @@ impl DDoSDetector { // Status/duration from check() are 0-initialized; the next request // will have fresh data. This is intentionally a no-op for now. } - - pub fn point_count(&self) -> usize { - self.model.point_count() - } } fn fx_hash(s: &str) -> u64 { diff --git a/src/ddos/mod.rs b/src/ddos/mod.rs index 93f324d..c369aa6 100644 --- a/src/ddos/mod.rs +++ b/src/ddos/mod.rs @@ -1,6 +1,8 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + pub mod audit_log; pub mod detector; pub mod features; pub mod model; -pub mod replay; pub mod train; diff --git a/src/ddos/model.rs b/src/ddos/model.rs index e655915..74b1cba 100644 --- a/src/ddos/model.rs +++ b/src/ddos/model.rs @@ -1,183 +1,8 @@ -use crate::ddos::features::{FeatureVector, NormParams, NUM_FEATURES}; -use anyhow::{Context, Result}; -use serde::{Deserialize, Serialize}; -use std::path::Path; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -pub enum TrafficLabel { - Normal, - Attack, -} - -#[derive(Serialize, Deserialize)] -pub struct SerializedModel { - pub points: Vec, - pub labels: Vec, - pub norm_params: NormParams, - pub k: usize, - pub threshold: f64, -} - -pub struct TrainedModel { - /// Stored points (normalized). The kD-tree borrows these. - points: Vec<[f64; NUM_FEATURES]>, - labels: Vec, - norm_params: NormParams, - k: usize, - threshold: f64, -} +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum DDoSAction { Allow, Block, } - -impl TrainedModel { - pub fn load(path: &Path, k_override: Option, threshold_override: Option) -> Result { - let data = std::fs::read(path) - .with_context(|| format!("reading model from {}", path.display()))?; - let model: SerializedModel = - bincode::deserialize(&data).context("deserializing model")?; - Ok(Self { - points: model.points, - labels: model.labels, - norm_params: model.norm_params, - k: k_override.unwrap_or(model.k), - threshold: threshold_override.unwrap_or(model.threshold), - }) - } - - /// Create an empty model (no training points). Used when the ensemble - /// path is active and the KNN model is not needed. - pub fn empty(k: usize, threshold: f64) -> Self { - Self { - points: vec![], - labels: vec![], - norm_params: NormParams { - mins: [0.0; NUM_FEATURES], - maxs: [1.0; NUM_FEATURES], - }, - k, - threshold, - } - } - - pub fn from_serialized(model: SerializedModel) -> Self { - Self { - points: model.points, - labels: model.labels, - norm_params: model.norm_params, - k: model.k, - threshold: model.threshold, - } - } - - pub fn classify(&self, features: &FeatureVector) -> DDoSAction { - let normalized = self.norm_params.normalize(features); - - if self.points.is_empty() { - return DDoSAction::Allow; - } - - // Build tree on-the-fly for query. In production with many queries, - // we'd cache this, but the tree build is fast for <100K points. - // fnntw::Tree borrows data, so we build it here. - let tree = match fnntw::Tree::<'_, f64, NUM_FEATURES>::new(&self.points, 32) { - Ok(t) => t, - Err(_) => return DDoSAction::Allow, - }; - - let k = self.k.min(self.points.len()); - let result = tree.query_nearest_k(&normalized, k); - match result { - Ok((_distances, indices)) => { - let attack_count = indices - .iter() - .filter(|&&idx| self.labels[idx as usize] == TrafficLabel::Attack) - .count(); - let attack_frac = attack_count as f64 / k as f64; - if attack_frac >= self.threshold { - DDoSAction::Block - } else { - DDoSAction::Allow - } - } - Err(_) => DDoSAction::Allow, - } - } - - pub fn norm_params(&self) -> &NormParams { - &self.norm_params - } - - pub fn point_count(&self) -> usize { - self.points.len() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_classify_empty_model() { - let model = TrainedModel { - points: vec![], - labels: vec![], - norm_params: NormParams { - mins: [0.0; NUM_FEATURES], - maxs: [1.0; NUM_FEATURES], - }, - k: 5, - threshold: 0.6, - }; - assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow); - } - - fn make_test_points(n: usize) -> Vec { - (0..n) - .map(|i| { - let mut v = [0.0; NUM_FEATURES]; - for d in 0..NUM_FEATURES { - v[d] = ((i * (d + 1)) as f64 / n as f64) % 1.0; - } - v - }) - .collect() - } - - #[test] - fn test_classify_all_attack() { - let points = make_test_points(100); - let labels = vec![TrafficLabel::Attack; 100]; - let model = TrainedModel { - points, - labels, - norm_params: NormParams { - mins: [0.0; NUM_FEATURES], - maxs: [1.0; NUM_FEATURES], - }, - k: 5, - threshold: 0.6, - }; - assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Block); - } - - #[test] - fn test_classify_all_normal() { - let points = make_test_points(100); - let labels = vec![TrafficLabel::Normal; 100]; - let model = TrainedModel { - points, - labels, - norm_params: NormParams { - mins: [0.0; NUM_FEATURES], - maxs: [1.0; NUM_FEATURES], - }, - k: 5, - threshold: 0.6, - }; - assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow); - } -} diff --git a/src/ddos/replay.rs b/src/ddos/replay.rs deleted file mode 100644 index f7f26fc..0000000 --- a/src/ddos/replay.rs +++ /dev/null @@ -1,281 +0,0 @@ -use crate::config::{DDoSConfig, RateLimitConfig}; -use crate::ddos::audit_log::{self, AuditLog}; -use crate::ddos::detector::DDoSDetector; -use crate::ddos::model::{DDoSAction, TrainedModel}; -use crate::rate_limit::key::RateLimitKey; -use crate::rate_limit::limiter::{RateLimitResult, RateLimiter}; -use anyhow::{Context, Result}; -use rustc_hash::FxHashMap; -use std::io::BufRead; -use std::net::IpAddr; -use std::sync::Arc; - -pub struct ReplayArgs { - pub input: String, - pub model_path: String, - pub config_path: Option, - pub k: usize, - pub threshold: f64, - pub window_secs: u64, - pub min_events: usize, - pub rate_limit: bool, -} - -pub struct ReplayResult { - pub total: u64, - pub skipped: u64, - pub ddos_blocked: u64, - pub rate_limited: u64, - pub allowed: u64, - pub ddos_blocked_ips: FxHashMap, - pub rate_limited_ips: FxHashMap, - pub false_positive_ips: usize, - pub true_positive_ips: usize, -} - -/// Core replay pipeline: load model, replay logs, compute stats including false positive analysis. -pub fn replay_and_evaluate(args: &ReplayArgs) -> Result { - let model = TrainedModel::load( - std::path::Path::new(&args.model_path), - Some(args.k), - Some(args.threshold), - ) - .with_context(|| format!("loading model from {}", args.model_path))?; - - let ddos_cfg = DDoSConfig { - model_path: Some(args.model_path.clone()), - k: args.k, - threshold: args.threshold, - window_secs: args.window_secs, - window_capacity: 1000, - min_events: args.min_events, - enabled: true, - use_ensemble: false, - }; - let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg)); - - let rate_limiter = if args.rate_limit { - let rl_cfg = if let Some(cfg_path) = &args.config_path { - let cfg = crate::config::Config::load(cfg_path)?; - cfg.rate_limit.unwrap_or_else(default_rate_limit_config) - } else { - default_rate_limit_config() - }; - Some(RateLimiter::new(&rl_cfg)) - } else { - None - }; - - let file = std::fs::File::open(&args.input) - .with_context(|| format!("opening {}", args.input))?; - let reader = std::io::BufReader::new(file); - - let mut total = 0u64; - let mut skipped = 0u64; - let mut ddos_blocked = 0u64; - let mut rate_limited = 0u64; - let mut allowed = 0u64; - let mut ddos_blocked_ips: FxHashMap = FxHashMap::default(); - let mut rate_limited_ips: FxHashMap = FxHashMap::default(); - - for line in reader.lines() { - let line = line?; - let entry: AuditLog = match serde_json::from_str(&line) { - Ok(e) => e, - Err(_) => { - skipped += 1; - continue; - } - }; - - if entry.fields.method.is_empty() { - skipped += 1; - continue; - } - - total += 1; - - let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string(); - let ip: IpAddr = match ip_str.parse() { - Ok(ip) => ip, - Err(_) => { - skipped += 1; - continue; - } - }; - - let has_cookies = entry.fields.has_cookies.unwrap_or(false); - let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false); - let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false); - let ddos_action = detector.check( - ip, - &entry.fields.method, - &entry.fields.path, - &entry.fields.host, - &entry.fields.user_agent, - entry.fields.content_length, - has_cookies, - has_referer, - has_accept_language, - ); - - if ddos_action == DDoSAction::Block { - ddos_blocked += 1; - *ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1; - continue; - } - - if let Some(limiter) = &rate_limiter { - let rl_key = RateLimitKey::Ip(ip); - if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) { - rate_limited += 1; - *rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1; - continue; - } - } - - allowed += 1; - } - - // Compute false positive / true positive counts - let (false_positive_ips, true_positive_ips) = - count_fp_tp(&args.input, &ddos_blocked_ips, &rate_limited_ips)?; - - Ok(ReplayResult { - total, - skipped, - ddos_blocked, - rate_limited, - allowed, - ddos_blocked_ips, - rate_limited_ips, - false_positive_ips, - true_positive_ips, - }) -} - -/// Count false-positive and true-positive IPs from blocked set. -/// FP = blocked IP where >60% of original responses were 2xx/3xx. -fn count_fp_tp( - input: &str, - ddos_blocked_ips: &FxHashMap, - rate_limited_ips: &FxHashMap, -) -> Result<(usize, usize)> { - let blocked_ips: rustc_hash::FxHashSet<&str> = ddos_blocked_ips - .keys() - .chain(rate_limited_ips.keys()) - .map(|s| s.as_str()) - .collect(); - - if blocked_ips.is_empty() { - return Ok((0, 0)); - } - - let file = std::fs::File::open(input)?; - let reader = std::io::BufReader::new(file); - let mut ip_statuses: FxHashMap> = FxHashMap::default(); - - for line in reader.lines() { - let line = line?; - let entry: AuditLog = match serde_json::from_str(&line) { - Ok(e) => e, - Err(_) => continue, - }; - let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string(); - if blocked_ips.contains(ip_str.as_str()) { - ip_statuses.entry(ip_str).or_default().push(entry.fields.status); - } - } - - let mut fp = 0usize; - let mut tp = 0usize; - for (_ip, statuses) in &ip_statuses { - let total = statuses.len(); - let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count(); - let ok_pct = ok_count as f64 / total as f64 * 100.0; - if ok_pct > 60.0 { - fp += 1; - } else { - tp += 1; - } - } - - Ok((fp, tp)) -} - -pub fn run(args: ReplayArgs) -> Result<()> { - eprintln!("Loading model from {}...", args.model_path); - eprintln!("Replaying {}...\n", args.input); - - let result = replay_and_evaluate(&args)?; - - let total = result.total; - eprintln!("═══ Replay Results ═══════════════════════════════════════"); - eprintln!(" Total requests: {total}"); - eprintln!(" Skipped (parse): {}", result.skipped); - eprintln!(" Allowed: {} ({:.1}%)", result.allowed, pct(result.allowed, total)); - eprintln!(" DDoS blocked: {} ({:.1}%)", result.ddos_blocked, pct(result.ddos_blocked, total)); - if args.rate_limit { - eprintln!(" Rate limited: {} ({:.1}%)", result.rate_limited, pct(result.rate_limited, total)); - } - - if !result.ddos_blocked_ips.is_empty() { - eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────"); - let mut sorted: Vec<_> = result.ddos_blocked_ips.iter().collect(); - sorted.sort_by(|a, b| b.1.cmp(a.1)); - for (ip, count) in sorted.iter().take(20) { - eprintln!(" {:<40} {} reqs blocked", ip, count); - } - } - - if !result.rate_limited_ips.is_empty() { - eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────"); - let mut sorted: Vec<_> = result.rate_limited_ips.iter().collect(); - sorted.sort_by(|a, b| b.1.cmp(a.1)); - for (ip, count) in sorted.iter().take(20) { - eprintln!(" {:<40} {} reqs limited", ip, count); - } - } - - eprintln!("\n── False positive check ──────────────────────────────────"); - if result.false_positive_ips == 0 { - eprintln!(" No likely false positives found."); - } else { - eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses", result.false_positive_ips); - } - eprintln!(" True positive IPs: {}", result.true_positive_ips); - - eprintln!("══════════════════════════════════════════════════════════"); - Ok(()) -} - -fn default_rate_limit_config() -> RateLimitConfig { - RateLimitConfig { - enabled: true, - bypass_cidrs: vec![ - "10.0.0.0/8".into(), - "172.16.0.0/12".into(), - "192.168.0.0/16".into(), - "100.64.0.0/10".into(), - "fd00::/8".into(), - ], - eviction_interval_secs: 300, - stale_after_secs: 600, - authenticated: crate::config::BucketConfig { - burst: 200, - rate: 50.0, - }, - unauthenticated: crate::config::BucketConfig { - burst: 60, - rate: 15.0, - }, - } -} - -fn pct(n: u64, total: u64) -> f64 { - if total == 0 { - 0.0 - } else { - (n as f64 / total as f64) * 100.0 - } -} diff --git a/src/ddos/train.rs b/src/ddos/train.rs index 4c3789f..4378883 100644 --- a/src/ddos/train.rs +++ b/src/ddos/train.rs @@ -1,13 +1,32 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use crate::ddos::audit_log::AuditLog; use crate::ddos::audit_log; use crate::ddos::features::{method_to_u8, FeatureVector, LogIpState, NormParams, NUM_FEATURES}; -use crate::ddos::model::{SerializedModel, TrafficLabel}; use anyhow::{bail, Context, Result}; use rustc_hash::{FxHashMap, FxHashSet}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::hash::{Hash, Hasher}; use std::io::BufRead; +/// Legacy KNN training types — kept for the `train-ddos` CLI command +/// which produces bincode model files for offline evaluation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum TrafficLabel { + Normal, + Attack, +} + +#[derive(Serialize, Deserialize)] +pub struct SerializedModel { + pub points: Vec, + pub labels: Vec, + pub norm_params: NormParams, + pub k: usize, + pub threshold: f64, +} + #[derive(Deserialize)] pub struct HeuristicThresholds { /// Requests/second above which an IP is labeled attack @@ -255,12 +274,12 @@ pub fn parse_logs(input: &str) -> Result> { state.statuses.push(entry.fields.status); state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32); state.content_lengths.push(entry.fields.content_length.min(u32::MAX as u64) as u32); - state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false)); + state.has_cookies.push(entry.fields.has_cookies); state.has_referer.push( - entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false), + !entry.fields.referer.is_empty() && entry.fields.referer != "-", ); state.has_accept_language.push( - entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false), + !entry.fields.accept_language.is_empty() && entry.fields.accept_language != "-", ); state.suspicious_paths.push( crate::ddos::features::is_suspicious_path(&entry.fields.path), diff --git a/src/ensemble/ddos.rs b/src/ensemble/ddos.rs index 2b82e1b..c59593b 100644 --- a/src/ensemble/ddos.rs +++ b/src/ensemble/ddos.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use crate::ddos::model::DDoSAction; use super::gen::ddos_weights; use super::mlp::mlp_predict_32; @@ -80,59 +83,46 @@ mod tests { use super::*; #[test] - fn test_tree_allow_path() { - // All zeros → feature 4 (request_rate) = 0.0 <= 0.70 → left (node 1) - // feature 10 (cookie_ratio) = 0.0 <= 0.30 → left (node 3) → Allow + fn test_tree_block_path() { + // Tree: root splits on feature 10 (cookie_ratio) at 0.14. + // All zeros → cookie_ratio normalized = 0.0 <= 0.14 → Block (node 1) let raw = [0.0f32; 14]; let v = ddos_ensemble_predict(&raw); + assert_eq!(v.action, DDoSAction::Block); + assert_eq!(v.path, DDoSEnsemblePath::TreeBlock); + } + + #[test] + fn test_tree_allow_path() { + // Tree: feature 10 (cookie_ratio) > 0.14 → node 2 (Allow leaf) + // feature 10 range [0, 1], raw 0.5 → normalized 0.5 > 0.14 → Allow + let mut raw = [0.0f32; 14]; + raw[10] = 0.5; + let v = ddos_ensemble_predict(&raw); assert_eq!(v.action, DDoSAction::Allow); assert_eq!(v.path, DDoSEnsemblePath::TreeAllow); assert_eq!(v.reason, "ensemble:tree_allow"); } #[test] - fn test_tree_block_path() { - // Need: feature 4 (request_rate) > 0.70 normalized → right (node 2) - // feature 12 (accept_language_ratio) > 0.25 normalized → right (node 6) → Block - // feature 4 max = 500, so raw 400 → normalized 0.8 > 0.70 ✓ - // feature 12 max = 1.0, so raw 0.5 → normalized 0.5 > 0.25 ✓ - let mut raw = [0.0f32; 14]; - raw[4] = 400.0; - raw[12] = 0.5; - let v = ddos_ensemble_predict(&raw); - assert_eq!(v.action, DDoSAction::Block); - assert_eq!(v.path, DDoSEnsemblePath::TreeBlock); - } - - #[test] - fn test_mlp_path() { - // Need: feature 4 > 0.70 normalized → right (node 2) - // feature 12 <= 0.25 normalized → left (node 5) → Defer - // feature 4 max = 500, raw 400 → 0.8 > 0.70 ✓ - // feature 12 max = 1.0, raw 0.1 → 0.1 <= 0.25 ✓ - let mut raw = [0.0f32; 14]; - raw[4] = 400.0; - raw[12] = 0.1; - let v = ddos_ensemble_predict(&raw); - assert_eq!(v.path, DDoSEnsemblePath::Mlp); - assert_eq!(v.reason, "ensemble:mlp"); - assert!(v.score >= 0.0 && v.score <= 1.0); - } - - #[test] - fn test_defer_then_mlp_allow() { - // Same Defer path as above — verify the MLP produces a valid action - let mut raw = [0.0f32; 14]; - raw[4] = 400.0; - raw[12] = 0.1; - let v = ddos_ensemble_predict(&raw); - assert!(matches!(v.action, DDoSAction::Allow | DDoSAction::Block)); + fn test_mlp_direct() { + // Current tree has no Defer leaves, so test MLP inference directly. + let input = [0.5f32; 14]; + let score = mlp_predict_32::<14>( + &ddos_weights::W1, + &ddos_weights::B1, + &ddos_weights::W2, + ddos_weights::B2, + &input, + ); + assert!(score >= 0.0 && score <= 1.0); } #[test] fn test_normalize_clamps_high() { + // feature 0 max = 10000.0, raw 999999 → clamped to 1.0 let mut raw = [0.0f32; 14]; - raw[0] = 999.0; // max is 100 + raw[0] = 999999.0; let normed = normalize(&raw); assert!((normed[0] - 1.0).abs() < f32::EPSILON); } @@ -140,7 +130,7 @@ mod tests { #[test] fn test_normalize_clamps_low() { let mut raw = [0.0f32; 14]; - raw[1] = -500.0; // min is 0 + raw[1] = -500.0; // min is 1.0 let normed = normalize(&raw); assert!((normed[1] - 0.0).abs() < f32::EPSILON); } diff --git a/src/ensemble/gen/ddos_weights.rs b/src/ensemble/gen/ddos_weights.rs index e67b64d..058a71d 100644 --- a/src/ensemble/gen/ddos_weights.rs +++ b/src/ensemble/gen/ddos_weights.rs @@ -1,71 +1,74 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! Auto-generated weights for the ddos ensemble. //! DO NOT EDIT — regenerate with `cargo run --features training -- train-ddos-mlp`. pub const THRESHOLD: f32 = 0.50000000; pub const NORM_MINS: [f32; 14] = [ - 0.08778746, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.05001374, 0.02000000, + 0.00000000, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00833333, 0.02000000, 0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000 ]; pub const NORM_MAXS: [f32; 14] = [ - 1000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49990082, 500.00000000, 1.00000000, + 10000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49999166, 500.00000000, 1.00000000, 171240.28125000, 30.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000 ]; pub const W1: [[f32; 14]; 32] = [ - [0.57458097, -0.10861993, 0.14037465, -0.23486336, -0.43255216, 0.16347405, 0.71766937, 0.83138502, 0.02852129, 0.56590265, 0.54848498, 0.38098580, 0.82907754, 0.61539698], - [0.06583579, 0.02305713, 0.89706898, 0.42619053, -1.20866120, -0.11974730, 1.70674825, -0.17969023, -0.26867196, 0.60768014, -0.08671998, -0.04825107, 0.58131427, -0.02062579], - [-0.40264100, 0.18836430, 0.08431315, 0.33763552, 0.44880620, -0.40894085, -0.22044741, 0.00533387, -0.61574107, 0.07670992, 0.63528854, 0.48244709, 0.20411402, -1.80697525], - [0.66713083, -0.22220801, 2.11234117, 0.41516641, -0.00165093, 0.65624571, 1.87509167, 0.63406783, -2.54182458, -0.53618753, 2.16407824, -0.61959583, -0.04717547, 0.17551991], - [-0.51027024, -0.60132700, 0.46407551, -0.57346475, -0.30902353, -0.24235034, 0.08087540, -2.14762974, -0.29429656, 0.56257033, -0.26935315, -0.16799171, 0.56852734, 1.93494022], - [-0.24938971, -0.12699288, -0.13746630, 0.64942318, 0.09490766, -0.02158179, 0.72449303, -0.28493983, -0.43053114, -0.01443988, 0.89670080, -0.34539866, -1.47019410, 0.79477930], - [0.62935185, 0.74686801, -0.15527052, -0.06635039, 0.73137009, 0.78417069, -0.06417987, 0.72259408, 0.85131824, 0.00477386, -0.14302900, 0.63481224, 0.92724019, -0.50126070], - [-0.12699059, -0.15016419, -0.48704135, 0.00581611, 0.75824696, 0.84114397, -0.08958503, 0.18609463, 0.56247348, 0.22239330, 0.43324804, 0.82077771, 0.55714250, -0.56955606], - [0.83457869, 0.40054807, 0.23281574, -0.58521581, 1.18067443, -0.49485078, 0.08600014, 0.99104887, -0.65019566, -0.44594154, 0.64507920, -0.61692268, -0.29301512, -0.11314666], - [-0.07868081, -0.18392175, 0.15165123, 0.35139060, 0.13855398, 0.16470867, 0.21025884, 1.57204449, 0.07827333, 0.05895505, 0.00810917, 1.05159700, 0.04605416, 0.38080546], - [1.47428405, -0.21614535, -0.35385504, 0.46582970, 1.26638246, 0.00133375, -3.85603786, 0.39766011, 1.92816520, 0.47828305, -0.16951409, -0.13771342, 0.49451983, 0.41184473], - [-0.23364748, 0.68134952, 0.36342716, 0.02657196, 0.07550839, 0.94861823, -0.52908695, 0.83652318, -0.05639480, 0.26536962, 0.44137934, 1.20957208, -0.60747981, -0.50647283], - [-0.16961956, -0.49570882, -0.33771378, -0.28554109, 0.95865113, -0.49269623, -0.44559151, 1.28568971, 0.79537493, -0.53175420, -3.19015551, 0.52214253, 0.86517984, 0.62523192], - [-0.16956513, -0.61727583, 0.63967121, 0.96406335, -0.28760204, 0.56459671, 0.78585202, -0.03668134, -0.14773002, -0.35764447, 0.84649116, -0.34540027, -0.12314465, -0.10070048], - [-0.34183556, -0.07760386, 0.70894319, 0.92814171, -0.19357866, 0.41449037, 0.54653358, 0.27682835, 0.81471086, 0.56383932, 0.57456553, -0.61491662, 0.92498505, 0.74495614], - [-0.38917324, -0.29217750, 1.43508542, -0.19152534, -0.18823336, 0.45097819, -0.38063127, -0.40419811, 0.56686693, -0.33231607, -0.19567636, -0.02500075, -0.04762971, 0.44703853], - [1.14234805, -0.62868208, -0.21298689, 0.00263968, -0.66115338, -1.12038326, 0.93599045, 0.77646011, -0.22770278, 1.43982041, 0.96078646, 1.15076077, -0.45110813, 0.83090556], - [0.89638984, -0.69683450, -0.29400119, 0.94997799, 0.90305328, -0.80215877, -0.09983492, -0.90757453, -0.03181892, 1.00702441, -0.97962254, -0.89580274, 0.69299418, -0.75975400], - [-0.75832003, -0.07210776, 0.07825917, 1.51633596, 0.44593197, 0.00936707, -0.12142835, -0.09877282, 0.06229200, 1.25678349, 0.25317946, 0.54112315, -0.17941843, 0.93283361], - [0.23085761, 0.53307736, 0.38696140, 0.36798462, 0.38192499, 0.23203450, 0.68225187, 0.47096270, -4.24785280, 0.18062039, 0.60047084, 0.16251479, -0.10811257, 0.48166662], - [0.10870802, 0.01576116, 0.00298645, 0.25878090, -0.16634797, 0.15850464, -0.24267951, 0.87678236, -0.27257833, 0.78637868, -0.00851476, 0.01502728, 0.92175138, -0.81292266], - [-0.74364990, -0.63139439, -0.18314177, -0.36881343, -0.53096825, -0.92442876, -0.05536628, -0.71273297, -0.94937468, -0.03863344, -0.09668982, -1.07886386, 0.58555382, 0.23351164], - [-0.09152136, 0.96538877, -0.11560653, -0.53110164, 0.89070886, 0.05664408, -0.71661353, 0.79684436, -0.00206013, 0.23857179, 0.06074178, -0.67188424, -0.15624331, 0.43436247], - [-0.28189376, -0.00535834, 0.60541785, 0.82968009, -0.21901314, -0.29874969, -0.16872653, 0.45570841, -0.25372767, -0.12359514, -1.10104620, 0.00162374, 0.07622175, 0.60413152], - [-1.13819373, -0.41320390, -5.57348347, 0.40931624, -1.59562767, 0.72510892, 0.03248254, 0.00407641, 0.57557869, 0.53510398, -0.35943517, 0.52707136, 0.61220711, -0.11644226], - [-0.02057049, 0.42545527, 0.24192038, 0.29863021, -0.22839858, -0.25318733, 0.17906551, -0.29471490, -0.04746799, 0.15909556, -0.26826856, -0.06874973, -0.03044286, 0.11770450], - [-0.18060833, -0.06301155, 0.01656315, -0.40476608, -0.35056075, 0.06344713, 0.32273614, -0.04382812, -0.18925793, 0.02124963, -0.23447622, 0.29704437, 0.19138981, -0.04584064], - [0.18248987, 0.05461208, -0.25655189, 0.16673982, 0.03251073, 0.05709980, 0.09135589, 0.06712578, -0.02372392, 0.00487196, -0.11774579, 0.34203079, 0.18477952, 0.09847298], - [-0.08292723, -0.03089223, 0.19555064, -0.18158682, -0.32060555, 0.18836822, -0.14625609, -0.83500093, -0.09893667, 0.02719803, 0.06864946, 0.00156752, 0.04342323, 0.30958080], - [-0.21274266, 0.06035644, 0.27282441, -0.01010289, -0.05599894, 0.27938741, -0.23254848, -0.20086342, -0.06775926, -0.18059292, 0.92534143, 0.09500337, 0.11612320, -0.06473339], - [-0.27279299, 0.96252358, 0.67542273, 0.64720130, 0.15221471, 1.67354584, 0.53074431, 0.65513390, 0.79840666, 0.78613347, 0.34742561, -1.83272552, 0.73313516, 0.09797212], - [-0.08888317, 0.14851266, 1.00953877, 0.19915256, -0.10076691, 0.47210938, 0.04427565, 0.19299655, 0.58729172, 0.17481442, -0.57466495, -0.16196120, 0.06293163, 1.73905540], + [0.30719012, 0.54619288, -0.20061351, 0.29659060, 0.34138879, -0.19088192, 0.34866381, -0.24303232, 0.20615512, 0.12656690, -0.16653502, 0.10961400, -0.16814700, 0.09950374], + [0.46002641, 0.49129087, -0.01386960, 0.64490628, 0.51850092, 0.69266915, 0.31095454, 0.26951542, -0.20926359, -0.09662568, 0.19281134, 0.36575633, 0.23089127, -0.03582983], + [0.04547556, -0.04854088, -0.10979871, -0.18705507, -0.44649851, 0.20949614, 0.73240960, 1.34691823, -0.13529004, 0.69439852, -0.40027520, 0.47921708, -0.43529814, 0.48781869], + [0.61547452, 1.52229679, 0.48276836, 1.27171433, -0.36176509, -0.33192506, -0.82673991, 0.67331636, -0.21094124, 1.03067887, -0.09182073, 1.44520211, 0.52611661, 0.61176163], + [-0.37162983, 0.48245564, -0.53393066, -0.20009390, -0.06583384, 0.17612432, 0.59905756, 1.39533114, 0.67457062, 0.06159161, -0.56609136, -0.29591814, 0.55239469, -0.56801152], + [0.75815558, -0.64557153, 0.84678394, 0.41179815, -0.50619060, -0.09139232, -0.64594650, -0.74464273, 0.87102652, 0.81111395, -0.35027400, 0.95135874, 0.85043454, 1.72117484], + [0.98888069, -0.04631047, -0.62931997, -0.37154037, 1.02817857, -1.04121590, 0.74848920, -0.26426360, 0.23142239, -0.17234743, -0.61689568, -0.59363395, -0.85756373, 0.53024006], + [0.59509462, 0.26622522, -0.74383926, 0.48256168, -0.75522244, 0.46806136, -0.62194610, 0.09251838, 0.25921744, 0.72987258, 0.66349596, 0.53999704, -0.25535119, -0.92465514], + [0.50981742, -1.44853806, -0.64814043, -0.78203505, -0.88038790, 0.68278509, 0.58861315, 0.55924416, -0.52396554, 0.45195666, -0.44876143, -0.11349974, 0.64508075, 0.06376592], + [0.44494510, 0.79238343, -0.22128101, 3.13757062, 0.54972911, 2.06494117, -0.20301908, 0.48413971, -0.25992882, 0.77544886, -0.18115431, 1.87130582, 0.71965748, 1.95458603], + [1.00518668, 1.80238068, -0.28449696, -0.02740687, -2.51049113, 0.56081659, -0.43591678, 3.59169340, -0.47954431, 1.82556272, 0.64387941, 0.56122434, -0.19696619, 3.49070907], + [-0.32992145, -0.03573111, 2.41438532, -0.00748284, 0.62775159, 1.78909039, 0.25103322, 0.59640545, -0.10183074, 0.83787775, 0.14171274, 0.08816884, 0.16381627, -0.04427620], + [0.09841868, 0.58517164, 0.02630968, 0.65797943, -0.03991833, 0.52833039, 0.37459302, 0.01832970, -1.20483434, 0.76000416, 0.02081347, 1.10453236, 0.46800232, 0.50707549], + [0.13568293, -0.04429439, 0.18404786, 0.74804515, -0.02402807, 0.25729915, 0.64555109, 0.09644510, 0.31338552, 0.62685025, -0.19832127, 1.95116663, 0.66340035, 1.29182386], + [-0.20969683, 0.56657153, -0.08705560, 0.71007556, -0.11011623, 1.16174579, 0.65050489, 1.31441426, 0.72755563, 1.15947676, -0.34925875, 0.01019314, -1.42810500, 0.14942981], + [-0.47017330, 2.62149596, -0.37532449, 1.17488575, 0.62930888, 0.62195790, -0.12959687, 2.36229849, -0.25786853, 0.03494137, 1.70790768, -0.02720823, 0.57822198, 1.57692003], + [-0.68229634, 0.87380433, 1.03171849, 0.35238963, -0.78998542, 0.97562903, -0.80616480, 1.07170749, -0.79917014, -0.43357334, 1.09133816, 0.49446958, 1.07970095, 0.27838916], + [-0.77235895, -0.66010702, -0.09969614, -0.38052577, -0.77211934, -0.73416811, -0.67031443, 0.62016815, 0.97461295, 1.07167208, -0.68821293, 0.51563287, -0.73027885, -0.14203216], + [0.90449816, -0.23423387, 1.11039567, 0.61329746, -0.21385542, 0.52449727, 0.42514217, 0.42172486, -0.33397049, 0.35888657, -0.54074812, 0.48481938, -0.05116262, -0.23848286], + [0.67948169, 0.50562781, 0.45344570, 0.47307885, -0.44913152, -0.11515936, 0.14361705, -0.36479098, -0.32777452, 0.11798909, -0.57137913, 0.30936614, 0.31339252, 0.51131296], + [-0.25677630, 0.25580657, -0.12398625, 0.24844812, 0.18556698, 0.21818036, 0.58248550, 0.50517905, 0.34329867, 0.15851928, -0.58440667, 0.33611965, 0.67439252, -0.52770680], + [0.66840053, -0.49819222, 0.29022828, 0.10492916, -0.06216156, 0.37093312, -0.24731418, 0.22893915, 0.32447502, 0.63166237, -0.13788179, 0.52650315, 0.15229015, 0.23656118], + [0.33978519, 0.15498674, -0.25265032, -0.42916322, 0.69121236, 0.20443739, 0.54050952, 0.08900955, -0.13801514, 0.25456557, -0.10714018, 0.08712567, 0.27245566, -0.29683220], + [-0.05526243, -1.17294025, 0.07328646, -0.07892461, 0.31488195, -0.01112767, 0.55462092, 0.65152955, -0.10721418, -0.99451303, 0.00110284, -0.53097665, 0.14362922, 0.17380728], + [-2.95768332, 0.46451911, 0.20220210, 0.76858771, 0.13804838, -0.80371422, -0.11160404, -0.15160939, 0.31488597, -0.10203149, 0.16458754, -0.08558689, -0.27082649, 0.03877234], + [-0.14654562, -0.70086712, 0.09809728, -0.60966188, 0.26278028, 0.07354698, 0.08616283, 0.36018923, 0.07040872, 0.41008693, 0.13071685, 0.18236822, 0.43306109, -0.10742717], + [0.41488791, -0.10255218, 0.10218169, 0.21971215, -0.05527666, -0.50265622, 0.06767768, -0.09040122, 0.16871217, -0.02748547, 0.21738021, 0.21068999, 0.10562737, -0.71913630], + [0.09367306, -0.14113051, -0.44151428, -0.05189204, 0.22411002, -0.09538609, 0.17464676, -0.30709952, 0.21021855, -0.27705607, 0.17645715, 0.19070518, 0.18094100, 0.10115600], + [-0.11084171, -0.60070217, 0.10072551, -0.09865215, 0.19512057, -0.32474023, 0.14499906, 0.06266983, -0.15383074, 0.10347557, -0.10143858, -0.09821036, -0.19187087, -0.21955618], + [0.09774263, 0.21607652, -0.22068830, -0.73502982, 0.14551027, -0.00246539, -0.32017741, -0.14855191, 0.15684886, -0.21544383, -0.36595181, 1.57917106, 0.45341989, 0.64960009], + [0.03760023, 0.12075356, -0.24193284, 0.16418910, -0.13468136, 0.40612614, 0.44222566, 0.17999728, 0.37591749, 0.67439985, -0.29388478, -0.20486754, 0.20614263, -2.63525987], + [-0.10479005, -0.17017230, -0.42374054, 0.30094361, 0.28561834, 0.40433934, -0.03086211, 1.49869466, -0.41601327, 0.20835553, 1.19875181, -0.00222666, 0.51400107, 0.27829245], ]; pub const B1: [f32; 32] = [ - -0.80723554, 0.54879200, 0.01237706, -0.22279924, 0.93692911, 0.12226531, -0.54665250, -0.49958101, - -0.20918398, -0.48646352, -0.58741039, -0.50572610, -0.04772990, -0.62962151, -0.46279392, 1.14840722, - -0.04871057, -0.31787100, 1.13966286, 0.69543558, -0.17798270, 0.66968435, -0.07442535, -0.70557600, - 0.79021728, 0.65736526, -0.30761406, 0.63242179, 0.83297908, -0.04573143, -0.18454255, -0.30583009 + 0.76754266, -0.52365464, -0.07451479, -0.24194083, 0.81372803, -0.14967601, 0.86968440, -0.11282827, + 0.82378083, 0.03708726, -0.14121835, -0.33332673, -0.24595253, -0.20005627, 0.80769247, 0.67842513, + 0.62225562, 0.55104679, 0.87356585, -0.16369765, 0.83232063, -0.40881905, -0.02851989, -0.04714838, + 0.69236869, -0.30938062, 0.87852216, -0.14689557, -0.52630597, -0.22946648, -0.13811214, -0.41019145 ]; pub const W2: [f32; 32] = [ - 1.09615684, -0.57856798, -0.08730038, -0.06425755, -0.96232760, -2.06290460, 0.70097560, 0.85189444, - -0.10077959, 1.94375157, 0.74497795, 0.88425481, 2.11908054, 0.85526127, 0.61624259, -2.93621016, - 1.52211487, 0.56318259, -3.15219641, -0.55187315, 1.61819077, -0.76258671, -0.09362544, 0.86861998, - -0.79028755, -0.90605170, 0.33475992, -0.79945564, -1.16680586, 0.15120529, 0.17619221, 1.61664009 + -0.84622073, 2.32144451, 0.70330697, 0.89360833, -1.08053613, 0.69213301, -1.07218480, 0.82345659, + -1.11953294, -2.58824420, 0.81520051, 1.19865966, 0.91804677, 1.04554057, -1.03049874, -0.94034135, + -1.66193688, -1.53192282, -1.09629154, -4.07772017, -1.14778209, 1.15202129, 0.42650393, 0.55174673, + -1.28319669, 2.06129408, -1.10220599, 0.09728605, 1.64764512, -0.14975634, 0.79428691, 1.56726408 ]; -pub const B2: f32 = -0.52729088; +pub const B2: f32 = -0.58103424; pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [ - (3, 0.30015790, 1, 2), - (255, 0.00000000, 0, 0), + (10, 0.13999981, 1, 2), (255, 1.00000000, 0, 0), + (255, 0.00000000, 0, 0), ]; diff --git a/src/ensemble/gen/scanner_weights.rs b/src/ensemble/gen/scanner_weights.rs index f9030c7..999e9e3 100644 --- a/src/ensemble/gen/scanner_weights.rs +++ b/src/ensemble/gen/scanner_weights.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! Auto-generated weights for the scanner ensemble. //! DO NOT EDIT — regenerate with `cargo run --features training -- train-scanner-mlp`. @@ -14,55 +17,55 @@ pub const NORM_MAXS: [f32; 12] = [ ]; pub const W1: [[f32; 12]; 32] = [ - [2.25848985, 1.62502551, 1.05068624, -0.23875977, 1.29692984, -1.34665418, -1.29937541, 1.66119707, -1.43897200, 0.07720046, -1.17116165, 1.96821272], - [2.30885172, 0.02477695, -0.23236598, 1.66507626, -1.41407740, 1.88616431, 1.84703696, -1.46395433, 2.03542018, 1.68318951, 2.01550031, 1.94223917], - [2.29420924, 1.86615539, 1.69271469, 1.42137837, 1.43151915, 1.84876072, 1.09228194, 1.73608077, 0.20805965, 0.52542430, -0.02558800, 0.04718366], - [0.36484259, -0.02785611, -0.01155548, 0.08577330, -0.00468449, -0.07848717, 0.05191587, 0.50796396, 0.40799347, -0.14838840, -0.30566201, 0.00758083], - [0.28191370, 0.20945202, 0.07742970, -0.06654347, 0.17395714, 0.00011351, 0.37079588, 0.41817516, 0.56992871, 0.05705916, 0.22339216, 0.11021475], - [0.06522971, 0.64510870, 0.31671444, 0.34980071, 0.03446164, -0.10592904, -0.21302676, -0.04404496, 0.08638768, 0.04217484, 0.43021953, 0.21055792], - [0.31206250, -0.14565454, 0.38078794, 0.00860748, 0.29409558, -0.11273954, -0.02210701, 0.15525217, 0.09696059, 0.13877581, 0.06483351, 0.10946950], - [0.28374705, -0.02963164, 0.27863786, -0.23428085, 0.12715313, 0.09141072, 0.07769041, 0.01915955, -0.20936646, 0.02813511, -0.03910714, 0.30322370], - [-1.19449413, -0.84935474, -0.32267663, -0.08140022, -0.78729230, 1.58759272, 0.88281459, -0.77263606, 1.55394125, 0.10148179, 1.59524822, -0.75499195], - [-0.97152823, -0.12173092, 0.04745778, -0.85466659, 1.57352293, -0.52651149, -0.66270715, 1.32282484, -1.24654925, -0.45822921, -1.10187364, -0.91162699], - [-0.93944395, -0.57891464, -1.12100291, -0.38871467, -0.18780440, -1.11835766, -0.43614236, -1.07918274, -0.09222561, -0.23854440, -0.16720718, 0.03247443], - [0.13319625, 0.87437463, 0.32213065, 0.13902900, 0.64760798, 0.00899744, 0.45325586, -0.14138180, 0.13888212, 0.07780524, -0.12482210, 0.12632932], - [0.57018995, -0.10839911, 0.02787536, 0.16884641, 0.19435850, -0.01189608, 0.13881874, -0.10700739, -0.05463003, 0.01371983, 0.04385772, 0.01100468], - [-0.26600277, -0.11843663, -0.01081531, 0.10785927, -0.18684258, 0.08537511, 0.01054722, -0.01972559, -0.07416820, 0.57192892, 0.37873995, -0.00498434], - [0.72535324, -0.25030360, 0.51470703, -0.16410951, -0.13649474, 0.16246459, -0.27847841, 0.12250750, 0.45576489, -0.18535912, -0.45686084, 0.58293521], - [0.18614589, -0.32835677, -0.08683094, 0.07748202, -0.24785264, -0.16834147, 0.27066526, 0.06058804, 0.01903199, -0.17387865, 0.12752151, -0.03780220], - [-1.22358644, -0.78316134, -0.54068804, -0.07921790, -0.72697675, 1.80127227, 0.14326867, -0.51875746, 1.83125353, -0.02672976, 1.68589675, -0.80162954], - [-0.83690810, -0.12682360, 0.10783038, -0.64648604, 1.50810242, -0.48788729, -0.59418935, 0.94863659, -0.84788662, -0.49779284, -0.96408021, -1.14068258], - [-0.96322638, -0.50503486, -0.87195945, -0.34710455, -0.28645220, -1.10507452, -0.32122782, -0.80753750, -0.00843489, 0.04215550, 0.03197355, 0.05468401], - [-0.17587705, 0.45144933, 0.37954769, -0.15405300, 0.75590396, 0.00346784, 0.62332457, -0.15602241, -0.26471916, -0.19963606, -0.22497311, -0.20784236], - [0.60608941, -0.05316854, 0.03766245, 0.46412235, -0.41121334, -0.01225545, -0.11125158, -0.33533856, -0.04625564, -0.02995013, -0.24979964, -0.35824969], - [0.08163761, 0.04702193, -0.24007457, -0.23439978, 0.27066308, 0.48389259, 0.32692793, -0.23089454, 0.26520243, -0.14099684, 0.06713670, 0.14434725], - [-0.50808382, -0.14518137, -0.23912378, 0.33510539, 0.46566108, 0.09035082, -0.12637842, 0.55245715, -0.19972627, 0.24517706, 0.34291887, 0.01936621], - [0.35826349, 0.21200819, 0.65315312, -0.16792546, 0.41378024, 0.32129642, 0.50814188, 0.48289016, 0.06839173, 0.42079177, 0.52295685, 0.26273951], - [0.24575019, 0.10700949, 0.07041252, -0.09410189, 0.18897925, 0.31616825, -0.01306109, 0.33499330, -0.01866218, 0.06233863, 0.15316568, 0.08370106], - [0.17828286, 0.17363867, -0.10626584, 0.06075979, 0.39465010, 0.19557165, 0.30352867, 0.26720291, 0.40256795, 0.13942246, 0.05869288, 0.08310238], - [-0.04834138, 0.29206491, 0.01330532, 0.07626399, -0.17378819, 0.09515948, 0.02298534, 0.41555724, 0.09492048, 0.39422533, 0.39373979, 0.20463347], - [-0.11641891, -0.06529939, -0.18899654, -0.02157970, -0.03554495, 0.10956290, -0.11688691, 0.04077352, 0.34220406, -0.09558969, 0.16150762, 0.25759667], - [-0.17313123, 0.00591523, 0.29443163, 0.08298909, 0.07761172, 0.19023541, 0.23826212, -0.07167042, 0.08753359, 0.17917964, -0.03248737, 0.28516129], - [0.13091524, 0.21435370, 0.15093684, 0.30902347, 0.44151527, 0.55901742, 0.19933179, 0.06438518, 0.30585650, -0.34089112, 0.26879075, 0.12928906], - [-0.25311065, -0.09963353, -0.50099874, 0.57481062, 0.38744658, -0.13065037, 0.18897361, 0.49376330, -0.15626629, 0.19911517, 0.06437352, -0.09104283], - [0.35787049, -0.04814727, 0.45446551, -0.15264697, 0.36565515, 0.22795495, 0.24630190, 0.16362202, 0.21044184, 0.53882843, 0.42343852, 0.18454899], + [-0.24775992, -1.12687504, -0.84061003, 1.35276508, 1.04677176, 1.57832956, 1.47995067, 1.38580477, -0.99564040, -1.20309269, -0.24385734, 1.32367671], + [-1.31796920, -0.12229957, 0.89794689, 1.14832735, 1.17210162, 1.32387733, 1.37799740, 1.22984815, 0.92816162, 1.45189691, 0.97822803, -0.89625973], + [-1.46790743, -1.43995631, 1.44276273, 1.30967343, 1.37424576, -1.16613543, 1.46063673, -1.29701447, 0.20172349, -0.05374112, -0.02462341, 0.25295240], + [0.57722956, 0.39603359, -0.03453349, -0.06384056, 0.08056496, 0.30015573, -0.19045275, 0.39737019, 0.41070747, 0.03414693, -0.25483453, 0.24018581], + [0.02473495, 0.05047939, 0.59600842, 0.35709319, -0.32843903, 0.17655721, 0.20441811, 0.05118980, 0.24312067, 0.12677322, 0.12109834, -0.09230414], + [-0.03397392, -0.08393703, 0.42762470, 0.21609549, -0.28864771, -0.11998425, -0.05842599, 0.87331069, 0.52526826, 0.35321006, 0.31293729, 0.51823896], + [-0.26906690, -0.04246650, 0.05114691, 0.27629042, -0.37589836, -0.07396606, -0.46509269, 0.26079515, 0.33588448, 0.55919188, 0.58768439, 0.81638134], + [-0.48464304, 0.76111192, 0.06296005, -0.13527128, -0.41344830, -0.19461812, 0.52815431, 0.96815300, 0.47175986, -0.13977771, 0.41216326, -0.03041369], + [-0.05156134, 0.90191734, 0.78513384, -1.25017786, -0.54637259, -0.36098620, -0.59882820, -0.90374511, 0.91336489, 1.05806208, 0.04994302, -0.91028821], + [0.68376511, 0.07305489, 0.44790089, -0.56647295, -0.66570538, -0.93897015, -0.40558589, -1.51070845, 0.46759781, -1.36738360, -0.05270236, 0.98130196], + [0.86788160, 0.69321704, -0.53778958, -1.54257190, -0.44623125, 0.72615588, -0.75269628, 0.81946337, 0.17503875, 0.63745797, 0.48478079, -0.31573632], + [-0.01361719, 0.21524477, -0.10345778, -0.38488832, 0.42967409, 0.75472528, -0.07410870, -0.65231675, 0.42633417, -0.10289414, 0.09583388, -0.29391766], + [-0.06778818, -0.44469842, 0.05952910, -0.55139810, 0.14308600, -0.53731138, -0.07426350, 0.28065708, 0.29584157, 0.47813708, 0.02095048, -0.36458421], + [-0.21983556, 0.55435538, -0.13939659, 0.58281261, -0.20551582, 0.30075905, 0.13396217, -0.18145087, -0.43283740, 0.18541494, 0.07530790, -0.04916608], + [0.46939296, 0.57935077, -0.05478116, -0.01144989, 0.54106784, -0.18313073, 0.12232503, -0.32802504, -0.01167463, -0.13702804, -0.19521871, 0.09115479], + [0.25263065, -0.27172634, -0.12802953, 0.50027740, 0.05213343, 0.49081728, 0.15367918, 0.20471051, -0.22081012, 0.51709008, 0.01776243, 0.22513707], + [-0.20733139, 0.60041994, 0.05273124, 0.07473211, 0.14580894, -0.72007078, -0.52350652, -0.15482022, 0.19132918, 0.52586436, -0.04793828, -0.00479114], + [0.19872986, -0.19177110, -0.22340146, -0.48786804, -0.51010352, -0.55363113, -0.29520389, -0.21378680, -0.40099174, -0.09184421, -0.08521358, 0.61833692], + [0.21346046, 0.53319895, -0.44765636, -0.04764151, -0.30569363, 0.19765340, -0.41479719, 0.34292534, -0.29234713, 0.54341668, 0.60121793, -0.00226344], + [-0.29598647, -0.37357926, -0.25650844, -0.05165816, 0.55829030, 0.21028350, -0.28581545, -0.37299931, 0.57590896, -0.01573592, -0.19411144, 0.13814686], + [0.10028259, 0.02526089, -0.20488358, 0.25667843, 0.17100072, 0.01034015, -0.32994771, 0.53425753, 0.64935833, 0.30769956, -0.26756367, -0.03389005], + [-0.18916467, 0.38340616, -0.16475976, 0.59811211, 0.12739281, -0.16611671, -0.31913927, 0.07577144, 0.28552490, 0.54843456, 0.40937552, 0.38236183], + [-0.20519450, -0.04122134, -0.20013523, 0.42193425, -0.27304563, -0.21811043, 0.13115846, 0.16724831, 0.13073303, 0.20491999, 0.31806493, 0.13444173], + [0.01762132, 0.32608625, 0.19381267, -0.33404192, -0.46299583, -0.28042898, 0.20772585, 0.20139317, 0.41952321, -0.30363685, 0.20015827, -0.03338646], + [0.13760759, 0.07168494, 0.26161709, 0.41468662, -0.03778528, 0.38290465, 0.48780030, 0.39562985, 0.24758396, -0.05975538, -0.22738078, 0.27877593], + [0.07016940, -0.03804595, -0.08812129, 0.19664441, 0.13347355, 0.50309300, 0.26076415, 0.19044210, -0.20414594, 0.64333421, 0.15160090, 0.16449226], + [0.31039700, -0.01906084, 0.25622010, 0.10707659, 0.54883337, 0.19277412, 0.42004701, -0.09319381, 0.19968294, 0.07109389, -0.28979829, 0.12353907], + [0.28500485, 0.01991569, 0.05190456, 0.29366553, 0.01045146, -0.02013574, -0.01796320, 0.13775185, 0.11095868, -0.25678155, 0.10733776, -0.07584792], + [0.12738188, 0.07762879, -0.06429479, 0.39944342, 0.07958066, 0.46697047, -0.10674930, -0.12212183, -0.01540831, 0.08788434, 0.17299946, 0.25846422], + [0.26692817, 0.00930361, 0.24862845, 0.02167275, -0.09902105, -0.35391217, -0.41734406, 0.44949567, 0.46330830, 0.40603620, 0.08397861, 0.39809385], + [-0.30756459, -0.43368185, -0.00478506, 0.45611116, -0.05069341, 0.21090019, 0.28219289, 0.07687758, 0.54915971, 0.46933413, 0.35599890, 0.17573997], + [-0.19320646, 0.44751191, -0.14140815, 0.00427075, -0.19792002, -0.19400074, 0.19292155, 0.39845818, 0.21028778, -0.10284913, 0.31191504, -0.36995885], ]; pub const B1: [f32; 32] = [ - 1.12135851, 0.64268047, 0.44761124, -0.28471574, 0.70866716, -0.25293177, -0.19119856, 0.39284116, - -0.20628852, -0.29301032, -0.08837436, 0.92048728, 0.91167349, -0.33615190, -0.06016272, 0.79141164, - -0.43257964, 0.48180589, 0.70891160, -0.24290052, 0.83115542, 0.69964927, 0.97887653, 1.34517038, - 1.10292709, 0.42009205, 1.07155228, 0.61349720, 0.46157768, 1.01911950, 0.51159418, 0.60460496 + -0.24357778, -0.24826238, -0.03415382, 0.00968227, 0.51550633, 0.45242083, 0.60654080, 0.25456131, + -0.36509025, -0.22825000, 0.03829522, 0.65561563, -0.19379658, -0.25716159, 0.45115772, 0.73442084, + 0.61352992, 0.59502298, 0.32757106, 0.28512844, 0.26663530, 0.27169749, 0.33571365, -0.34503689, + -0.08054741, -0.06313029, 0.43629149, 0.35936099, 0.39375633, -0.19984132, 0.49092621, -0.27418151 ]; pub const W2: [f32; 32] = [ - 1.55191231, 1.27754235, 0.43588921, 0.10868450, 0.55931729, -1.46911597, -0.54461092, 0.78240824, - -1.25938582, -0.06287600, -1.02053738, 1.07076716, 1.58776867, -0.03168033, -0.11393511, 1.30535436, - -1.46621227, 0.62925971, 0.76781118, -0.74480098, 1.29669034, 0.62078375, 1.64134884, 2.09736991, - 1.52834618, 0.87368065, 1.80090642, 0.89230227, 0.38757962, 1.80718291, 0.64923352, 1.18709576 + -0.04975564, -0.54667705, -0.39323062, 0.72362727, 0.86801738, 1.93621075, 1.01259410, 0.75978750, + -0.67997259, -0.63063931, -0.07149173, 0.81899148, -0.69025612, -0.12359849, 1.09533453, 0.88092262, + 0.89678788, 0.87908030, 1.12460852, 0.76745653, 0.85632098, 0.72992527, 0.93983871, -0.55915666, + -0.61104172, -0.56369978, 1.43480921, 0.71174467, 1.03119624, -0.57950914, 0.81188917, -0.78019017 ]; -pub const B2: f32 = 0.23270580; +pub const B2: f32 = 0.24626280; pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [ (3, 0.50000000, 1, 2), diff --git a/src/ensemble/replay.rs b/src/ensemble/replay.rs index 3dd7af0..ec63c0b 100644 --- a/src/ensemble/replay.rs +++ b/src/ensemble/replay.rs @@ -1,6 +1,10 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! Replay audit logs through the ensemble models (scanner + DDoS). -use crate::ddos::audit_log::{self, AuditLog}; +use crate::audit::AuditLogLine; +use crate::ddos::audit_log; use crate::ddos::features::{method_to_u8, LogIpState}; use crate::ddos::model::DDoSAction; use crate::ensemble::ddos::{ddos_ensemble_predict, DDoSEnsemblePath}; @@ -26,21 +30,31 @@ pub fn run(args: ReplayEnsembleArgs) -> Result<()> { std::fs::File::open(&args.input).with_context(|| format!("opening {}", args.input))?; let reader = std::io::BufReader::new(file); - // --- Parse all entries --- - let mut entries: Vec = Vec::new(); - let mut parse_errors = 0u64; + // --- Parse all entries, filtering for audit logs only --- + let mut entries: Vec = Vec::new(); + let mut skipped_non_audit = 0u64; + let mut schema_errors = 0u64; for line in reader.lines() { let line = line?; if line.trim().is_empty() { continue; } - match serde_json::from_str::(&line) { - Ok(e) => entries.push(e), - Err(_) => parse_errors += 1, + match AuditLogLine::try_parse(&line) { + Ok(Some(entry)) => entries.push(entry), + Ok(None) => skipped_non_audit += 1, + Err(e) => { + schema_errors += 1; + if schema_errors <= 3 { + eprintln!(" schema error: {e}"); + } + } } } let total = entries.len() as u64; - eprintln!("parsed {} entries ({} parse errors)\n", total, parse_errors); + eprintln!( + "parsed {} audit entries ({} non-audit skipped, {} schema errors)\n", + total, skipped_non_audit, schema_errors, + ); // --- Scanner replay --- eprintln!("═══ Scanner Ensemble ═════════════════════════════════════"); @@ -54,7 +68,7 @@ pub fn run(args: ReplayEnsembleArgs) -> Result<()> { Ok(()) } -fn replay_scanner(entries: &[AuditLog]) { +fn replay_scanner(entries: &[AuditLogLine]) { let fragment_hashes: FxHashSet = crate::scanner::train::DEFAULT_FRAGMENTS .iter() .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) @@ -73,23 +87,15 @@ fn replay_scanner(entries: &[AuditLog]) { let mut blocked = 0u64; let mut allowed = 0u64; let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp - let mut blocked_examples: Vec<(String, String, f64)> = Vec::new(); // (path, reason, score) - let mut fp_candidates: Vec<(String, u16, f64)> = Vec::new(); // blocked but had 2xx status + let mut blocked_examples: Vec<(String, String, String, f64)> = Vec::new(); // (path, ua, reason, score) + let mut fp_candidates: Vec<(String, String, u16, f64)> = Vec::new(); // blocked but had 2xx status for e in entries { let f = &e.fields; let host_prefix = f.host.split('.').next().unwrap_or(""); - let has_cookies = f.has_cookies.unwrap_or(false); - let has_referer = f - .referer - .as_ref() - .map(|r| r != "-" && !r.is_empty()) - .unwrap_or(false); - let has_accept_language = f - .accept_language - .as_ref() - .map(|a| a != "-" && !a.is_empty()) - .unwrap_or(false); + let has_cookies = f.has_cookies; + let has_referer = !f.referer.is_empty() && f.referer != "-"; + let has_accept_language = !f.accept_language.is_empty() && f.accept_language != "-"; let feats = features::extract_features_f32( &f.method, @@ -98,7 +104,7 @@ fn replay_scanner(entries: &[AuditLog]) { has_cookies, has_referer, has_accept_language, - "-", + &f.accept, &f.user_agent, f.content_length, &fragment_hashes, @@ -121,12 +127,13 @@ fn replay_scanner(entries: &[AuditLog]) { if blocked_examples.len() < 20 { blocked_examples.push(( f.path.clone(), + f.user_agent.clone(), verdict.reason.to_string(), verdict.score, )); } if (200..400).contains(&f.status) { - fp_candidates.push((f.path.clone(), f.status, verdict.score)); + fp_candidates.push((f.path.clone(), f.user_agent.clone(), f.status, verdict.score)); } } ScannerAction::Allow => allowed += 1, @@ -159,8 +166,9 @@ fn replay_scanner(entries: &[AuditLog]) { if !blocked_examples.is_empty() { eprintln!("\n blocked examples (first 20):"); - for (path, reason, score) in &blocked_examples { + for (path, ua, reason, score) in &blocked_examples { eprintln!(" {:<50} {reason} (score={score:.3})", truncate(path, 50)); + eprintln!(" ua: {}", truncate(ua, 72)); } } @@ -170,16 +178,17 @@ fn replay_scanner(entries: &[AuditLog]) { "\n potential false positives (blocked but had 2xx/3xx): {}", fp_count ); - for (path, status, score) in fp_candidates.iter().take(10) { + for (path, ua, status, score) in fp_candidates.iter().take(10) { eprintln!( " {:<50} status={status} score={score:.3}", truncate(path, 50) ); + eprintln!(" ua: {}", truncate(ua, 72)); } } } -fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) { +fn replay_ddos(entries: &[AuditLogLine], window_secs: f64, min_events: usize) { fn fx_hash(s: &str) -> u64 { let mut h = rustc_hash::FxHasher::default(); s.hash(&mut h); @@ -208,18 +217,12 @@ fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) { .push(f.content_length.min(u32::MAX as u64) as u32); state .has_cookies - .push(f.has_cookies.unwrap_or(false)); + .push(f.has_cookies); state.has_referer.push( - f.referer - .as_deref() - .map(|r| r != "-") - .unwrap_or(false), + !f.referer.is_empty() && f.referer != "-", ); state.has_accept_language.push( - f.accept_language - .as_deref() - .map(|a| a != "-") - .unwrap_or(false), + !f.accept_language.is_empty() && f.accept_language != "-", ); state .suspicious_paths diff --git a/src/ensemble/scanner.rs b/src/ensemble/scanner.rs index 6ad7f1d..8828331 100644 --- a/src/ensemble/scanner.rs +++ b/src/ensemble/scanner.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use crate::scanner::model::{ScannerAction, ScannerVerdict}; use super::gen::scanner_weights; use super::mlp::mlp_predict_32; @@ -92,27 +95,11 @@ impl From for ScannerVerdict { mod tests { use super::*; - #[test] - fn test_tree_allow_path() { - // All features at zero → feature 3 (suspicious_ua) = 0.0 <= 0.65 → left (node 1) - // feature 0 (path_depth) = 0.0 <= 0.40 → left (node 3) → Allow leaf - let raw = [0.0f32; 12]; - let v = scanner_ensemble_predict(&raw); - assert_eq!(v.action, ScannerAction::Allow); - assert_eq!(v.path, EnsemblePath::TreeAllow); - assert_eq!(v.reason, "ensemble:tree_allow"); - assert!((v.score - 0.0).abs() < f64::EPSILON); - } - #[test] fn test_tree_block_path() { - // Need: feature 3 (suspicious_ua) > 0.65 (normalized) → right (node 2) - // feature 7 (payload_entropy) > 0.72 (normalized) → right (node 6) → Block - // feature 3 max = 1.0, so raw 0.8 → normalized 0.8 > 0.65 ✓ - // feature 7 max = 8.0, so raw 6.0 → normalized 0.75 > 0.72 ✓ - let mut raw = [0.0f32; 12]; - raw[3] = 0.8; // suspicious_ua: normalized = 0.8/1.0 = 0.8 > 0.65 - raw[7] = 6.0; // payload_entropy: normalized = 6.0/8.0 = 0.75 > 0.72 + // Tree: root splits on feature 7 (ua_category) at 0.75. + // All zeros → ua_category normalized = 0.0 <= 0.75 → Block (node 1) + let raw = [0.0f32; 12]; let v = scanner_ensemble_predict(&raw); assert_eq!(v.action, ScannerAction::Block); assert_eq!(v.path, EnsemblePath::TreeBlock); @@ -120,28 +107,38 @@ mod tests { } #[test] - fn test_mlp_path() { - // Need: feature 3 > 0.65 normalized → right (node 2) - // feature 7 <= 0.72 normalized → left (node 5) → Defer - // Then MLP runs on the normalized input. + fn test_tree_allow_path() { + // Tree: root feature 7 > 0.75 → node 2, checks feature 3 (has_cookies) at 0.25. + // raw[7] = 1.0 → normalized 1.0 > 0.75 → right. + // raw[3] = 1.0 → normalized ~0.7 > 0.25 → right child node 6 → Allow leaf. let mut raw = [0.0f32; 12]; - raw[3] = 0.8; // normalized = 0.8 > 0.65 - raw[7] = 4.0; // normalized = 4.0/8.0 = 0.5 <= 0.72 - // Also need feature 2 (query_param_count) to navigate node 5 correctly - // node 5: split on feature 2, threshold 0.55 → left=9(Defer), right=10 - // normalized feature 2 = 0.0/20.0 = 0.0 <= 0.55 → left (node 9) → Defer + raw[7] = 1.0; // ua_category = browser + raw[3] = 1.0; // has_cookies = yes let v = scanner_ensemble_predict(&raw); - assert_eq!(v.path, EnsemblePath::Mlp); - assert_eq!(v.reason, "ensemble:mlp"); - // MLP output is deterministic for these inputs - assert!(v.score >= 0.0 && v.score <= 1.0); + assert_eq!(v.action, ScannerAction::Allow); + assert_eq!(v.path, EnsemblePath::TreeAllow); + assert_eq!(v.reason, "ensemble:tree_allow"); + } + + #[test] + fn test_mlp_direct() { + // Current tree has no Defer leaves, so test MLP inference directly. + let input = [0.5f32; 12]; + let score = mlp_predict_32::<12>( + &scanner_weights::W1, + &scanner_weights::B1, + &scanner_weights::W2, + scanner_weights::B2, + &input, + ); + assert!(score >= 0.0 && score <= 1.0); } #[test] fn test_normalize_clamps() { // Values beyond max should be clamped to 1.0 let mut raw = [0.0f32; 12]; - raw[0] = 100.0; // max is 10.0 + raw[0] = 100.0; let normed = normalize(&raw); assert!((normed[0] - 1.0).abs() < f64::EPSILON as f32); } diff --git a/src/lib.rs b/src/lib.rs index a799ce9..ba4b80d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,12 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + // Library crate root — exports the proxy/config/acme modules so that // integration tests in tests/ can construct and drive a SunbeamProxy // without going through the binary entry point. +#![recursion_limit = "256"] pub mod acme; +pub mod audit; pub mod autotune; pub mod cache; pub mod cluster; diff --git a/src/main.rs b/src/main.rs index 04df74d..18f5943 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,12 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + mod cert; mod telemetry; mod watcher; -use sunbeam_proxy::{acme, autotune, config}; +use sunbeam_proxy::{acme, config}; use sunbeam_proxy::proxy::SunbeamProxy; -use sunbeam_proxy::ddos; use sunbeam_proxy::rate_limit; use sunbeam_proxy::scanner; @@ -32,77 +34,18 @@ enum Commands { #[arg(long)] upgrade: bool, }, - /// Replay audit logs through detection models + /// Replay audit logs through ensemble models (scanner + DDoS) Replay { - #[command(subcommand)] - mode: ReplayMode, - }, - /// Train a DDoS detection model from audit logs - TrainDdos { /// Path to audit log JSONL file #[arg(short, long)] input: String, - /// Output model file path - #[arg(short, long)] - output: String, - /// File with known-attack IPs (one per line) - #[arg(long)] - attack_ips: Option, - /// File with known-normal IPs (one per line) - #[arg(long)] - normal_ips: Option, - /// TOML file with heuristic auto-labeling thresholds - #[arg(long)] - heuristics: Option, - /// KNN k parameter - #[arg(long, default_value = "5")] - k: usize, - /// Attack threshold (fraction of k neighbors) - #[arg(long, default_value = "0.6")] - threshold: f64, /// Sliding window size in seconds #[arg(long, default_value = "60")] window_secs: u64, - /// Minimum events per IP to include in training - #[arg(long, default_value = "10")] + /// Minimum events per IP before DDoS classification + #[arg(long, default_value = "5")] min_events: usize, }, - /// Train a per-request scanner detection model from audit logs - TrainScanner { - /// Path to audit log JSONL file - #[arg(short, long)] - input: String, - /// Output model file path - #[arg(short, long, default_value = "scanner_model.bin")] - output: String, - /// Directory (or file) containing .txt wordlists of scanner paths - #[arg(long)] - wordlists: Option, - /// Classification threshold - #[arg(long, default_value = "0.5")] - threshold: f64, - /// Include CSIC 2010 dataset as base training data (downloaded from GitHub, cached locally) - #[arg(long)] - csic: bool, - }, - /// Bayesian hyperparameter optimization for DDoS model - AutotuneDdos { - /// Path to audit log JSONL file - #[arg(short, long)] - input: String, - /// Output best model file path - #[arg(short, long, default_value = "ddos_model_best.bin")] - output: String, - /// Number of optimization trials - #[arg(long, default_value = "200")] - trials: usize, - /// F-beta parameter (1.0 = F1, 2.0 = recall-weighted) - #[arg(long, default_value = "1.0")] - beta: f64, - /// JSONL file to log each trial's parameters and results - #[arg(long)] - trial_log: Option, - }, /// Download and cache upstream datasets (CIC-IDS2017) DownloadDatasets, /// Prepare a unified training dataset from multiple sources @@ -125,6 +68,12 @@ enum Commands { /// Path to heuristics.toml for auto-labeling production logs #[arg(long)] heuristics: Option, + /// Inject CSIC 2010 dataset as labeled audit log entries + #[arg(long)] + inject_csic: bool, + /// Inject OWASP ModSec audit log entries (path to .log file) + #[arg(long)] + inject_modsec: Option, }, #[cfg(feature = "training")] /// Train scanner ensemble (decision tree + MLP) from prepared dataset @@ -142,7 +91,7 @@ enum Commands { #[arg(long, default_value = "100")] epochs: usize, /// Learning rate - #[arg(long, default_value = "0.001")] + #[arg(long, default_value = "0.0001")] learning_rate: f64, /// Batch size #[arg(long, default_value = "64")] @@ -153,6 +102,12 @@ enum Commands { /// Min purity for tree leaves (below -> Defer) #[arg(long, default_value = "0.90")] tree_min_purity: f32, + /// Min samples required in a leaf node (higher = less overfitting) + #[arg(long, default_value = "2")] + min_samples_leaf: usize, + /// Weight for cookie feature (0.0=ignore, 1.0=full). Controls has_cookies influence. + #[arg(long, default_value = "1.0")] + cookie_weight: f32, }, #[cfg(feature = "training")] /// Train DDoS ensemble (decision tree + MLP) from prepared dataset @@ -165,184 +120,82 @@ enum Commands { hidden_dim: usize, #[arg(long, default_value = "100")] epochs: usize, - #[arg(long, default_value = "0.001")] + #[arg(long, default_value = "0.0001")] learning_rate: f64, #[arg(long, default_value = "64")] batch_size: usize, #[arg(long, default_value = "6")] tree_max_depth: usize, + /// Min purity for tree leaves (below -> Defer) #[arg(long, default_value = "0.90")] tree_min_purity: f32, - }, - /// Bayesian hyperparameter optimization for scanner model - AutotuneScanner { - /// Path to audit log JSONL file - #[arg(short, long)] - input: String, - /// Output best model file path - #[arg(short, long, default_value = "scanner_model_best.bin")] - output: String, - /// Directory (or file) containing .txt wordlists of scanner paths - #[arg(long)] - wordlists: Option, - /// Include CSIC 2010 dataset as base training data - #[arg(long)] - csic: bool, - /// Number of optimization trials - #[arg(long, default_value = "200")] - trials: usize, - /// F-beta parameter (1.0 = F1, 2.0 = recall-weighted) + /// Min samples required in a leaf node (higher = less overfitting) + #[arg(long, default_value = "2")] + min_samples_leaf: usize, + /// Weight for cookie feature (0.0=ignore, 1.0=full). Controls cookie_ratio influence. #[arg(long, default_value = "1.0")] - beta: f64, - /// JSONL file to log each trial's parameters and results + cookie_weight: f32, + }, + #[cfg(feature = "training")] + /// Sweep cookie_weight values and report tree structure + validation accuracy for each + SweepCookieWeight { + /// Path to prepared dataset (.bin) + #[arg(short = 'd', long)] + dataset: String, + /// Which detector to sweep: "scanner" or "ddos" + #[arg(long, default_value = "scanner")] + detector: String, + /// Comma-separated cookie_weight values to try (default: 0.0,0.1,0.2,...,1.0) #[arg(long)] - trial_log: Option, + weights: Option, + /// Max tree depth + #[arg(long, default_value = "6")] + tree_max_depth: usize, + /// Min purity for tree leaves + #[arg(long, default_value = "0.90")] + tree_min_purity: f32, + /// Min samples required in a leaf node + #[arg(long, default_value = "2")] + min_samples_leaf: usize, }, } -#[derive(Subcommand)] -enum ReplayMode { - /// Replay through ensemble models (scanner + DDoS) - Ensemble { - /// Path to audit log JSONL file - #[arg(short, long)] - input: String, - /// Sliding window size in seconds - #[arg(long, default_value = "60")] - window_secs: u64, - /// Minimum events per IP before DDoS classification - #[arg(long, default_value = "5")] - min_events: usize, - }, - /// Replay through legacy KNN DDoS detector - Ddos { - /// Path to audit log JSONL file - #[arg(short, long)] - input: String, - /// Path to trained model file - #[arg(short, long, default_value = "ddos_model.bin")] - model: String, - /// Optional config file (for rate limit settings) - #[arg(short, long)] - config: Option, - /// KNN k parameter - #[arg(long, default_value = "5")] - k: usize, - /// Attack threshold - #[arg(long, default_value = "0.6")] - threshold: f64, - /// Sliding window size in seconds - #[arg(long, default_value = "60")] - window_secs: u64, - /// Minimum events per IP before classification - #[arg(long, default_value = "10")] - min_events: usize, - /// Also run rate limiter during replay - #[arg(long)] - rate_limit: bool, - }, -} fn main() -> Result<()> { let cli = Cli::parse(); match cli.command.unwrap_or(Commands::Serve { upgrade: false }) { Commands::Serve { upgrade } => run_serve(upgrade), - Commands::Replay { mode } => match mode { - ReplayMode::Ensemble { input, window_secs, min_events } => { - sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs { - input, window_secs, min_events, - }) - } - ReplayMode::Ddos { input, model, config, k, threshold, window_secs, min_events, rate_limit } => { - ddos::replay::run(ddos::replay::ReplayArgs { - input, model_path: model, config_path: config, k, threshold, window_secs, min_events, rate_limit, - }) - } + Commands::Replay { input, window_secs, min_events } => { + sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs { + input, window_secs, min_events, + }) }, - Commands::TrainDdos { - input, - output, - attack_ips, - normal_ips, - heuristics, - k, - threshold, - window_secs, - min_events, - } => ddos::train::run(ddos::train::TrainArgs { - input, - output, - attack_ips, - normal_ips, - heuristics, - k, - threshold, - window_secs, - min_events, - }), - Commands::TrainScanner { - input, - output, - wordlists, - threshold, - csic, - } => scanner::train::run(scanner::train::TrainScannerArgs { - input, - output, - wordlists, - threshold, - csic, - }), Commands::DownloadDatasets => { sunbeam_proxy::dataset::download::download_all() }, - Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics } => { + Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics, inject_csic, inject_modsec } => { sunbeam_proxy::dataset::prepare::run(sunbeam_proxy::dataset::prepare::PrepareDatasetArgs { - input, owasp, wordlists, output, seed, heuristics, + input, owasp, wordlists, output, seed, heuristics, inject_csic, inject_modsec, }) }, #[cfg(feature = "training")] - Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => { + Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight } => { sunbeam_proxy::training::train_scanner::run(sunbeam_proxy::training::train_scanner::TrainScannerMlpArgs { - dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, + dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight, }) }, #[cfg(feature = "training")] - Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => { + Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight } => { sunbeam_proxy::training::train_ddos::run(sunbeam_proxy::training::train_ddos::TrainDdosMlpArgs { - dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, + dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight, }) }, - Commands::AutotuneDdos { - input, - output, - trials, - beta, - trial_log, - } => autotune::ddos::run_autotune(autotune::ddos::AutotuneDdosArgs { - input, - output, - trials, - beta, - trial_log, - }), - Commands::AutotuneScanner { - input, - output, - wordlists, - csic, - trials, - beta, - trial_log, - } => autotune::scanner::run_autotune(autotune::scanner::AutotuneScannerArgs { - input, - output, - wordlists, - csic, - trials, - beta, - trial_log, - }), + #[cfg(feature = "training")] + Commands::SweepCookieWeight { dataset, detector, weights, tree_max_depth, tree_min_purity, min_samples_leaf } => { + sunbeam_proxy::training::sweep::run_cookie_sweep( + &dataset, &detector, weights.as_deref(), tree_max_depth, tree_min_purity, min_samples_leaf, + ) + }, } } @@ -363,46 +216,19 @@ fn run_serve(upgrade: bool) -> Result<()> { // 1b. Spawn metrics HTTP server (needs a tokio runtime for the TCP listener). let metrics_port = cfg.telemetry.metrics_port; - // 2. Load DDoS detection model if configured. + // 2. Init DDoS detector if configured (ensemble: compiled-in weights). let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos { if ddos_cfg.enabled { - if ddos_cfg.use_ensemble { - // Ensemble path: compiled-in weights, no model file needed. - // We still need a TrainedModel for the struct, but it won't be used. - let dummy_model = ddos::model::TrainedModel::empty(ddos_cfg.k, ddos_cfg.threshold); - let detector = Arc::new(ddos::detector::DDoSDetector::new_ensemble(dummy_model, ddos_cfg)); - tracing::info!( - k = ddos_cfg.k, - threshold = ddos_cfg.threshold, - "DDoS ensemble detector enabled" - ); - Some(detector) - } else if let Some(ref model_path) = ddos_cfg.model_path { - match ddos::model::TrainedModel::load( - std::path::Path::new(model_path), - Some(ddos_cfg.k), - Some(ddos_cfg.threshold), - ) { - Ok(model) => { - let point_count = model.point_count(); - let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg)); - tracing::info!( - points = point_count, - k = ddos_cfg.k, - threshold = ddos_cfg.threshold, - "DDoS detector loaded" - ); - Some(detector) - } - Err(e) => { - tracing::warn!(error = %e, "failed to load DDoS model; detection disabled"); - None - } - } - } else { - tracing::warn!("DDoS enabled but no model_path and use_ensemble=false; detection disabled"); - None + let detector = Arc::new(sunbeam_proxy::ddos::detector::DDoSDetector::new(ddos_cfg)); + tracing::info!( + threshold = ddos_cfg.threshold, + observe_only = ddos_cfg.observe_only, + "DDoS ensemble detector enabled" + ); + if ddos_cfg.observe_only { + tracing::warn!("DDoS detector in OBSERVE-ONLY mode — decisions are logged but traffic is never blocked"); } + Some(detector) } else { None } @@ -435,88 +261,35 @@ fn run_serve(upgrade: bool) -> Result<()> { None }; - // 2c. Load scanner model if configured. + // 2c. Init scanner detector if configured (ensemble: compiled-in weights). let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner { if scanner_cfg.enabled { - if scanner_cfg.use_ensemble { - // Ensemble path: compiled-in weights, no model file needed. - let detector = scanner::detector::ScannerDetector::new_ensemble(&cfg.routes); - let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector)); + let detector = scanner::detector::ScannerDetector::new(&cfg.routes); + let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector)); - // Start bot allowlist if rules are configured. - let bot_allowlist = if !scanner_cfg.allowlist.is_empty() { - let al = scanner::allowlist::BotAllowlist::spawn( - &scanner_cfg.allowlist, - scanner_cfg.bot_cache_ttl_secs, - ); - tracing::info!( - rules = scanner_cfg.allowlist.len(), - "bot allowlist enabled" - ); - Some(al) - } else { - None - }; - - tracing::info!( - threshold = scanner_cfg.threshold, - "scanner ensemble detector enabled" + let bot_allowlist = if !scanner_cfg.allowlist.is_empty() { + let al = scanner::allowlist::BotAllowlist::spawn( + &scanner_cfg.allowlist, + scanner_cfg.bot_cache_ttl_secs, ); - (Some(handle), bot_allowlist) - } else if let Some(ref model_path) = scanner_cfg.model_path { - match scanner::model::ScannerModel::load(std::path::Path::new(model_path)) { - Ok(mut model) => { - let fragment_count = model.fragments.len(); - model.threshold = scanner_cfg.threshold; - let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes); - let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector)); - - // Start bot allowlist if rules are configured. - let bot_allowlist = if !scanner_cfg.allowlist.is_empty() { - let al = scanner::allowlist::BotAllowlist::spawn( - &scanner_cfg.allowlist, - scanner_cfg.bot_cache_ttl_secs, - ); - tracing::info!( - rules = scanner_cfg.allowlist.len(), - "bot allowlist enabled" - ); - Some(al) - } else { - None - }; - - // Start background file watcher for hot-reload. - if scanner_cfg.poll_interval_secs > 0 { - let watcher_handle = handle.clone(); - let watcher_model_path = std::path::PathBuf::from(model_path); - let threshold = scanner_cfg.threshold; - let routes = cfg.routes.clone(); - let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs); - std::thread::spawn(move || { - scanner::watcher::watch_scanner_model( - watcher_handle, watcher_model_path, threshold, routes, interval, - ); - }); - } - - tracing::info!( - fragments = fragment_count, - threshold = scanner_cfg.threshold, - poll_interval_secs = scanner_cfg.poll_interval_secs, - "scanner detector loaded" - ); - (Some(handle), bot_allowlist) - } - Err(e) => { - tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled"); - (None, None) - } - } + tracing::info!( + rules = scanner_cfg.allowlist.len(), + "bot allowlist enabled" + ); + Some(al) } else { - tracing::warn!("scanner enabled but no model_path and use_ensemble=false; scanner detection disabled"); - (None, None) + None + }; + + tracing::info!( + threshold = scanner_cfg.threshold, + observe_only = scanner_cfg.observe_only, + "scanner ensemble detector enabled" + ); + if scanner_cfg.observe_only { + tracing::warn!("scanner detector in OBSERVE-ONLY mode — decisions are logged but traffic is never blocked"); } + (Some(handle), bot_allowlist) } else { (None, None) } @@ -617,6 +390,8 @@ fn run_serve(upgrade: bool) -> Result<()> { &cfg.rate_limit.as_ref().map(|rl| rl.bypass_cidrs.clone()).unwrap_or_default(), ), cluster: cluster_handle, + ddos_observe_only: cfg.ddos.as_ref().map(|d| d.observe_only).unwrap_or(false), + scanner_observe_only: cfg.scanner.as_ref().map(|s| s.observe_only).unwrap_or(false), }; let mut svc = http_proxy_service(&server.configuration, proxy); diff --git a/src/proxy.rs b/src/proxy.rs index b515a16..0511b75 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use crate::acme::AcmeRoutes; use crate::cluster::ClusterHandle; use crate::config::RouteConfig; @@ -32,9 +35,9 @@ pub struct SunbeamProxy { pub routes: Vec, /// Per-challenge route table populated by the Ingress watcher. pub acme_routes: AcmeRoutes, - /// Optional KNN-based DDoS detector. + /// Optional DDoS detector (ensemble: decision tree + MLP). pub ddos_detector: Option>, - /// Optional per-request scanner detector (hot-reloadable via ArcSwap). + /// Optional per-request scanner detector (ensemble: decision tree + MLP). pub scanner_detector: Option>>, /// Optional verified-bot allowlist (bypasses scanner for known crawlers/agents). pub bot_allowlist: Option>, @@ -48,6 +51,10 @@ pub struct SunbeamProxy { pub pipeline_bypass_cidrs: Vec, /// Optional cluster handle for multi-node bandwidth tracking. pub cluster: Option>, + /// When true, DDoS detector logs decisions but never blocks traffic. + pub ddos_observe_only: bool, + /// When true, scanner detector logs decisions but never blocks traffic. + pub scanner_observe_only: bool, } pub struct RequestCtx { @@ -341,7 +348,7 @@ impl ProxyHttp for SunbeamProxy { metrics::DDOS_DECISIONS.with_label_values(&[decision]).inc(); - if matches!(ddos_action, DDoSAction::Block) { + if matches!(ddos_action, DDoSAction::Block) && !self.ddos_observe_only { let mut resp = ResponseHeader::build(429, None)?; resp.insert_header("Retry-After", "60")?; resp.insert_header("Content-Length", "0")?; @@ -426,7 +433,7 @@ impl ProxyHttp for SunbeamProxy { .with_label_values(&[decision, reason]) .inc(); - if decision == "block" { + if decision == "block" && !self.scanner_observe_only { let mut resp = ResponseHeader::build(403, None)?; resp.insert_header("Content-Length", "0")?; session.write_response_header(Box::new(resp), true).await?; @@ -1150,6 +1157,21 @@ impl ProxyHttp for SunbeamProxy { .and_then(|v| v.to_str().ok()) .unwrap_or("-"); let query = session.req_header().uri.query().unwrap_or(""); + let response_bytes = session.body_bytes_sent(); + let http_version = format!("{:?}", session.req_header().version); + let header_count = session.req_header().headers.len() as u16; + let accept_encoding = session + .req_header() + .headers + .get("accept-encoding") + .and_then(|v| v.to_str().ok()) + .unwrap_or("-"); + let connection = session + .req_header() + .headers + .get("connection") + .and_then(|v| v.to_str().ok()) + .unwrap_or("-"); tracing::info!( target = "audit", @@ -1162,14 +1184,19 @@ impl ProxyHttp for SunbeamProxy { status, duration_ms, content_length, + response_bytes, user_agent, referer, accept_language, accept, + accept_encoding, has_cookies, cf_country, backend, error = error_str, + http_version, + header_count, + connection, "request" ); diff --git a/src/scanner/csic.rs b/src/scanner/csic.rs index 314a7fb..ad5c780 100644 --- a/src/scanner/csic.rs +++ b/src/scanner/csic.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! Fetch and convert the CSIC 2010 HTTP dataset into labeled training samples. //! //! The CSIC 2010 dataset contains raw HTTP/1.1 requests (normal + anomalous) @@ -65,6 +68,7 @@ struct ParsedRequest { content_length: u64, referer: String, accept_language: String, + accept: String, } fn parse_csic_content(content: &str) -> Vec { @@ -158,6 +162,7 @@ fn parse_single_request(lines: &[&str]) -> Option { content_length, referer: get_header("Referer").unwrap_or("-").to_string(), accept_language: get_header("Accept-Language").unwrap_or("-").to_string(), + accept: get_header("Accept").unwrap_or("-").to_string(), }) } @@ -219,11 +224,12 @@ fn to_audit_fields( // For anomalous samples, simulate real scanner behavior: // strip cookies/referer/accept-language that CSIC attacks have from their session. let (has_cookies, referer, accept_language, user_agent) = if label != "normal" { - let referer = None; + let referer = "-".to_string(); let accept_language = if rng.next_f64() < 0.8 { - None + "-".to_string() } else { - Some(req.accept_language.clone()).filter(|a| a != "-") + let al = req.accept_language.clone(); + if al == "-" { "-".to_string() } else { al } }; let r = rng.next_f64(); let user_agent = if r < 0.15 { @@ -241,12 +247,26 @@ fn to_audit_fields( } else { ( req.has_cookies, - Some(req.referer.clone()).filter(|r| r != "-"), - Some(req.accept_language.clone()).filter(|a| a != "-"), + if req.referer == "-" { "-".to_string() } else { req.referer.clone() }, + if req.accept_language == "-" { "-".to_string() } else { req.accept_language.clone() }, req.user_agent.clone(), ) }; + // For normal traffic, preserve Accept header from CSIC request. + // For attacks, degrade it to simulate scanner behavior. + let accept = if label == "normal" { + if req.accept == "-" || req.accept.is_empty() { + "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8".to_string() + } else { + req.accept.clone() + } + } else if rng.next_f64() < 0.6 { + "*/*".to_string() + } else { + req.accept.clone() + }; + AuditFields { method: req.method.clone(), host, @@ -263,9 +283,10 @@ fn to_audit_fields( duration_ms: rng.next_usize(50) as u64 + 1, content_length: req.content_length, user_agent, - has_cookies: Some(has_cookies), + has_cookies, referer, accept_language, + accept, backend: if label == "normal" { format!("{host_prefix}-svc:8080") } else { @@ -274,6 +295,7 @@ fn to_audit_fields( label: Some( if label == "normal" { "normal" } else { "attack" }.to_string(), ), + ..AuditFields::default() } } @@ -343,6 +365,7 @@ mod tests { assert_eq!(req.path, "/index.html"); assert!(req.has_cookies); assert_eq!(req.user_agent, "Mozilla/5.0"); + assert_eq!(req.accept, "text/html"); } #[test] @@ -374,11 +397,12 @@ mod tests { content_length: 100, referer: "https://example.com".to_string(), accept_language: "en-US".to_string(), + accept: "text/html".to_string(), }; let mut rng = Rng::new(42); let fields = to_audit_fields(&req, "normal", DEFAULT_HOSTS, &mut rng); assert_eq!(fields.label.as_deref(), Some("normal")); - assert!(fields.has_cookies.unwrap_or(false)); + assert!(fields.has_cookies); assert!(fields.host.ends_with(".sunbeam.pt")); } @@ -393,11 +417,12 @@ mod tests { content_length: 0, referer: "https://example.com".to_string(), accept_language: "en-US".to_string(), + accept: "text/html".to_string(), }; let mut rng = Rng::new(42); let fields = to_audit_fields(&req, "anomalous", DEFAULT_HOSTS, &mut rng); assert_eq!(fields.label.as_deref(), Some("attack")); - assert!(!fields.has_cookies.unwrap_or(true)); + assert!(!fields.has_cookies); } #[test] diff --git a/src/scanner/detector.rs b/src/scanner/detector.rs index aeb3baa..2d7ac67 100644 --- a/src/scanner/detector.rs +++ b/src/scanner/detector.rs @@ -1,9 +1,9 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use crate::config::RouteConfig; -use crate::scanner::features::{ - self, fx_hash_bytes, ScannerNormParams, SUSPICIOUS_EXTENSIONS_LIST, NUM_SCANNER_FEATURES, - NUM_SCANNER_WEIGHTS, -}; -use crate::scanner::model::{ScannerAction, ScannerModel, ScannerVerdict}; +use crate::scanner::features::{self, fx_hash_bytes, SUSPICIOUS_EXTENSIONS_LIST}; +use crate::scanner::model::{ScannerAction, ScannerVerdict}; use rustc_hash::FxHashSet; /// Immutable, zero-state per-request scanner detector. @@ -12,44 +12,10 @@ pub struct ScannerDetector { fragment_hashes: FxHashSet, extension_hashes: FxHashSet, configured_hosts: FxHashSet, - weights: [f64; NUM_SCANNER_WEIGHTS], - threshold: f64, - norm_params: ScannerNormParams, - use_ensemble: bool, } impl ScannerDetector { - pub fn new(model: &ScannerModel, routes: &[RouteConfig]) -> Self { - let fragment_hashes: FxHashSet = model - .fragments - .iter() - .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) - .collect(); - - let extension_hashes: FxHashSet = SUSPICIOUS_EXTENSIONS_LIST - .iter() - .map(|e| fx_hash_bytes(e.as_bytes())) - .collect(); - - let configured_hosts: FxHashSet = routes - .iter() - .map(|r| fx_hash_bytes(r.host_prefix.as_bytes())) - .collect(); - - Self { - fragment_hashes, - extension_hashes, - configured_hosts, - weights: model.weights, - threshold: model.threshold, - norm_params: model.norm_params.clone(), - use_ensemble: false, - } - } - - /// Create a detector that uses the ensemble (decision tree + MLP) path - /// instead of the linear model. No model file needed — weights are compiled in. - pub fn new_ensemble(routes: &[RouteConfig]) -> Self { + pub fn new(routes: &[RouteConfig]) -> Self { let fragment_hashes: FxHashSet = crate::scanner::train::DEFAULT_FRAGMENTS .iter() .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) @@ -69,13 +35,6 @@ impl ScannerDetector { fragment_hashes, extension_hashes, configured_hosts, - weights: [0.0; NUM_SCANNER_WEIGHTS], - threshold: 0.5, - norm_params: ScannerNormParams { - mins: [0.0; NUM_SCANNER_FEATURES], - maxs: [1.0; NUM_SCANNER_FEATURES], - }, - use_ensemble: true, } } @@ -98,8 +57,6 @@ impl ScannerDetector { content_length: u64, ) -> ScannerVerdict { // Hard allowlist: obviously legitimate traffic bypasses the model. - // This prevents model drift from ever blocking real users and ensures - // the training pipeline always has clean positive labels. let host_known = { let hash = features::fx_hash_bytes(host_prefix.as_bytes()); self.configured_hosts.contains(&hash) @@ -121,95 +78,32 @@ impl ScannerDetector { }; } - if self.use_ensemble { - // Ensemble path: extract f32 features → decision tree + MLP. - let raw_f32 = features::extract_features_f32( - method, path, host_prefix, - has_cookies, has_referer, has_accept_language, - accept, user_agent, content_length, - &self.fragment_hashes, &self.extension_hashes, &self.configured_hosts, - ); - let ev = crate::ensemble::scanner::scanner_ensemble_predict(&raw_f32); - crate::metrics::SCANNER_ENSEMBLE_PATH - .with_label_values(&[match ev.path { - crate::ensemble::scanner::EnsemblePath::TreeBlock => "tree_block", - crate::ensemble::scanner::EnsemblePath::TreeAllow => "tree_allow", - crate::ensemble::scanner::EnsemblePath::Mlp => "mlp", - }]) - .inc(); - return ev.into(); - } - - // 1. Extract 12 features - let raw = features::extract_features( - method, - path, - host_prefix, - has_cookies, - has_referer, - has_accept_language, - accept, - user_agent, - content_length, - &self.fragment_hashes, - &self.extension_hashes, - &self.configured_hosts, + // Ensemble path: extract f32 features → decision tree + MLP. + let raw_f32 = features::extract_features_f32( + method, path, host_prefix, + has_cookies, has_referer, has_accept_language, + accept, user_agent, content_length, + &self.fragment_hashes, &self.extension_hashes, &self.configured_hosts, ); - - // 2. Normalize - let f = self.norm_params.normalize(&raw); - - // 3. Compute score = bias + dot(weights, features) + interaction terms - let mut score = self.weights[NUM_SCANNER_FEATURES + 2]; // bias (index 14) - for (i, &fi) in f.iter().enumerate().take(NUM_SCANNER_FEATURES) { - score += self.weights[i] * fi; - } - // Interaction: suspicious_path AND no_cookies - score += self.weights[12] * f[0] * (1.0 - f[3]); - // Interaction: unknown_host AND no_accept_language - score += self.weights[13] * (1.0 - f[9]) * (1.0 - f[5]); - - // 4. Threshold - let action = if score > self.threshold { - ScannerAction::Block - } else { - ScannerAction::Allow - }; - - ScannerVerdict { - action, - score, - reason: "model", - } + let ev = crate::ensemble::scanner::scanner_ensemble_predict(&raw_f32); + crate::metrics::SCANNER_ENSEMBLE_PATH + .with_label_values(&[match ev.path { + crate::ensemble::scanner::EnsemblePath::TreeBlock => "tree_block", + crate::ensemble::scanner::EnsemblePath::TreeAllow => "tree_allow", + crate::ensemble::scanner::EnsemblePath::Mlp => "mlp", + }]) + .inc(); + ev.into() } } #[cfg(test)] mod tests { use super::*; - use crate::scanner::features::NUM_SCANNER_FEATURES; + use crate::config::RouteConfig; - fn make_detector(weights: [f64; NUM_SCANNER_WEIGHTS], threshold: f64) -> ScannerDetector { - let model = ScannerModel { - weights, - threshold, - norm_params: ScannerNormParams { - mins: [0.0; NUM_SCANNER_FEATURES], - maxs: [1.0; NUM_SCANNER_FEATURES], - }, - fragments: vec![ - ".env".into(), - "wp-admin".into(), - "wp-login".into(), - "phpinfo".into(), - "phpmyadmin".into(), - ".git".into(), - "cgi-bin".into(), - ".htaccess".into(), - ".htpasswd".into(), - ], - }; - let routes = vec![RouteConfig { + fn test_routes() -> Vec { + vec![RouteConfig { host_prefix: "app".into(), backend: "http://127.0.0.1:8080".into(), websocket: false, @@ -221,35 +115,12 @@ mod tests { body_rewrites: vec![], response_headers: vec![], cache: None, - }]; - ScannerDetector::new(&model, &routes) - } - - /// Weights tuned to block scanner-like requests: - /// High weight on suspicious_path (w[0]), no_cookies interaction (w[12]), - /// has_suspicious_extension (w[2]), traversal (w[11]). - /// Negative weight on has_cookies (w[3]), has_referer (w[4]), - /// accept_quality (w[6]), ua_category (w[7]), host_is_configured (w[9]). - fn attack_tuned_weights() -> [f64; NUM_SCANNER_WEIGHTS] { - let mut w = [0.0; NUM_SCANNER_WEIGHTS]; - w[0] = 2.0; // suspicious_path_score - w[2] = 2.0; // has_suspicious_extension - w[3] = -2.0; // has_cookies (negative = good) - w[4] = -1.0; // has_referer (negative = good) - w[5] = -1.0; // has_accept_language (negative = good) - w[6] = -0.5; // accept_quality (negative = good) - w[7] = -1.0; // ua_category (negative = browser is good) - w[9] = -1.5; // host_is_configured (negative = known host is good) - w[11] = 2.0; // path_has_traversal - w[12] = 1.5; // interaction: suspicious_path AND no_cookies - w[13] = 1.0; // interaction: unknown_host AND no_accept_lang - w[14] = 0.5; // bias - w + }] } #[test] fn test_normal_browser_request_allowed() { - let detector = make_detector(attack_tuned_weights(), 0.5); + let detector = ScannerDetector::new(&test_routes()); let verdict = detector.check( "GET", "/blog/hello-world", @@ -267,7 +138,7 @@ mod tests { #[test] fn test_api_client_with_auth_allowed() { - let detector = make_detector(attack_tuned_weights(), 0.5); + let detector = ScannerDetector::new(&test_routes()); let verdict = detector.check( "POST", "/api/v1/data", @@ -285,81 +156,24 @@ mod tests { #[test] fn test_env_probe_blocked() { - let detector = make_detector(attack_tuned_weights(), 0.5); + let detector = ScannerDetector::new(&test_routes()); let verdict = detector.check( "GET", "/.env", "unknown", - false, // no cookies - false, // no referer - false, // no accept-language + false, + false, + false, "*/*", "curl/7.0", 0, ); assert_eq!(verdict.action, ScannerAction::Block); - assert_eq!(verdict.reason, "model"); - } - - #[test] - fn test_wordpress_scan_blocked() { - let detector = make_detector(attack_tuned_weights(), 0.5); - let verdict = detector.check( - "GET", - "/wp-admin/install.php", - "unknown", - false, - false, - false, - "*/*", - "", - 0, - ); - assert_eq!(verdict.action, ScannerAction::Block); - assert_eq!(verdict.reason, "model"); - } - - #[test] - fn test_path_traversal_blocked() { - let detector = make_detector(attack_tuned_weights(), 0.5); - let verdict = detector.check( - "GET", - "/etc/../../../passwd", - "unknown", - false, - false, - false, - "*/*", - "python-requests/2.28", - 0, - ); - assert_eq!(verdict.action, ScannerAction::Block); - assert_eq!(verdict.reason, "model"); - } - - #[test] - fn test_legitimate_php_path_allowed() { - let detector = make_detector(attack_tuned_weights(), 0.5); - // "/blog/php-is-dead" — "php-is-dead" is not a known fragment - // has_cookies=true + known host "app" → hits allowlist - let verdict = detector.check( - "GET", - "/blog/php-is-dead", - "app", - true, - true, - true, - "text/html", - "Mozilla/5.0 Chrome/120", - 0, - ); - assert_eq!(verdict.action, ScannerAction::Allow); } #[test] fn test_allowlist_browser_on_known_host() { - let detector = make_detector(attack_tuned_weights(), 0.5); - // No cookies but browser UA + accept-language + known host → allowlist + let detector = ScannerDetector::new(&test_routes()); let verdict = detector.check( "GET", "/", @@ -374,22 +188,4 @@ mod tests { assert_eq!(verdict.action, ScannerAction::Allow); assert_eq!(verdict.reason, "allowlist:host+browser"); } - - #[test] - fn test_model_path_for_non_allowlisted() { - let detector = make_detector(attack_tuned_weights(), 0.5); - // Unknown host, no cookies, curl UA → goes through model - let verdict = detector.check( - "GET", - "/robots.txt", - "unknown", - false, - false, - false, - "*/*", - "curl/7.0", - 0, - ); - assert_eq!(verdict.reason, "model"); - } } diff --git a/src/scanner/mod.rs b/src/scanner/mod.rs index bfe6192..8778b44 100644 --- a/src/scanner/mod.rs +++ b/src/scanner/mod.rs @@ -1,7 +1,9 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + pub mod allowlist; pub mod csic; pub mod detector; pub mod features; pub mod model; pub mod train; -pub mod watcher; diff --git a/src/scanner/model.rs b/src/scanner/model.rs index cdd2bab..9507536 100644 --- a/src/scanner/model.rs +++ b/src/scanner/model.rs @@ -1,7 +1,5 @@ -use crate::scanner::features::{ScannerNormParams, NUM_SCANNER_WEIGHTS}; -use anyhow::{Context, Result}; -use serde::{Deserialize, Serialize}; -use std::path::Path; +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ScannerAction { @@ -16,74 +14,3 @@ pub struct ScannerVerdict { /// Why this decision was made: "model", "allowlist", etc. pub reason: &'static str, } - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ScannerModel { - pub weights: [f64; NUM_SCANNER_WEIGHTS], - pub threshold: f64, - pub norm_params: ScannerNormParams, - /// Suspicious path fragments used during training — kept for reproducibility. - pub fragments: Vec, -} - -impl ScannerModel { - pub fn save(&self, path: &Path) -> Result<()> { - let data = bincode::serialize(self).context("serializing scanner model")?; - std::fs::write(path, data) - .with_context(|| format!("writing scanner model to {}", path.display()))?; - Ok(()) - } - - pub fn load(path: &Path) -> Result { - let data = std::fs::read(path) - .with_context(|| format!("reading scanner model from {}", path.display()))?; - bincode::deserialize(&data).context("deserializing scanner model") - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::scanner::features::NUM_SCANNER_FEATURES; - - #[test] - fn test_serialization_roundtrip() { - let model = ScannerModel { - weights: [0.1; NUM_SCANNER_WEIGHTS], - threshold: 0.5, - norm_params: ScannerNormParams { - mins: [0.0; NUM_SCANNER_FEATURES], - maxs: [1.0; NUM_SCANNER_FEATURES], - }, - fragments: vec![".env".into(), "wp-admin".into()], - }; - let data = bincode::serialize(&model).unwrap(); - let loaded: ScannerModel = bincode::deserialize(&data).unwrap(); - assert_eq!(loaded.weights, model.weights); - assert_eq!(loaded.threshold, model.threshold); - assert_eq!(loaded.fragments, model.fragments); - } - - #[test] - fn test_save_load_file() { - let dir = std::env::temp_dir().join("scanner_model_test"); - std::fs::create_dir_all(&dir).unwrap(); - let path = dir.join("test_model.bin"); - - let model = ScannerModel { - weights: [0.5; NUM_SCANNER_WEIGHTS], - threshold: 0.42, - norm_params: ScannerNormParams { - mins: [0.0; NUM_SCANNER_FEATURES], - maxs: [1.0; NUM_SCANNER_FEATURES], - }, - fragments: vec!["phpinfo".into()], - }; - model.save(&path).unwrap(); - let loaded = ScannerModel::load(&path).unwrap(); - assert_eq!(loaded.threshold, 0.42); - assert_eq!(loaded.fragments, vec!["phpinfo"]); - - let _ = std::fs::remove_dir_all(&dir); - } -} diff --git a/src/scanner/train.rs b/src/scanner/train.rs index 0dc770b..81307bb 100644 --- a/src/scanner/train.rs +++ b/src/scanner/train.rs @@ -1,9 +1,30 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use crate::ddos::audit_log::{AuditLog, AuditFields}; use crate::scanner::features::{ self, fx_hash_bytes, ScannerFeatureVector, ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS, }; -use crate::scanner::model::ScannerModel; +use serde::{Deserialize, Serialize}; + +/// Legacy linear scanner model — kept for the `train-scanner` CLI command. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScannerModel { + pub weights: [f64; NUM_SCANNER_WEIGHTS], + pub threshold: f64, + pub norm_params: ScannerNormParams, + pub fragments: Vec, +} + +impl ScannerModel { + pub fn save(&self, path: &Path) -> Result<()> { + let data = bincode::serialize(self).context("serializing scanner model")?; + std::fs::write(path, data) + .with_context(|| format!("writing scanner model to {}", path.display()))?; + Ok(()) + } +} use anyhow::{Context, Result}; use rustc_hash::FxHashSet; use std::io::BufRead; @@ -88,17 +109,9 @@ pub fn train_and_evaluate( } for (fields, host_prefix) in &parsed_entries { - let has_cookies = fields.has_cookies.unwrap_or(false); - let has_referer = fields - .referer - .as_ref() - .map(|r| r != "-" && !r.is_empty()) - .unwrap_or(false); - let has_accept_language = fields - .accept_language - .as_ref() - .map(|a| a != "-" && !a.is_empty()) - .unwrap_or(false); + let has_cookies = fields.has_cookies; + let has_referer = !fields.referer.is_empty() && fields.referer != "-"; + let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-"; let feats = features::extract_features( &fields.method, @@ -149,17 +162,9 @@ pub fn train_and_evaluate( log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes())); } for (fields, host_prefix) in &csic_entries { - let has_cookies = fields.has_cookies.unwrap_or(false); - let has_referer = fields - .referer - .as_ref() - .map(|r| r != "-" && !r.is_empty()) - .unwrap_or(false); - let has_accept_language = fields - .accept_language - .as_ref() - .map(|a| a != "-" && !a.is_empty()) - .unwrap_or(false); + let has_cookies = fields.has_cookies; + let has_referer = !fields.referer.is_empty() && fields.referer != "-"; + let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-"; let feats = features::extract_features( &fields.method, @@ -288,17 +293,9 @@ pub fn run(args: TrainScannerArgs) -> Result<()> { } for (fields, host_prefix) in &parsed_entries { - let has_cookies = fields.has_cookies.unwrap_or(false); - let has_referer = fields - .referer - .as_ref() - .map(|r| r != "-" && !r.is_empty()) - .unwrap_or(false); - let has_accept_language = fields - .accept_language - .as_ref() - .map(|a| a != "-" && !a.is_empty()) - .unwrap_or(false); + let has_cookies = fields.has_cookies; + let has_referer = !fields.referer.is_empty() && fields.referer != "-"; + let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-"; let feats = features::extract_features( &fields.method, @@ -352,17 +349,9 @@ pub fn run(args: TrainScannerArgs) -> Result<()> { log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes())); } for (fields, host_prefix) in &csic_entries { - let has_cookies = fields.has_cookies.unwrap_or(false); - let has_referer = fields - .referer - .as_ref() - .map(|r| r != "-" && !r.is_empty()) - .unwrap_or(false); - let has_accept_language = fields - .accept_language - .as_ref() - .map(|a| a != "-" && !a.is_empty()) - .unwrap_or(false); + let has_cookies = fields.has_cookies; + let has_referer = !fields.referer.is_empty() && fields.referer != "-"; + let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-"; let feats = features::extract_features( &fields.method, diff --git a/src/scanner/watcher.rs b/src/scanner/watcher.rs deleted file mode 100644 index 5660f99..0000000 --- a/src/scanner/watcher.rs +++ /dev/null @@ -1,51 +0,0 @@ -use crate::config::RouteConfig; -use crate::scanner::detector::ScannerDetector; -use crate::scanner::model::ScannerModel; -use arc_swap::ArcSwap; -use std::path::PathBuf; -use std::sync::Arc; -use std::time::Duration; - -/// Poll the scanner model file for mtime changes and hot-swap the detector. -/// Runs forever on a dedicated OS thread — never returns. -pub fn watch_scanner_model( - handle: Arc>, - model_path: PathBuf, - threshold: f64, - routes: Vec, - poll_interval: Duration, -) { - let mut last_mtime = std::fs::metadata(&model_path) - .and_then(|m| m.modified()) - .ok(); - - loop { - std::thread::sleep(poll_interval); - - let current_mtime = match std::fs::metadata(&model_path).and_then(|m| m.modified()) { - Ok(t) => t, - Err(_) => continue, - }; - - if Some(current_mtime) == last_mtime { - continue; - } - - match ScannerModel::load(&model_path) { - Ok(mut model) => { - model.threshold = threshold; - let fragment_count = model.fragments.len(); - let detector = ScannerDetector::new(&model, &routes); - handle.store(Arc::new(detector)); - last_mtime = Some(current_mtime); - tracing::info!( - fragments = fragment_count, - "scanner model hot-reloaded" - ); - } - Err(e) => { - tracing::warn!(error = %e, "failed to reload scanner model; keeping current"); - } - } - } -} diff --git a/src/training/batch.rs b/src/training/batch.rs new file mode 100644 index 0000000..830234e --- /dev/null +++ b/src/training/batch.rs @@ -0,0 +1,112 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + +//! Shared training infrastructure: Dataset adapter, Batcher, and batch types +//! for use with burn's SupervisedTraining. + +use crate::dataset::sample::TrainingSample; + +use burn::data::dataloader::batcher::Batcher; +use burn::data::dataloader::Dataset; +use burn::prelude::*; + +/// A single normalized training item ready for batching. +#[derive(Clone, Debug)] +pub struct TrainingItem { + pub features: Vec, + pub label: i32, + pub weight: f32, +} + +/// A batch of training items as tensors. +#[derive(Clone, Debug)] +pub struct TrainingBatch { + pub features: Tensor, + pub labels: Tensor, + pub weights: Tensor, +} + +/// Wraps a `Vec` as a burn `Dataset`, applying min-max +/// normalization to features at construction time. +#[derive(Clone)] +pub struct SampleDataset { + items: Vec, +} + +impl SampleDataset { + pub fn new(samples: &[TrainingSample], mins: &[f32], maxs: &[f32]) -> Self { + let items = samples + .iter() + .map(|s| { + let features: Vec = s + .features + .iter() + .enumerate() + .map(|(i, &v)| { + let range = maxs[i] - mins[i]; + if range > f32::EPSILON { + ((v - mins[i]) / range).clamp(0.0, 1.0) + } else { + 0.0 + } + }) + .collect(); + TrainingItem { + features, + label: if s.label >= 0.5 { 1 } else { 0 }, + weight: s.weight, + } + }) + .collect(); + Self { items } + } +} + +impl Dataset for SampleDataset { + fn get(&self, index: usize) -> Option { + self.items.get(index).cloned() + } + + fn len(&self) -> usize { + self.items.len() + } +} + +/// Converts a `Vec` into a `TrainingBatch` of tensors. +#[derive(Clone)] +pub struct SampleBatcher; + +impl SampleBatcher { + pub fn new() -> Self { + Self + } +} + +impl Batcher> for SampleBatcher { + fn batch(&self, items: Vec, device: &B::Device) -> TrainingBatch { + let batch_size = items.len(); + let num_features = items[0].features.len(); + + let flat_features: Vec = items + .iter() + .flat_map(|item| item.features.iter().copied()) + .collect(); + + let labels: Vec = items.iter().map(|item| item.label).collect(); + let weights: Vec = items.iter().map(|item| item.weight).collect(); + + let features = Tensor::::from_floats(flat_features.as_slice(), device) + .reshape([batch_size, num_features]); + + let labels = Tensor::::from_ints(labels.as_slice(), device); + + let weights = Tensor::::from_floats(weights.as_slice(), device) + .reshape([batch_size, 1]); + + TrainingBatch { + features, + labels, + weights, + } + } +} diff --git a/src/training/export.rs b/src/training/export.rs index 1b0af93..108c0cb 100644 --- a/src/training/export.rs +++ b/src/training/export.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! Weight export: converts trained models into standalone Rust `const` arrays //! and optionally Lean 4 definitions. //! @@ -54,7 +57,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String { writeln!(s).unwrap(); // Threshold. - writeln!(s, "pub const THRESHOLD: f32 = {:.8};", model.threshold).unwrap(); + writeln!(s, "pub const THRESHOLD: f32 = {:.8};", sanitize(model.threshold)).unwrap(); writeln!(s).unwrap(); // Normalization params. @@ -74,7 +77,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String { if i > 0 { write!(s, ", ").unwrap(); } - write!(s, "{:.8}", v).unwrap(); + write!(s, "{:.8}", sanitize(*v)).unwrap(); } writeln!(s, "],").unwrap(); } @@ -88,7 +91,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String { write_f32_array(&mut s, "W2", &model.w2); // B2. - writeln!(s, "pub const B2: f32 = {:.8};", model.b2).unwrap(); + writeln!(s, "pub const B2: f32 = {:.8};", sanitize(model.b2)).unwrap(); writeln!(s).unwrap(); // Tree nodes. @@ -207,6 +210,11 @@ pub fn export_to_file(model: &ExportedModel, path: &Path) -> Result<()> { // Helpers // --------------------------------------------------------------------------- +/// Sanitize a float for Rust source: replace NaN/Inf with 0.0. +fn sanitize(v: f32) -> f32 { + if v.is_finite() { v } else { 0.0 } +} + fn write_f32_array(s: &mut String, name: &str, values: &[f32]) { writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap(); write!(s, " ").unwrap(); @@ -218,7 +226,7 @@ fn write_f32_array(s: &mut String, name: &str, values: &[f32]) { if i > 0 && i % 8 == 0 { write!(s, "\n ").unwrap(); } - write!(s, "{:.8}", v).unwrap(); + write!(s, "{:.8}", sanitize(*v)).unwrap(); } writeln!(s, "\n];").unwrap(); writeln!(s).unwrap(); diff --git a/src/training/mlp.rs b/src/training/mlp.rs index 0172492..1a4fc5d 100644 --- a/src/training/mlp.rs +++ b/src/training/mlp.rs @@ -1,11 +1,18 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! burn-rs MLP model definition for ensemble training. //! //! A two-layer network (linear -> ReLU -> linear -> sigmoid) used as the //! "uncertain region" classifier in the tree+MLP ensemble. +use crate::training::batch::TrainingBatch; + use burn::module::Module; use burn::nn::{Linear, LinearConfig}; use burn::prelude::*; +use burn::tensor::backend::AutodiffBackend; +use burn::train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep}; /// Two-layer MLP: input -> hidden (ReLU) -> output (sigmoid). #[derive(Module, Debug)] @@ -34,24 +41,79 @@ impl MlpConfig { } impl MlpModel { - /// Forward pass: ReLU hidden activation, sigmoid output. + /// Forward pass returning raw logits (pre-sigmoid). /// /// Input shape: `[batch, input_dim]` /// Output shape: `[batch, 1]` - pub fn forward(&self, x: Tensor) -> Tensor { + pub fn forward_logits(&self, x: Tensor) -> Tensor { let h = self.linear1.forward(x); let h = burn::tensor::activation::relu(h); - let out = self.linear2.forward(h); - burn::tensor::activation::sigmoid(out) + self.linear2.forward(h) + } + + /// Forward pass with sigmoid activation for inference/export. + /// + /// Input shape: `[batch, input_dim]` + /// Output shape: `[batch, 1]` (values in [0, 1]) + pub fn forward(&self, x: Tensor) -> Tensor { + burn::tensor::activation::sigmoid(self.forward_logits(x)) + } + + /// Forward pass returning a `ClassificationOutput` for burn's training loop. + /// + /// Uses raw logits for BCE (which applies sigmoid internally) and converts + /// to two-column format `[1-p, p]` for AccuracyMetric (which uses argmax). + pub fn forward_classification( + &self, + batch: TrainingBatch, + ) -> ClassificationOutput { + let logits = self.forward_logits(batch.features); // [batch, 1] + let logits_1d = logits.clone().squeeze::<1>(); // [batch] + + // Numerically stable BCE from logits: + // loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|)) + // This avoids log(0) and exp(large) overflow. + let targets_float = batch.labels.clone().float(); // [batch] + let zeros = Tensor::zeros_like(&logits_1d); + let relu_logits = logits_1d.clone().max_pair(zeros); // max(logits, 0) + let neg_abs = logits_1d.clone().abs().neg(); // -|logits| + let log_term = neg_abs.exp().log1p(); // log(1 + exp(-|logits|)) + let per_sample = relu_logits - logits_1d.clone() * targets_float + log_term; + let loss = per_sample.mean(); // scalar [1] + + // AccuracyMetric expects [batch, num_classes] and uses argmax. + let neg_logits = logits.clone().neg(); + let output_2col = Tensor::cat(vec![neg_logits, logits], 1); // [batch, 2] + + ClassificationOutput::new(loss, output_2col, batch.labels) + } +} + +impl TrainStep for MlpModel { + type Input = TrainingBatch; + type Output = ClassificationOutput; + + fn step(&self, batch: Self::Input) -> TrainOutput { + let item = self.forward_classification(batch); + TrainOutput::new(self, item.loss.backward(), item) + } +} + +impl InferenceStep for MlpModel { + type Input = TrainingBatch; + type Output = ClassificationOutput; + + fn step(&self, batch: Self::Input) -> Self::Output { + self.forward_classification(batch) } } #[cfg(test)] mod tests { use super::*; - use burn::backend::NdArray; + use burn::backend::Wgpu; - type TestBackend = NdArray; + type TestBackend = Wgpu; #[test] fn test_forward_pass_shape() { @@ -80,7 +142,6 @@ mod tests { }; let model = config.init::(&device); - // Random-ish input values. let input = Tensor::::from_data( [[1.0, -2.0, 0.5, 3.0], [0.0, 0.0, 0.0, 0.0]], &device, diff --git a/src/training/mod.rs b/src/training/mod.rs index 1e82804..609b8fe 100644 --- a/src/training/mod.rs +++ b/src/training/mod.rs @@ -1,5 +1,10 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + pub mod tree; pub mod mlp; +pub mod batch; pub mod export; pub mod train_scanner; pub mod train_ddos; +pub mod sweep; diff --git a/src/training/sweep.rs b/src/training/sweep.rs new file mode 100644 index 0000000..911a297 --- /dev/null +++ b/src/training/sweep.rs @@ -0,0 +1,103 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + +//! Cookie weight sweep: trains full tree+MLP ensembles (GPU via wgpu) across a +//! range of cookie_weight values and reports accuracy metrics for each. + +use anyhow::{Context, Result}; +use std::path::Path; + +use crate::dataset::sample::load_dataset; +use crate::training::train_scanner::TrainScannerMlpArgs; +use crate::training::train_ddos::TrainDdosMlpArgs; + +/// Run a sweep across cookie_weight values for either scanner or ddos. +/// +/// Each trial does a full GPU training run (tree + MLP) with the specified +/// cookie_weight, writing artifacts to a temp directory. +pub fn run_cookie_sweep( + dataset_path: &str, + detector: &str, + weights_csv: Option<&str>, + tree_max_depth: usize, + tree_min_purity: f32, + min_samples_leaf: usize, +) -> Result<()> { + // Validate dataset exists and has samples. + let manifest = load_dataset(Path::new(dataset_path)) + .context("loading dataset manifest")?; + + let (cookie_idx, sample_count) = match detector { + "scanner" => (3usize, manifest.scanner_samples.len()), + "ddos" => (10usize, manifest.ddos_samples.len()), + other => anyhow::bail!("unknown detector '{}', expected 'scanner' or 'ddos'", other), + }; + + anyhow::ensure!(sample_count > 0, "no {} samples in dataset", detector); + drop(manifest); // Free memory before training loop. + + let weights: Vec = if let Some(csv) = weights_csv { + csv.split(',') + .map(|s| s.trim().parse::()) + .collect::, _>>() + .context("parsing --weights as comma-separated floats")? + } else { + (0..=10).map(|i| i as f32 / 10.0).collect() + }; + + println!( + "[sweep] {} detector, {} samples, cookie feature index: {}", + detector, sample_count, cookie_idx, + ); + println!("[sweep] training {} trials with full tree+MLP (wgpu)\n", weights.len()); + + let sweep_dir = tempfile::tempdir().context("creating temp dir for sweep")?; + + for (trial, &cw) in weights.iter().enumerate() { + let trial_dir = sweep_dir.path().join(format!("trial_{}", trial)); + std::fs::create_dir_all(&trial_dir)?; + let trial_dir_str = trial_dir.to_string_lossy().to_string(); + + println!("━━━ Trial {}/{}: cookie_weight={:.2} ━━━", trial + 1, weights.len(), cw); + + match detector { + "scanner" => { + crate::training::train_scanner::run(TrainScannerMlpArgs { + dataset_path: dataset_path.to_string(), + output_dir: trial_dir_str, + hidden_dim: 32, + epochs: 100, + learning_rate: 0.0001, + batch_size: 64, + tree_max_depth, + tree_min_purity, + min_samples_leaf, + cookie_weight: cw, + })?; + } + "ddos" => { + crate::training::train_ddos::run(TrainDdosMlpArgs { + dataset_path: dataset_path.to_string(), + output_dir: trial_dir_str, + hidden_dim: 32, + epochs: 100, + learning_rate: 0.0001, + batch_size: 64, + tree_max_depth, + tree_min_purity, + min_samples_leaf, + cookie_weight: cw, + })?; + } + _ => unreachable!(), + } + + println!(); + } + + println!("[sweep] All {} trials complete.", weights.len()); + println!("[sweep] Tip: compare tree structures and validation accuracy above."); + println!("[sweep] Look for a cookie_weight where FP rate drops without FN rate spiking."); + + Ok(()) +} diff --git a/src/training/train_ddos.rs b/src/training/train_ddos.rs index b14fb6d..f143a9c 100644 --- a/src/training/train_ddos.rs +++ b/src/training/train_ddos.rs @@ -1,19 +1,27 @@ -//! DDoS MLP+tree training loop. +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + +//! DDoS MLP+tree training loop using burn's SupervisedTraining. //! -//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP, -//! then exports the combined ensemble weights as a Rust source file that can -//! be dropped into `src/ensemble/gen/ddos_weights.rs`. +//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP +//! with cosine annealing + early stopping, then exports the combined ensemble +//! weights as a Rust source file for `src/ensemble/gen/ddos_weights.rs`. use anyhow::{Context, Result}; use std::path::Path; -use burn::backend::ndarray::NdArray; use burn::backend::Autodiff; -use burn::module::AutodiffModule; -use burn::optim::{AdamConfig, GradientsParams, Optimizer}; +use burn::backend::Wgpu; +use burn::data::dataloader::DataLoaderBuilder; +use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig; +use burn::optim::AdamConfig; use burn::prelude::*; +use burn::record::CompactRecorder; +use burn::train::metric::{AccuracyMetric, LossMetric}; +use burn::train::{Learner, SupervisedTraining}; use crate::dataset::sample::{load_dataset, TrainingSample}; +use crate::training::batch::{SampleBatcher, SampleDataset}; use crate::training::export::{export_to_file, ExportedModel}; use crate::training::mlp::MlpConfig; use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision}; @@ -21,7 +29,7 @@ use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision}; /// Number of DDoS features (matches `crate::ddos::features::NUM_FEATURES`). const NUM_FEATURES: usize = 14; -type TrainBackend = Autodiff>; +type TrainBackend = Autodiff>; /// Arguments for the DDoS MLP training command. pub struct TrainDdosMlpArgs { @@ -37,10 +45,14 @@ pub struct TrainDdosMlpArgs { pub learning_rate: f64, /// Mini-batch size (default 64). pub batch_size: usize, - /// CART max depth (default 6). + /// CART max depth (default 8). pub tree_max_depth: usize, - /// CART leaf purity threshold (default 0.90). + /// CART leaf purity threshold (default 0.98). pub tree_min_purity: f32, + /// Minimum samples in a leaf node (default 2). + pub min_samples_leaf: usize, + /// Weight for cookie feature (feature 10: cookie_ratio). 0.0 = ignore, 1.0 = full weight. + pub cookie_weight: f32, } impl Default for TrainDdosMlpArgs { @@ -50,14 +62,19 @@ impl Default for TrainDdosMlpArgs { output_dir: ".".into(), hidden_dim: 32, epochs: 100, - learning_rate: 0.001, + learning_rate: 0.0001, batch_size: 64, - tree_max_depth: 6, - tree_min_purity: 0.90, + tree_max_depth: 8, + tree_min_purity: 0.98, + min_samples_leaf: 2, + cookie_weight: 1.0, } } } +/// Index of the cookie_ratio feature in the DDoS feature vector. +const COOKIE_FEATURE_IDX: usize = 10; + /// Entry point: train DDoS ensemble and export weights. pub fn run(args: TrainDdosMlpArgs) -> Result<()> { // 1. Load dataset. @@ -86,6 +103,23 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> { // 2. Compute normalization params from training data. let (norm_mins, norm_maxs) = compute_norm_params(samples); + if args.cookie_weight < 1.0 - f32::EPSILON { + println!( + "[ddos] cookie_weight={:.2} (feature {} influence reduced)", + args.cookie_weight, COOKIE_FEATURE_IDX, + ); + } + + // MLP norm adjustment: scale cookie feature's normalization range. + let mut mlp_norm_maxs = norm_maxs.clone(); + if args.cookie_weight < 1.0 - f32::EPSILON { + let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX]; + if range > f32::EPSILON && args.cookie_weight > f32::EPSILON { + mlp_norm_maxs[COOKIE_FEATURE_IDX] = + range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX]; + } + } + // 3. Stratified 80/20 split. let (train_set, val_set) = stratified_split(samples, 0.8); println!( @@ -94,15 +128,16 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> { val_set.len() ); - // 4. Train CART tree. + // 4. Train CART tree (with cookie feature masking for reduced weight). + let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight); let tree_config = TreeConfig { max_depth: args.tree_max_depth, - min_samples_leaf: 5, + min_samples_leaf: args.min_samples_leaf, min_purity: args.tree_min_purity, num_features: NUM_FEATURES, }; - let tree_nodes = train_tree(&train_set, &tree_config); - println!("[ddos] CART tree: {} nodes", tree_nodes.len()); + let tree_nodes = train_tree(&tree_train_set, &tree_config); + println!("[ddos] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth); // Evaluate tree on validation set. let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs); @@ -112,23 +147,27 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> { tree_deferred * 100.0, ); - // 5. Train MLP on the full training set. + // 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling). let device = Default::default(); let mlp_config = MlpConfig { input_dim: NUM_FEATURES, hidden_dim: args.hidden_dim, }; + let artifact_dir = Path::new(&args.output_dir).join("ddos_artifacts"); + std::fs::create_dir_all(&artifact_dir).ok(); + let model = train_mlp( &train_set, &val_set, &mlp_config, &norm_mins, - &norm_maxs, + &mlp_norm_maxs, args.epochs, args.learning_rate, args.batch_size, &device, + &artifact_dir, ); // 6. Extract weights from trained model. @@ -136,9 +175,9 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> { &model, "ddos", &tree_nodes, - 0.5, // threshold + 0.5, &norm_mins, - &norm_maxs, + &mlp_norm_maxs, &device, ); @@ -153,6 +192,37 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> { Ok(()) } +// --------------------------------------------------------------------------- +// Cookie feature masking for CART trees +// --------------------------------------------------------------------------- + +fn mask_cookie_feature( + samples: &[TrainingSample], + cookie_idx: usize, + cookie_weight: f32, +) -> Vec { + if cookie_weight >= 1.0 - f32::EPSILON { + return samples.to_vec(); + } + samples + .iter() + .enumerate() + .map(|(i, s)| { + let mut s2 = s.clone(); + if cookie_weight < f32::EPSILON { + s2.features[cookie_idx] = 0.5; + } else { + let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42); + let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32; + if r > cookie_weight { + s2.features[cookie_idx] = 0.5; + } + } + s2 + }) + .collect() +} + // --------------------------------------------------------------------------- // Normalization // --------------------------------------------------------------------------- @@ -170,21 +240,6 @@ fn compute_norm_params(samples: &[TrainingSample]) -> (Vec, Vec) { (mins, maxs) } -fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec { - features - .iter() - .enumerate() - .map(|(i, &v)| { - let range = maxs[i] - mins[i]; - if range > f32::EPSILON { - ((v - mins[i]) / range).clamp(0.0, 1.0) - } else { - 0.0 - } - }) - .collect() -} - // --------------------------------------------------------------------------- // Stratified split // --------------------------------------------------------------------------- @@ -272,8 +327,23 @@ fn eval_tree( (accuracy, defer_rate) } +fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec { + features + .iter() + .enumerate() + .map(|(i, &v)| { + let range = maxs[i] - mins[i]; + if range > f32::EPSILON { + ((v - mins[i]) / range).clamp(0.0, 1.0) + } else { + 0.0 + } + }) + .collect() +} + // --------------------------------------------------------------------------- -// MLP training +// MLP training via SupervisedTraining // --------------------------------------------------------------------------- fn train_mlp( @@ -286,117 +356,47 @@ fn train_mlp( learning_rate: f64, batch_size: usize, device: &::Device, -) -> crate::training::mlp::MlpModel> { - let mut model = config.init::(device); - let mut optim = AdamConfig::new().init(); + artifact_dir: &Path, +) -> crate::training::mlp::MlpModel> { + let model = config.init::(device); - // Pre-normalize all training data. - let train_features: Vec> = train_set - .iter() - .map(|s| normalize_features(&s.features, mins, maxs)) - .collect(); - let train_labels: Vec = train_set.iter().map(|s| s.label).collect(); - let train_weights: Vec = train_set.iter().map(|s| s.weight).collect(); + let train_dataset = SampleDataset::new(train_set, mins, maxs); + let val_dataset = SampleDataset::new(val_set, mins, maxs); - let n = train_features.len(); + let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new()) + .batch_size(batch_size) + .shuffle(42) + .num_workers(1) + .build(train_dataset); - for epoch in 0..epochs { - let mut epoch_loss = 0.0f32; - let mut batches = 0usize; + let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new()) + .batch_size(batch_size) + .num_workers(1) + .build(val_dataset); - let mut offset = 0; - while offset < n { - let end = (offset + batch_size).min(n); - let batch_n = end - offset; + // Cosine annealing: initial_lr must be in (0.0, 1.0]. + let lr = learning_rate.min(1.0); + let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs) + .init() + .expect("valid cosine annealing config"); - // Build input tensor [batch, features]. - let flat: Vec = train_features[offset..end] - .iter() - .flat_map(|f| f.iter().copied()) - .collect(); - let x = Tensor::::from_floats(flat.as_slice(), device) - .reshape([batch_n, NUM_FEATURES]); + let learner = Learner::new( + model, + AdamConfig::new().init(), + lr_scheduler, + ); - // Labels [batch, 1]. - let y = Tensor::::from_floats( - &train_labels[offset..end], - device, - ) - .reshape([batch_n, 1]); + let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .num_epochs(epochs) + .summary() + .launch(learner); - // Sample weights [batch, 1]. - let w = Tensor::::from_floats( - &train_weights[offset..end], - device, - ) - .reshape([batch_n, 1]); - - // Forward pass. - let pred = model.forward(x); - - // Binary cross-entropy with sample weights. - let eps = 1e-7; - let pred_clamped = pred.clone().clamp(eps, 1.0 - eps); - let bce = (y.clone() * pred_clamped.clone().log() - + (y.clone().neg().add_scalar(1.0)) - * pred_clamped.neg().add_scalar(1.0).log()) - .neg(); - let weighted_bce = bce * w; - let loss = weighted_bce.mean(); - - epoch_loss += loss.clone().into_scalar().elem::(); - batches += 1; - - // Backward + optimizer step. - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &model); - model = optim.step(learning_rate, model, grads); - - offset = end; - } - - if (epoch + 1) % 10 == 0 || epoch == 0 { - let avg_loss = epoch_loss / batches as f32; - let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device); - println!( - "[ddos] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}", - epoch + 1, - epochs, - avg_loss, - val_acc, - ); - } - } - - model.valid() -} - -fn eval_mlp_accuracy( - model: &crate::training::mlp::MlpModel, - val_set: &[TrainingSample], - mins: &[f32], - maxs: &[f32], - device: &::Device, -) -> f64 { - let flat: Vec = val_set - .iter() - .flat_map(|s| normalize_features(&s.features, mins, maxs)) - .collect(); - let x = Tensor::::from_floats(flat.as_slice(), device) - .reshape([val_set.len(), NUM_FEATURES]); - - let pred = model.forward(x); - let pred_data: Vec = pred.to_data().to_vec().expect("flat vec"); - - let mut correct = 0usize; - for (i, s) in val_set.iter().enumerate() { - let p = pred_data[i]; - let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 }; - if (predicted_label - s.label).abs() < 0.1 { - correct += 1; - } - } - correct as f64 / val_set.len() as f64 + result.model } // --------------------------------------------------------------------------- @@ -404,13 +404,13 @@ fn eval_mlp_accuracy( // --------------------------------------------------------------------------- fn extract_weights( - model: &crate::training::mlp::MlpModel>, + model: &crate::training::mlp::MlpModel>, name: &str, tree_nodes: &[(u8, f32, u16, u16)], threshold: f32, norm_mins: &[f32], norm_maxs: &[f32], - _device: & as Backend>::Device, + _device: & as Backend>::Device, ) -> ExportedModel { let w1_tensor = model.linear1.weight.val(); let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val(); diff --git a/src/training/train_scanner.rs b/src/training/train_scanner.rs index f5e42fd..1f59579 100644 --- a/src/training/train_scanner.rs +++ b/src/training/train_scanner.rs @@ -1,19 +1,27 @@ -//! Scanner MLP+tree training loop. +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + +//! Scanner MLP+tree training loop using burn's SupervisedTraining. //! -//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP, -//! then exports the combined ensemble weights as a Rust source file that can -//! be dropped into `src/ensemble/gen/scanner_weights.rs`. +//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP +//! with cosine annealing + early stopping, then exports the combined ensemble +//! weights as a Rust source file for `src/ensemble/gen/scanner_weights.rs`. use anyhow::{Context, Result}; use std::path::Path; -use burn::backend::ndarray::NdArray; use burn::backend::Autodiff; -use burn::module::AutodiffModule; -use burn::optim::{AdamConfig, GradientsParams, Optimizer}; +use burn::backend::Wgpu; +use burn::data::dataloader::DataLoaderBuilder; +use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig; +use burn::optim::AdamConfig; use burn::prelude::*; +use burn::record::CompactRecorder; +use burn::train::metric::{AccuracyMetric, LossMetric}; +use burn::train::{Learner, SupervisedTraining}; use crate::dataset::sample::{load_dataset, TrainingSample}; +use crate::training::batch::{SampleBatcher, SampleDataset}; use crate::training::export::{export_to_file, ExportedModel}; use crate::training::mlp::MlpConfig; use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision}; @@ -21,7 +29,7 @@ use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision}; /// Number of scanner features (matches `crate::scanner::features::NUM_SCANNER_FEATURES`). const NUM_FEATURES: usize = 12; -type TrainBackend = Autodiff>; +type TrainBackend = Autodiff>; /// Arguments for the scanner MLP training command. pub struct TrainScannerMlpArgs { @@ -37,10 +45,14 @@ pub struct TrainScannerMlpArgs { pub learning_rate: f64, /// Mini-batch size (default 64). pub batch_size: usize, - /// CART max depth (default 6). + /// CART max depth (default 8). pub tree_max_depth: usize, - /// CART leaf purity threshold (default 0.90). + /// CART leaf purity threshold (default 0.98). pub tree_min_purity: f32, + /// Minimum samples in a leaf node (default 2). + pub min_samples_leaf: usize, + /// Weight for cookie feature (feature 3: has_cookies). 0.0 = ignore, 1.0 = full weight. + pub cookie_weight: f32, } impl Default for TrainScannerMlpArgs { @@ -50,14 +62,19 @@ impl Default for TrainScannerMlpArgs { output_dir: ".".into(), hidden_dim: 32, epochs: 100, - learning_rate: 0.001, + learning_rate: 0.0001, batch_size: 64, - tree_max_depth: 6, - tree_min_purity: 0.90, + tree_max_depth: 8, + tree_min_purity: 0.98, + min_samples_leaf: 2, + cookie_weight: 1.0, } } } +/// Index of the has_cookies feature in the scanner feature vector. +const COOKIE_FEATURE_IDX: usize = 3; + /// Entry point: train scanner ensemble and export weights. pub fn run(args: TrainScannerMlpArgs) -> Result<()> { // 1. Load dataset. @@ -86,6 +103,27 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> { // 2. Compute normalization params from training data. let (norm_mins, norm_maxs) = compute_norm_params(samples); + // Apply cookie_weight: for the MLP, we scale the normalization range so + // the feature contributes less gradient signal. For the CART tree, scaling + // doesn't help (the tree just adjusts its threshold), so we mask the feature + // to a constant on a fraction of training samples to degrade its Gini gain. + if args.cookie_weight < 1.0 - f32::EPSILON { + println!( + "[scanner] cookie_weight={:.2} (feature {} influence reduced)", + args.cookie_weight, COOKIE_FEATURE_IDX, + ); + } + + // MLP norm adjustment: scale the cookie feature's normalization range. + let mut mlp_norm_maxs = norm_maxs.clone(); + if args.cookie_weight < 1.0 - f32::EPSILON { + let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX]; + if range > f32::EPSILON && args.cookie_weight > f32::EPSILON { + mlp_norm_maxs[COOKIE_FEATURE_IDX] = + range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX]; + } + } + // 3. Stratified 80/20 split. let (train_set, val_set) = stratified_split(samples, 0.8); println!( @@ -94,17 +132,18 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> { val_set.len() ); - // 4. Train CART tree. + // 4. Train CART tree (with cookie feature masking for reduced weight). + let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight); let tree_config = TreeConfig { max_depth: args.tree_max_depth, - min_samples_leaf: 5, + min_samples_leaf: args.min_samples_leaf, min_purity: args.tree_min_purity, num_features: NUM_FEATURES, }; - let tree_nodes = train_tree(&train_set, &tree_config); - println!("[scanner] CART tree: {} nodes", tree_nodes.len()); + let tree_nodes = train_tree(&tree_train_set, &tree_config); + println!("[scanner] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth); - // Evaluate tree on validation set. + // Evaluate tree on validation set (use original norms — tree learned on masked features). let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs); println!( "[scanner] tree validation: {:.2}% correct (of decided), {:.1}% deferred", @@ -112,35 +151,38 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> { tree_deferred * 100.0, ); - // 5. Train MLP on the full training set (the MLP only fires on Defer - // at inference time, but we train it on all data so it learns the - // full decision boundary). + // 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling). let device = Default::default(); let mlp_config = MlpConfig { input_dim: NUM_FEATURES, hidden_dim: args.hidden_dim, }; + let artifact_dir = Path::new(&args.output_dir).join("scanner_artifacts"); + std::fs::create_dir_all(&artifact_dir).ok(); + let model = train_mlp( &train_set, &val_set, &mlp_config, &norm_mins, - &norm_maxs, + &mlp_norm_maxs, args.epochs, args.learning_rate, args.batch_size, &device, + &artifact_dir, ); - // 6. Extract weights from trained model. + // 6. Extract weights from trained model (export mlp_norm_maxs so inference + // automatically applies the same cookie scaling). let exported = extract_weights( &model, "scanner", &tree_nodes, - 0.5, // threshold + 0.5, &norm_mins, - &norm_maxs, + &mlp_norm_maxs, &device, ); @@ -155,6 +197,46 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> { Ok(()) } +// --------------------------------------------------------------------------- +// Cookie feature masking for CART trees +// --------------------------------------------------------------------------- + +/// Mask the cookie feature to reduce its influence on CART tree training. +/// +/// Scaling a binary feature doesn't reduce its Gini gain — the tree just adjusts +/// the split threshold. Instead, we mask (set to 0.5) a fraction of samples so +/// the feature's apparent class-separation degrades. +/// +/// - `cookie_weight = 0.0` → fully masked (feature is constant 0.5, zero info gain) +/// - `cookie_weight = 0.5` → 50% of samples masked (noisy, reduced gain) +/// - `cookie_weight = 1.0` → no masking (full feature) +fn mask_cookie_feature( + samples: &[TrainingSample], + cookie_idx: usize, + cookie_weight: f32, +) -> Vec { + if cookie_weight >= 1.0 - f32::EPSILON { + return samples.to_vec(); + } + samples + .iter() + .enumerate() + .map(|(i, s)| { + let mut s2 = s.clone(); + if cookie_weight < f32::EPSILON { + s2.features[cookie_idx] = 0.5; + } else { + let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42); + let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32; + if r > cookie_weight { + s2.features[cookie_idx] = 0.5; + } + } + s2 + }) + .collect() +} + // --------------------------------------------------------------------------- // Normalization // --------------------------------------------------------------------------- @@ -172,21 +254,6 @@ fn compute_norm_params(samples: &[TrainingSample]) -> (Vec, Vec) { (mins, maxs) } -fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec { - features - .iter() - .enumerate() - .map(|(i, &v)| { - let range = maxs[i] - mins[i]; - if range > f32::EPSILON { - ((v - mins[i]) / range).clamp(0.0, 1.0) - } else { - 0.0 - } - }) - .collect() -} - // --------------------------------------------------------------------------- // Stratified split // --------------------------------------------------------------------------- @@ -195,7 +262,6 @@ fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec = samples.iter().filter(|s| s.label >= 0.5).collect(); let mut normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label < 0.5).collect(); - // Deterministic shuffle using a simple index permutation seeded by length. deterministic_shuffle(&mut attacks); deterministic_shuffle(&mut normals); @@ -224,7 +290,6 @@ fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec(items: &mut [T]) { - // Simple Fisher-Yates with a fixed LCG seed for reproducibility. let mut rng = 42u64; for i in (1..items.len()).rev() { rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); @@ -276,8 +341,23 @@ fn eval_tree( (accuracy, defer_rate) } +fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec { + features + .iter() + .enumerate() + .map(|(i, &v)| { + let range = maxs[i] - mins[i]; + if range > f32::EPSILON { + ((v - mins[i]) / range).clamp(0.0, 1.0) + } else { + 0.0 + } + }) + .collect() +} + // --------------------------------------------------------------------------- -// MLP training +// MLP training via SupervisedTraining // --------------------------------------------------------------------------- fn train_mlp( @@ -290,119 +370,47 @@ fn train_mlp( learning_rate: f64, batch_size: usize, device: &::Device, -) -> crate::training::mlp::MlpModel> { - let mut model = config.init::(device); - let mut optim = AdamConfig::new().init(); + artifact_dir: &Path, +) -> crate::training::mlp::MlpModel> { + let model = config.init::(device); - // Pre-normalize all training data. - let train_features: Vec> = train_set - .iter() - .map(|s| normalize_features(&s.features, mins, maxs)) - .collect(); - let train_labels: Vec = train_set.iter().map(|s| s.label).collect(); - let train_weights: Vec = train_set.iter().map(|s| s.weight).collect(); + let train_dataset = SampleDataset::new(train_set, mins, maxs); + let val_dataset = SampleDataset::new(val_set, mins, maxs); - let n = train_features.len(); + let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new()) + .batch_size(batch_size) + .shuffle(42) + .num_workers(1) + .build(train_dataset); - for epoch in 0..epochs { - let mut epoch_loss = 0.0f32; - let mut batches = 0usize; + let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new()) + .batch_size(batch_size) + .num_workers(1) + .build(val_dataset); - let mut offset = 0; - while offset < n { - let end = (offset + batch_size).min(n); - let batch_n = end - offset; + // Cosine annealing: initial_lr must be in (0.0, 1.0]. + let lr = learning_rate.min(1.0); + let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs) + .init() + .expect("valid cosine annealing config"); - // Build input tensor [batch, features]. - let flat: Vec = train_features[offset..end] - .iter() - .flat_map(|f| f.iter().copied()) - .collect(); - let x = Tensor::::from_floats(flat.as_slice(), device) - .reshape([batch_n, NUM_FEATURES]); + let learner = Learner::new( + model, + AdamConfig::new().init(), + lr_scheduler, + ); - // Labels [batch, 1]. - let y = Tensor::::from_floats( - &train_labels[offset..end], - device, - ) - .reshape([batch_n, 1]); + let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid) + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .num_epochs(epochs) + .summary() + .launch(learner); - // Sample weights [batch, 1]. - let w = Tensor::::from_floats( - &train_weights[offset..end], - device, - ) - .reshape([batch_n, 1]); - - // Forward pass. - let pred = model.forward(x); - - // Binary cross-entropy with sample weights: - // loss = -w * [y * log(p) + (1-y) * log(1-p)] - let eps = 1e-7; - let pred_clamped = pred.clone().clamp(eps, 1.0 - eps); - let bce = (y.clone() * pred_clamped.clone().log() - + (y.clone().neg().add_scalar(1.0)) - * pred_clamped.neg().add_scalar(1.0).log()) - .neg(); - let weighted_bce = bce * w; - let loss = weighted_bce.mean(); - - epoch_loss += loss.clone().into_scalar().elem::(); - batches += 1; - - // Backward + optimizer step. - let grads = loss.backward(); - let grads = GradientsParams::from_grads(grads, &model); - model = optim.step(learning_rate, model, grads); - - offset = end; - } - - if (epoch + 1) % 10 == 0 || epoch == 0 { - let avg_loss = epoch_loss / batches as f32; - let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device); - println!( - "[scanner] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}", - epoch + 1, - epochs, - avg_loss, - val_acc, - ); - } - } - - // Return the inner (non-autodiff) model for weight extraction. - model.valid() -} - -fn eval_mlp_accuracy( - model: &crate::training::mlp::MlpModel, - val_set: &[TrainingSample], - mins: &[f32], - maxs: &[f32], - device: &::Device, -) -> f64 { - let flat: Vec = val_set - .iter() - .flat_map(|s| normalize_features(&s.features, mins, maxs)) - .collect(); - let x = Tensor::::from_floats(flat.as_slice(), device) - .reshape([val_set.len(), NUM_FEATURES]); - - let pred = model.forward(x); - let pred_data: Vec = pred.to_data().to_vec().expect("flat vec"); - - let mut correct = 0usize; - for (i, s) in val_set.iter().enumerate() { - let p = pred_data[i]; - let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 }; - if (predicted_label - s.label).abs() < 0.1 { - correct += 1; - } - } - correct as f64 / val_set.len() as f64 + result.model } // --------------------------------------------------------------------------- @@ -410,19 +418,14 @@ fn eval_mlp_accuracy( // --------------------------------------------------------------------------- fn extract_weights( - model: &crate::training::mlp::MlpModel>, + model: &crate::training::mlp::MlpModel>, name: &str, tree_nodes: &[(u8, f32, u16, u16)], threshold: f32, norm_mins: &[f32], norm_maxs: &[f32], - _device: & as Backend>::Device, + _device: & as Backend>::Device, ) -> ExportedModel { - // Extract weight tensors from the model. - // linear1.weight: [hidden_dim, input_dim] - // linear1.bias: [hidden_dim] - // linear2.weight: [1, hidden_dim] - // linear2.bias: [1] let w1_tensor = model.linear1.weight.val(); let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val(); let w2_tensor = model.linear2.weight.val(); @@ -436,7 +439,6 @@ fn extract_weights( let hidden_dim = b1_data.len(); let input_dim = w1_data.len() / hidden_dim; - // Reshape W1 into [hidden_dim][input_dim]. let w1: Vec> = (0..hidden_dim) .map(|h| w1_data[h * input_dim..(h + 1) * input_dim].to_vec()) .collect(); @@ -485,9 +487,8 @@ mod tests { let train_attacks = train.iter().filter(|s| s.label >= 0.5).count(); let val_attacks = val.iter().filter(|s| s.label >= 0.5).count(); - // Should preserve the 80/20 attack ratio approximately. - assert_eq!(train_attacks, 16); // 80% of 20 - assert_eq!(val_attacks, 4); // 20% of 20 + assert_eq!(train_attacks, 16); + assert_eq!(val_attacks, 4); assert_eq!(train.len() + val.len(), 100); } @@ -503,13 +504,4 @@ mod tests { assert_eq!(mins[1], 10.0); assert_eq!(maxs[1], 20.0); } - - #[test] - fn test_normalize_features() { - let mins = vec![0.0, 10.0]; - let maxs = vec![1.0, 20.0]; - let normed = normalize_features(&[0.5, 15.0], &mins, &maxs); - assert!((normed[0] - 0.5).abs() < 1e-6); - assert!((normed[1] - 0.5).abs() < 1e-6); - } }