diff --git a/src/ddos/replay.rs b/src/ddos/replay.rs index 7e816b9..f7f26fc 100644 --- a/src/ddos/replay.rs +++ b/src/ddos/replay.rs @@ -21,38 +21,39 @@ pub struct ReplayArgs { pub rate_limit: bool, } -struct ReplayStats { - total: u64, - skipped: u64, - ddos_blocked: u64, - rate_limited: u64, - allowed: u64, - ddos_blocked_ips: FxHashMap, - rate_limited_ips: FxHashMap, +pub struct ReplayResult { + pub total: u64, + pub skipped: u64, + pub ddos_blocked: u64, + pub rate_limited: u64, + pub allowed: u64, + pub ddos_blocked_ips: FxHashMap, + pub rate_limited_ips: FxHashMap, + pub false_positive_ips: usize, + pub true_positive_ips: usize, } -pub fn run(args: ReplayArgs) -> Result<()> { - eprintln!("Loading model from {}...", args.model_path); +/// Core replay pipeline: load model, replay logs, compute stats including false positive analysis. +pub fn replay_and_evaluate(args: &ReplayArgs) -> Result { let model = TrainedModel::load( std::path::Path::new(&args.model_path), Some(args.k), Some(args.threshold), ) .with_context(|| format!("loading model from {}", args.model_path))?; - eprintln!(" {} training points, k={}, threshold={}", model.point_count(), args.k, args.threshold); let ddos_cfg = DDoSConfig { - model_path: args.model_path.clone(), + model_path: Some(args.model_path.clone()), k: args.k, threshold: args.threshold, window_secs: args.window_secs, window_capacity: 1000, min_events: args.min_events, enabled: true, + use_ensemble: false, }; let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg)); - // Optionally set up rate limiter let rate_limiter = if args.rate_limit { let rl_cfg = if let Some(cfg_path) = &args.config_path { let cfg = crate::config::Config::load(cfg_path)?; @@ -60,61 +61,49 @@ pub fn run(args: ReplayArgs) -> Result<()> { } else { default_rate_limit_config() }; - eprintln!( - " Rate limiter: auth burst={} rate={}/s, unauth burst={} rate={}/s", - rl_cfg.authenticated.burst, - rl_cfg.authenticated.rate, - rl_cfg.unauthenticated.burst, - rl_cfg.unauthenticated.rate, - ); Some(RateLimiter::new(&rl_cfg)) } else { None }; - eprintln!("Replaying {}...\n", args.input); - let file = std::fs::File::open(&args.input) .with_context(|| format!("opening {}", args.input))?; let reader = std::io::BufReader::new(file); - let mut stats = ReplayStats { - total: 0, - skipped: 0, - ddos_blocked: 0, - rate_limited: 0, - allowed: 0, - ddos_blocked_ips: FxHashMap::default(), - rate_limited_ips: FxHashMap::default(), - }; + let mut total = 0u64; + let mut skipped = 0u64; + let mut ddos_blocked = 0u64; + let mut rate_limited = 0u64; + let mut allowed = 0u64; + let mut ddos_blocked_ips: FxHashMap = FxHashMap::default(); + let mut rate_limited_ips: FxHashMap = FxHashMap::default(); for line in reader.lines() { let line = line?; let entry: AuditLog = match serde_json::from_str(&line) { Ok(e) => e, Err(_) => { - stats.skipped += 1; + skipped += 1; continue; } }; if entry.fields.method.is_empty() { - stats.skipped += 1; + skipped += 1; continue; } - stats.total += 1; + total += 1; let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string(); let ip: IpAddr = match ip_str.parse() { Ok(ip) => ip, Err(_) => { - stats.skipped += 1; + skipped += 1; continue; } }; - // DDoS check let has_cookies = entry.fields.has_cookies.unwrap_or(false); let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false); let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false); @@ -131,78 +120,57 @@ pub fn run(args: ReplayArgs) -> Result<()> { ); if ddos_action == DDoSAction::Block { - stats.ddos_blocked += 1; - *stats.ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1; + ddos_blocked += 1; + *ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1; continue; } - // Rate limit check if let Some(limiter) = &rate_limiter { - // Audit logs don't have auth headers, so all traffic is keyed by IP let rl_key = RateLimitKey::Ip(ip); if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) { - stats.rate_limited += 1; - *stats.rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1; + rate_limited += 1; + *rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1; continue; } } - stats.allowed += 1; + allowed += 1; } - // Report - let total = stats.total; - eprintln!("═══ Replay Results ═══════════════════════════════════════"); - eprintln!(" Total requests: {total}"); - eprintln!(" Skipped (parse): {}", stats.skipped); - eprintln!(" Allowed: {} ({:.1}%)", stats.allowed, pct(stats.allowed, total)); - eprintln!(" DDoS blocked: {} ({:.1}%)", stats.ddos_blocked, pct(stats.ddos_blocked, total)); - if rate_limiter.is_some() { - eprintln!(" Rate limited: {} ({:.1}%)", stats.rate_limited, pct(stats.rate_limited, total)); - } + // Compute false positive / true positive counts + let (false_positive_ips, true_positive_ips) = + count_fp_tp(&args.input, &ddos_blocked_ips, &rate_limited_ips)?; - if !stats.ddos_blocked_ips.is_empty() { - eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────"); - let mut sorted: Vec<_> = stats.ddos_blocked_ips.iter().collect(); - sorted.sort_by(|a, b| b.1.cmp(a.1)); - for (ip, count) in sorted.iter().take(20) { - eprintln!(" {:<40} {} reqs blocked", ip, count); - } - } - - if !stats.rate_limited_ips.is_empty() { - eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────"); - let mut sorted: Vec<_> = stats.rate_limited_ips.iter().collect(); - sorted.sort_by(|a, b| b.1.cmp(a.1)); - for (ip, count) in sorted.iter().take(20) { - eprintln!(" {:<40} {} reqs limited", ip, count); - } - } - - // Check for false positives: IPs that were blocked but had 2xx statuses in the original logs - eprintln!("\n── False positive check ──────────────────────────────────"); - check_false_positives(&args.input, &stats)?; - - eprintln!("══════════════════════════════════════════════════════════"); - Ok(()) + Ok(ReplayResult { + total, + skipped, + ddos_blocked, + rate_limited, + allowed, + ddos_blocked_ips, + rate_limited_ips, + false_positive_ips, + true_positive_ips, + }) } -/// Re-scan the log to find blocked IPs that had mostly 2xx responses originally -/// (i.e. they were legitimate traffic that the model would incorrectly block). -fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> { - let blocked_ips: rustc_hash::FxHashSet<&str> = stats - .ddos_blocked_ips +/// Count false-positive and true-positive IPs from blocked set. +/// FP = blocked IP where >60% of original responses were 2xx/3xx. +fn count_fp_tp( + input: &str, + ddos_blocked_ips: &FxHashMap, + rate_limited_ips: &FxHashMap, +) -> Result<(usize, usize)> { + let blocked_ips: rustc_hash::FxHashSet<&str> = ddos_blocked_ips .keys() - .chain(stats.rate_limited_ips.keys()) + .chain(rate_limited_ips.keys()) .map(|s| s.as_str()) .collect(); if blocked_ips.is_empty() { - eprintln!(" No blocked IPs — nothing to check."); - return Ok(()); + return Ok((0, 0)); } - // Collect original status codes for blocked IPs let file = std::fs::File::open(input)?; let reader = std::io::BufReader::new(file); let mut ip_statuses: FxHashMap> = FxHashMap::default(); @@ -215,47 +183,69 @@ fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> { }; let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string(); if blocked_ips.contains(ip_str.as_str()) { - ip_statuses - .entry(ip_str) - .or_default() - .push(entry.fields.status); + ip_statuses.entry(ip_str).or_default().push(entry.fields.status); } } - let mut suspects = Vec::new(); - for (ip, statuses) in &ip_statuses { + let mut fp = 0usize; + let mut tp = 0usize; + for (_ip, statuses) in &ip_statuses { let total = statuses.len(); let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count(); - let ok_pct = (ok_count as f64 / total as f64) * 100.0; - // If >60% of original responses were 2xx/3xx, this might be a false positive + let ok_pct = ok_count as f64 / total as f64 * 100.0; if ok_pct > 60.0 { - let blocked = stats - .ddos_blocked_ips - .get(ip) - .copied() - .unwrap_or(0) - + stats - .rate_limited_ips - .get(ip) - .copied() - .unwrap_or(0); - suspects.push((ip.clone(), total, ok_pct, blocked)); + fp += 1; + } else { + tp += 1; } } - if suspects.is_empty() { + Ok((fp, tp)) +} + +pub fn run(args: ReplayArgs) -> Result<()> { + eprintln!("Loading model from {}...", args.model_path); + eprintln!("Replaying {}...\n", args.input); + + let result = replay_and_evaluate(&args)?; + + let total = result.total; + eprintln!("═══ Replay Results ═══════════════════════════════════════"); + eprintln!(" Total requests: {total}"); + eprintln!(" Skipped (parse): {}", result.skipped); + eprintln!(" Allowed: {} ({:.1}%)", result.allowed, pct(result.allowed, total)); + eprintln!(" DDoS blocked: {} ({:.1}%)", result.ddos_blocked, pct(result.ddos_blocked, total)); + if args.rate_limit { + eprintln!(" Rate limited: {} ({:.1}%)", result.rate_limited, pct(result.rate_limited, total)); + } + + if !result.ddos_blocked_ips.is_empty() { + eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────"); + let mut sorted: Vec<_> = result.ddos_blocked_ips.iter().collect(); + sorted.sort_by(|a, b| b.1.cmp(a.1)); + for (ip, count) in sorted.iter().take(20) { + eprintln!(" {:<40} {} reqs blocked", ip, count); + } + } + + if !result.rate_limited_ips.is_empty() { + eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────"); + let mut sorted: Vec<_> = result.rate_limited_ips.iter().collect(); + sorted.sort_by(|a, b| b.1.cmp(a.1)); + for (ip, count) in sorted.iter().take(20) { + eprintln!(" {:<40} {} reqs limited", ip, count); + } + } + + eprintln!("\n── False positive check ──────────────────────────────────"); + if result.false_positive_ips == 0 { eprintln!(" No likely false positives found."); } else { - suspects.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap()); - eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses:", suspects.len()); - for (ip, total, ok_pct, blocked) in suspects.iter().take(15) { - eprintln!( - " {:<40} {}/{} reqs were 2xx/3xx ({:.0}%), {} blocked", - ip, ((*ok_pct / 100.0) * *total as f64) as u64, total, ok_pct, blocked, - ); - } + eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses", result.false_positive_ips); } + eprintln!(" True positive IPs: {}", result.true_positive_ips); + eprintln!("══════════════════════════════════════════════════════════"); Ok(()) } diff --git a/src/ensemble/replay.rs b/src/ensemble/replay.rs new file mode 100644 index 0000000..3dd7af0 --- /dev/null +++ b/src/ensemble/replay.rs @@ -0,0 +1,316 @@ +//! Replay audit logs through the ensemble models (scanner + DDoS). + +use crate::ddos::audit_log::{self, AuditLog}; +use crate::ddos::features::{method_to_u8, LogIpState}; +use crate::ddos::model::DDoSAction; +use crate::ensemble::ddos::{ddos_ensemble_predict, DDoSEnsemblePath}; +use crate::ensemble::scanner::{scanner_ensemble_predict, EnsemblePath}; +use crate::scanner::features::{self, fx_hash_bytes}; +use crate::scanner::model::ScannerAction; + +use anyhow::{Context, Result}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::hash::{Hash, Hasher}; +use std::io::BufRead; + +pub struct ReplayEnsembleArgs { + pub input: String, + pub window_secs: u64, + pub min_events: usize, +} + +pub fn run(args: ReplayEnsembleArgs) -> Result<()> { + eprintln!("replaying {} through ensemble models...\n", args.input); + + let file = + std::fs::File::open(&args.input).with_context(|| format!("opening {}", args.input))?; + let reader = std::io::BufReader::new(file); + + // --- Parse all entries --- + let mut entries: Vec = Vec::new(); + let mut parse_errors = 0u64; + for line in reader.lines() { + let line = line?; + if line.trim().is_empty() { + continue; + } + match serde_json::from_str::(&line) { + Ok(e) => entries.push(e), + Err(_) => parse_errors += 1, + } + } + let total = entries.len() as u64; + eprintln!("parsed {} entries ({} parse errors)\n", total, parse_errors); + + // --- Scanner replay --- + eprintln!("═══ Scanner Ensemble ═════════════════════════════════════"); + replay_scanner(&entries); + + // --- DDoS replay --- + eprintln!("\n═══ DDoS Ensemble ═══════════════════════════════════════"); + replay_ddos(&entries, args.window_secs as f64, args.min_events); + + eprintln!("\n══════════════════════════════════════════════════════════"); + Ok(()) +} + +fn replay_scanner(entries: &[AuditLog]) { + let fragment_hashes: FxHashSet = crate::scanner::train::DEFAULT_FRAGMENTS + .iter() + .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) + .collect(); + let extension_hashes: FxHashSet = features::SUSPICIOUS_EXTENSIONS_LIST + .iter() + .map(|e| fx_hash_bytes(e.as_bytes())) + .collect(); + let mut log_hosts: FxHashSet = FxHashSet::default(); + for e in entries { + let prefix = e.fields.host.split('.').next().unwrap_or(""); + log_hosts.insert(fx_hash_bytes(prefix.as_bytes())); + } + + let mut total = 0u64; + let mut blocked = 0u64; + let mut allowed = 0u64; + let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp + let mut blocked_examples: Vec<(String, String, f64)> = Vec::new(); // (path, reason, score) + let mut fp_candidates: Vec<(String, u16, f64)> = Vec::new(); // blocked but had 2xx status + + for e in entries { + let f = &e.fields; + let host_prefix = f.host.split('.').next().unwrap_or(""); + let has_cookies = f.has_cookies.unwrap_or(false); + let has_referer = f + .referer + .as_ref() + .map(|r| r != "-" && !r.is_empty()) + .unwrap_or(false); + let has_accept_language = f + .accept_language + .as_ref() + .map(|a| a != "-" && !a.is_empty()) + .unwrap_or(false); + + let feats = features::extract_features_f32( + &f.method, + &f.path, + host_prefix, + has_cookies, + has_referer, + has_accept_language, + "-", + &f.user_agent, + f.content_length, + &fragment_hashes, + &extension_hashes, + &log_hosts, + ); + + let verdict = scanner_ensemble_predict(&feats); + total += 1; + + match verdict.path { + EnsemblePath::TreeBlock => path_counts[0] += 1, + EnsemblePath::TreeAllow => path_counts[1] += 1, + EnsemblePath::Mlp => path_counts[2] += 1, + } + + match verdict.action { + ScannerAction::Block => { + blocked += 1; + if blocked_examples.len() < 20 { + blocked_examples.push(( + f.path.clone(), + verdict.reason.to_string(), + verdict.score, + )); + } + if (200..400).contains(&f.status) { + fp_candidates.push((f.path.clone(), f.status, verdict.score)); + } + } + ScannerAction::Allow => allowed += 1, + } + } + + let pct = |n: u64| { + if total == 0 { + 0.0 + } else { + n as f64 / total as f64 * 100.0 + } + }; + + eprintln!(" total: {total}"); + eprintln!( + " blocked: {} ({:.1}%)", + blocked, + pct(blocked) + ); + eprintln!( + " allowed: {} ({:.1}%)", + allowed, + pct(allowed) + ); + eprintln!( + " paths: tree_block={} tree_allow={} mlp={}", + path_counts[0], path_counts[1], path_counts[2] + ); + + if !blocked_examples.is_empty() { + eprintln!("\n blocked examples (first 20):"); + for (path, reason, score) in &blocked_examples { + eprintln!(" {:<50} {reason} (score={score:.3})", truncate(path, 50)); + } + } + + let fp_count = fp_candidates.len(); + if fp_count > 0 { + eprintln!( + "\n potential false positives (blocked but had 2xx/3xx): {}", + fp_count + ); + for (path, status, score) in fp_candidates.iter().take(10) { + eprintln!( + " {:<50} status={status} score={score:.3}", + truncate(path, 50) + ); + } + } +} + +fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) { + fn fx_hash(s: &str) -> u64 { + let mut h = rustc_hash::FxHasher::default(); + s.hash(&mut h); + h.finish() + } + + // Aggregate per-IP state. + let mut ip_states: FxHashMap = FxHashMap::default(); + + for e in entries { + let f = &e.fields; + let ip = audit_log::strip_port(&f.client_ip).to_string(); + let state = ip_states.entry(ip).or_default(); + let ts = state.timestamps.len() as f64; + state.timestamps.push(ts); + state.methods.push(method_to_u8(&f.method)); + state.path_hashes.push(fx_hash(&f.path)); + state.host_hashes.push(fx_hash(&f.host)); + state.user_agent_hashes.push(fx_hash(&f.user_agent)); + state.statuses.push(f.status); + state + .durations + .push(f.duration_ms.min(u32::MAX as u64) as u32); + state + .content_lengths + .push(f.content_length.min(u32::MAX as u64) as u32); + state + .has_cookies + .push(f.has_cookies.unwrap_or(false)); + state.has_referer.push( + f.referer + .as_deref() + .map(|r| r != "-") + .unwrap_or(false), + ); + state.has_accept_language.push( + f.accept_language + .as_deref() + .map(|a| a != "-") + .unwrap_or(false), + ); + state + .suspicious_paths + .push(crate::ddos::features::is_suspicious_path(&f.path)); + } + + let mut total_ips = 0u64; + let mut blocked_ips = 0u64; + let mut allowed_ips = 0u64; + let mut skipped_ips = 0u64; + let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp + let mut blocked_details: Vec<(String, usize, f64, &'static str)> = Vec::new(); + + for (ip, state) in &ip_states { + let n = state.timestamps.len(); + if n < min_events { + skipped_ips += 1; + continue; + } + total_ips += 1; + + let fv = state.extract_features_for_window(0, n, window_secs); + let fv_f32: [f32; 14] = { + let mut arr = [0.0f32; 14]; + for i in 0..14 { + arr[i] = fv[i] as f32; + } + arr + }; + + let verdict = ddos_ensemble_predict(&fv_f32); + + match verdict.path { + DDoSEnsemblePath::TreeBlock => path_counts[0] += 1, + DDoSEnsemblePath::TreeAllow => path_counts[1] += 1, + DDoSEnsemblePath::Mlp => path_counts[2] += 1, + } + + match verdict.action { + DDoSAction::Block => { + blocked_ips += 1; + if blocked_details.len() < 30 { + blocked_details.push((ip.clone(), n, verdict.score, verdict.reason)); + } + } + DDoSAction::Allow => allowed_ips += 1, + } + } + + let pct = |n: u64, d: u64| { + if d == 0 { + 0.0 + } else { + n as f64 / d as f64 * 100.0 + } + }; + + eprintln!(" unique IPs: {} ({} skipped < {} events)", ip_states.len(), skipped_ips, min_events); + eprintln!(" evaluated: {total_ips}"); + eprintln!( + " blocked: {} ({:.1}%)", + blocked_ips, + pct(blocked_ips, total_ips) + ); + eprintln!( + " allowed: {} ({:.1}%)", + allowed_ips, + pct(allowed_ips, total_ips) + ); + eprintln!( + " paths: tree_block={} tree_allow={} mlp={}", + path_counts[0], path_counts[1], path_counts[2] + ); + + if !blocked_details.is_empty() { + eprintln!("\n blocked IPs (up to 30):"); + let mut sorted = blocked_details; + sorted.sort_by(|a, b| b.1.cmp(&a.1)); + for (ip, reqs, score, reason) in &sorted { + eprintln!( + " {:<40} {} reqs score={:.3} {reason}", + ip, reqs, score + ); + } + } +} + +fn truncate(s: &str, max: usize) -> String { + if s.len() <= max { + s.to_string() + } else { + format!("{}...", &s[..max - 3]) + } +} diff --git a/src/main.rs b/src/main.rs index 3e10f1d..04df74d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ mod cert; mod telemetry; mod watcher; -use sunbeam_proxy::{acme, config}; +use sunbeam_proxy::{acme, autotune, config}; use sunbeam_proxy::proxy::SunbeamProxy; use sunbeam_proxy::ddos; use sunbeam_proxy::rate_limit; @@ -32,32 +32,10 @@ enum Commands { #[arg(long)] upgrade: bool, }, - /// Replay audit logs through the DDoS detector and rate limiter - ReplayDdos { - /// Path to audit log JSONL file - #[arg(short, long)] - input: String, - /// Path to trained model file - #[arg(short, long, default_value = "ddos_model.bin")] - model: String, - /// Optional config file (for rate limit settings) - #[arg(short, long)] - config: Option, - /// KNN k parameter - #[arg(long, default_value = "5")] - k: usize, - /// Attack threshold - #[arg(long, default_value = "0.6")] - threshold: f64, - /// Sliding window size in seconds - #[arg(long, default_value = "60")] - window_secs: u64, - /// Minimum events per IP before classification - #[arg(long, default_value = "10")] - min_events: usize, - /// Also run rate limiter during replay - #[arg(long)] - rate_limit: bool, + /// Replay audit logs through detection models + Replay { + #[command(subcommand)] + mode: ReplayMode, }, /// Train a DDoS detection model from audit logs TrainDdos { @@ -107,31 +85,180 @@ enum Commands { #[arg(long)] csic: bool, }, + /// Bayesian hyperparameter optimization for DDoS model + AutotuneDdos { + /// Path to audit log JSONL file + #[arg(short, long)] + input: String, + /// Output best model file path + #[arg(short, long, default_value = "ddos_model_best.bin")] + output: String, + /// Number of optimization trials + #[arg(long, default_value = "200")] + trials: usize, + /// F-beta parameter (1.0 = F1, 2.0 = recall-weighted) + #[arg(long, default_value = "1.0")] + beta: f64, + /// JSONL file to log each trial's parameters and results + #[arg(long)] + trial_log: Option, + }, + /// Download and cache upstream datasets (CIC-IDS2017) + DownloadDatasets, + /// Prepare a unified training dataset from multiple sources + PrepareDataset { + /// Path to audit log JSONL file + #[arg(short, long)] + input: String, + /// Path to OWASP ModSecurity audit log file (optional extra data) + #[arg(long)] + owasp: Option, + /// Directory containing .txt wordlists (optional, enhances synthetic scanner) + #[arg(long)] + wordlists: Option, + /// Output dataset file path + #[arg(short, long, default_value = "dataset.bin")] + output: String, + /// Random seed + #[arg(long, default_value = "42")] + seed: u64, + /// Path to heuristics.toml for auto-labeling production logs + #[arg(long)] + heuristics: Option, + }, + #[cfg(feature = "training")] + /// Train scanner ensemble (decision tree + MLP) from prepared dataset + TrainMlpScanner { + /// Path to prepared dataset (.bin) + #[arg(short = 'd', long)] + dataset: String, + /// Output directory for generated weight files + #[arg(short, long, default_value = "src/ensemble/gen")] + output_dir: String, + /// Hidden layer dimension + #[arg(long, default_value = "32")] + hidden_dim: usize, + /// Training epochs + #[arg(long, default_value = "100")] + epochs: usize, + /// Learning rate + #[arg(long, default_value = "0.001")] + learning_rate: f64, + /// Batch size + #[arg(long, default_value = "64")] + batch_size: usize, + /// Max tree depth + #[arg(long, default_value = "6")] + tree_max_depth: usize, + /// Min purity for tree leaves (below -> Defer) + #[arg(long, default_value = "0.90")] + tree_min_purity: f32, + }, + #[cfg(feature = "training")] + /// Train DDoS ensemble (decision tree + MLP) from prepared dataset + TrainMlpDdos { + #[arg(short = 'd', long)] + dataset: String, + #[arg(short, long, default_value = "src/ensemble/gen")] + output_dir: String, + #[arg(long, default_value = "32")] + hidden_dim: usize, + #[arg(long, default_value = "100")] + epochs: usize, + #[arg(long, default_value = "0.001")] + learning_rate: f64, + #[arg(long, default_value = "64")] + batch_size: usize, + #[arg(long, default_value = "6")] + tree_max_depth: usize, + #[arg(long, default_value = "0.90")] + tree_min_purity: f32, + }, + /// Bayesian hyperparameter optimization for scanner model + AutotuneScanner { + /// Path to audit log JSONL file + #[arg(short, long)] + input: String, + /// Output best model file path + #[arg(short, long, default_value = "scanner_model_best.bin")] + output: String, + /// Directory (or file) containing .txt wordlists of scanner paths + #[arg(long)] + wordlists: Option, + /// Include CSIC 2010 dataset as base training data + #[arg(long)] + csic: bool, + /// Number of optimization trials + #[arg(long, default_value = "200")] + trials: usize, + /// F-beta parameter (1.0 = F1, 2.0 = recall-weighted) + #[arg(long, default_value = "1.0")] + beta: f64, + /// JSONL file to log each trial's parameters and results + #[arg(long)] + trial_log: Option, + }, +} + +#[derive(Subcommand)] +enum ReplayMode { + /// Replay through ensemble models (scanner + DDoS) + Ensemble { + /// Path to audit log JSONL file + #[arg(short, long)] + input: String, + /// Sliding window size in seconds + #[arg(long, default_value = "60")] + window_secs: u64, + /// Minimum events per IP before DDoS classification + #[arg(long, default_value = "5")] + min_events: usize, + }, + /// Replay through legacy KNN DDoS detector + Ddos { + /// Path to audit log JSONL file + #[arg(short, long)] + input: String, + /// Path to trained model file + #[arg(short, long, default_value = "ddos_model.bin")] + model: String, + /// Optional config file (for rate limit settings) + #[arg(short, long)] + config: Option, + /// KNN k parameter + #[arg(long, default_value = "5")] + k: usize, + /// Attack threshold + #[arg(long, default_value = "0.6")] + threshold: f64, + /// Sliding window size in seconds + #[arg(long, default_value = "60")] + window_secs: u64, + /// Minimum events per IP before classification + #[arg(long, default_value = "10")] + min_events: usize, + /// Also run rate limiter during replay + #[arg(long)] + rate_limit: bool, + }, } fn main() -> Result<()> { let cli = Cli::parse(); match cli.command.unwrap_or(Commands::Serve { upgrade: false }) { Commands::Serve { upgrade } => run_serve(upgrade), - Commands::ReplayDdos { - input, - model, - config, - k, - threshold, - window_secs, - min_events, - rate_limit, - } => ddos::replay::run(ddos::replay::ReplayArgs { - input, - model_path: model, - config_path: config, - k, - threshold, - window_secs, - min_events, - rate_limit, - }), + Commands::Replay { mode } => match mode { + ReplayMode::Ensemble { input, window_secs, min_events } => { + sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs { + input, window_secs, min_events, + }) + } + ReplayMode::Ddos { input, model, config, k, threshold, window_secs, min_events, rate_limit } => { + ddos::replay::run(ddos::replay::ReplayArgs { + input, model_path: model, config_path: config, k, threshold, window_secs, min_events, rate_limit, + }) + } + }, Commands::TrainDdos { input, output, @@ -166,6 +293,56 @@ fn main() -> Result<()> { threshold, csic, }), + Commands::DownloadDatasets => { + sunbeam_proxy::dataset::download::download_all() + }, + Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics } => { + sunbeam_proxy::dataset::prepare::run(sunbeam_proxy::dataset::prepare::PrepareDatasetArgs { + input, owasp, wordlists, output, seed, heuristics, + }) + }, + #[cfg(feature = "training")] + Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => { + sunbeam_proxy::training::train_scanner::run(sunbeam_proxy::training::train_scanner::TrainScannerMlpArgs { + dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, + }) + }, + #[cfg(feature = "training")] + Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => { + sunbeam_proxy::training::train_ddos::run(sunbeam_proxy::training::train_ddos::TrainDdosMlpArgs { + dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, + }) + }, + Commands::AutotuneDdos { + input, + output, + trials, + beta, + trial_log, + } => autotune::ddos::run_autotune(autotune::ddos::AutotuneDdosArgs { + input, + output, + trials, + beta, + trial_log, + }), + Commands::AutotuneScanner { + input, + output, + wordlists, + csic, + trials, + beta, + trial_log, + } => autotune::scanner::run_autotune(autotune::scanner::AutotuneScannerArgs { + input, + output, + wordlists, + csic, + trials, + beta, + trial_log, + }), } } @@ -189,26 +366,42 @@ fn run_serve(upgrade: bool) -> Result<()> { // 2. Load DDoS detection model if configured. let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos { if ddos_cfg.enabled { - match ddos::model::TrainedModel::load( - std::path::Path::new(&ddos_cfg.model_path), - Some(ddos_cfg.k), - Some(ddos_cfg.threshold), - ) { - Ok(model) => { - let point_count = model.point_count(); - let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg)); - tracing::info!( - points = point_count, - k = ddos_cfg.k, - threshold = ddos_cfg.threshold, - "DDoS detector loaded" - ); - Some(detector) - } - Err(e) => { - tracing::warn!(error = %e, "failed to load DDoS model; detection disabled"); - None + if ddos_cfg.use_ensemble { + // Ensemble path: compiled-in weights, no model file needed. + // We still need a TrainedModel for the struct, but it won't be used. + let dummy_model = ddos::model::TrainedModel::empty(ddos_cfg.k, ddos_cfg.threshold); + let detector = Arc::new(ddos::detector::DDoSDetector::new_ensemble(dummy_model, ddos_cfg)); + tracing::info!( + k = ddos_cfg.k, + threshold = ddos_cfg.threshold, + "DDoS ensemble detector enabled" + ); + Some(detector) + } else if let Some(ref model_path) = ddos_cfg.model_path { + match ddos::model::TrainedModel::load( + std::path::Path::new(model_path), + Some(ddos_cfg.k), + Some(ddos_cfg.threshold), + ) { + Ok(model) => { + let point_count = model.point_count(); + let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg)); + tracing::info!( + points = point_count, + k = ddos_cfg.k, + threshold = ddos_cfg.threshold, + "DDoS detector loaded" + ); + Some(detector) + } + Err(e) => { + tracing::warn!(error = %e, "failed to load DDoS model; detection disabled"); + None + } } + } else { + tracing::warn!("DDoS enabled but no model_path and use_ensemble=false; detection disabled"); + None } } else { None @@ -245,54 +438,84 @@ fn run_serve(upgrade: bool) -> Result<()> { // 2c. Load scanner model if configured. let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner { if scanner_cfg.enabled { - match scanner::model::ScannerModel::load(std::path::Path::new(&scanner_cfg.model_path)) { - Ok(mut model) => { - let fragment_count = model.fragments.len(); - model.threshold = scanner_cfg.threshold; - let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes); - let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector)); + if scanner_cfg.use_ensemble { + // Ensemble path: compiled-in weights, no model file needed. + let detector = scanner::detector::ScannerDetector::new_ensemble(&cfg.routes); + let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector)); - // Start bot allowlist if rules are configured. - let bot_allowlist = if !scanner_cfg.allowlist.is_empty() { - let al = scanner::allowlist::BotAllowlist::spawn( - &scanner_cfg.allowlist, - scanner_cfg.bot_cache_ttl_secs, - ); - tracing::info!( - rules = scanner_cfg.allowlist.len(), - "bot allowlist enabled" - ); - Some(al) - } else { - None - }; - - // Start background file watcher for hot-reload. - if scanner_cfg.poll_interval_secs > 0 { - let watcher_handle = handle.clone(); - let model_path = std::path::PathBuf::from(&scanner_cfg.model_path); - let threshold = scanner_cfg.threshold; - let routes = cfg.routes.clone(); - let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs); - std::thread::spawn(move || { - scanner::watcher::watch_scanner_model( - watcher_handle, model_path, threshold, routes, interval, - ); - }); - } - - tracing::info!( - fragments = fragment_count, - threshold = scanner_cfg.threshold, - poll_interval_secs = scanner_cfg.poll_interval_secs, - "scanner detector loaded" + // Start bot allowlist if rules are configured. + let bot_allowlist = if !scanner_cfg.allowlist.is_empty() { + let al = scanner::allowlist::BotAllowlist::spawn( + &scanner_cfg.allowlist, + scanner_cfg.bot_cache_ttl_secs, ); - (Some(handle), bot_allowlist) - } - Err(e) => { - tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled"); - (None, None) + tracing::info!( + rules = scanner_cfg.allowlist.len(), + "bot allowlist enabled" + ); + Some(al) + } else { + None + }; + + tracing::info!( + threshold = scanner_cfg.threshold, + "scanner ensemble detector enabled" + ); + (Some(handle), bot_allowlist) + } else if let Some(ref model_path) = scanner_cfg.model_path { + match scanner::model::ScannerModel::load(std::path::Path::new(model_path)) { + Ok(mut model) => { + let fragment_count = model.fragments.len(); + model.threshold = scanner_cfg.threshold; + let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes); + let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector)); + + // Start bot allowlist if rules are configured. + let bot_allowlist = if !scanner_cfg.allowlist.is_empty() { + let al = scanner::allowlist::BotAllowlist::spawn( + &scanner_cfg.allowlist, + scanner_cfg.bot_cache_ttl_secs, + ); + tracing::info!( + rules = scanner_cfg.allowlist.len(), + "bot allowlist enabled" + ); + Some(al) + } else { + None + }; + + // Start background file watcher for hot-reload. + if scanner_cfg.poll_interval_secs > 0 { + let watcher_handle = handle.clone(); + let watcher_model_path = std::path::PathBuf::from(model_path); + let threshold = scanner_cfg.threshold; + let routes = cfg.routes.clone(); + let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs); + std::thread::spawn(move || { + scanner::watcher::watch_scanner_model( + watcher_handle, watcher_model_path, threshold, routes, interval, + ); + }); + } + + tracing::info!( + fragments = fragment_count, + threshold = scanner_cfg.threshold, + poll_interval_secs = scanner_cfg.poll_interval_secs, + "scanner detector loaded" + ); + (Some(handle), bot_allowlist) + } + Err(e) => { + tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled"); + (None, None) + } } + } else { + tracing::warn!("scanner enabled but no model_path and use_ensemble=false; scanner detection disabled"); + (None, None) } } else { (None, None)