feat(cli): restructure replay as subcommand with ensemble and ddos modes

Add `replay ensemble` (runs logs through compiled-in tree+MLP for both
scanner and DDoS) and `replay ddos` (legacy KNN). Also adds CLI commands
for download-datasets, prepare-dataset, train-mlp-scanner, train-mlp-ddos,
autotune-ddos, and autotune-scanner.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:21 +00:00
parent 565ea4cde4
commit 905fd78299
3 changed files with 753 additions and 224 deletions

View File

@@ -21,38 +21,39 @@ pub struct ReplayArgs {
pub rate_limit: bool,
}
struct ReplayStats {
total: u64,
skipped: u64,
ddos_blocked: u64,
rate_limited: u64,
allowed: u64,
ddos_blocked_ips: FxHashMap<String, u64>,
rate_limited_ips: FxHashMap<String, u64>,
pub struct ReplayResult {
pub total: u64,
pub skipped: u64,
pub ddos_blocked: u64,
pub rate_limited: u64,
pub allowed: u64,
pub ddos_blocked_ips: FxHashMap<String, u64>,
pub rate_limited_ips: FxHashMap<String, u64>,
pub false_positive_ips: usize,
pub true_positive_ips: usize,
}
pub fn run(args: ReplayArgs) -> Result<()> {
eprintln!("Loading model from {}...", args.model_path);
/// Core replay pipeline: load model, replay logs, compute stats including false positive analysis.
pub fn replay_and_evaluate(args: &ReplayArgs) -> Result<ReplayResult> {
let model = TrainedModel::load(
std::path::Path::new(&args.model_path),
Some(args.k),
Some(args.threshold),
)
.with_context(|| format!("loading model from {}", args.model_path))?;
eprintln!(" {} training points, k={}, threshold={}", model.point_count(), args.k, args.threshold);
let ddos_cfg = DDoSConfig {
model_path: args.model_path.clone(),
model_path: Some(args.model_path.clone()),
k: args.k,
threshold: args.threshold,
window_secs: args.window_secs,
window_capacity: 1000,
min_events: args.min_events,
enabled: true,
use_ensemble: false,
};
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
// Optionally set up rate limiter
let rate_limiter = if args.rate_limit {
let rl_cfg = if let Some(cfg_path) = &args.config_path {
let cfg = crate::config::Config::load(cfg_path)?;
@@ -60,61 +61,49 @@ pub fn run(args: ReplayArgs) -> Result<()> {
} else {
default_rate_limit_config()
};
eprintln!(
" Rate limiter: auth burst={} rate={}/s, unauth burst={} rate={}/s",
rl_cfg.authenticated.burst,
rl_cfg.authenticated.rate,
rl_cfg.unauthenticated.burst,
rl_cfg.unauthenticated.rate,
);
Some(RateLimiter::new(&rl_cfg))
} else {
None
};
eprintln!("Replaying {}...\n", args.input);
let file = std::fs::File::open(&args.input)
.with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
let mut stats = ReplayStats {
total: 0,
skipped: 0,
ddos_blocked: 0,
rate_limited: 0,
allowed: 0,
ddos_blocked_ips: FxHashMap::default(),
rate_limited_ips: FxHashMap::default(),
};
let mut total = 0u64;
let mut skipped = 0u64;
let mut ddos_blocked = 0u64;
let mut rate_limited = 0u64;
let mut allowed = 0u64;
let mut ddos_blocked_ips: FxHashMap<String, u64> = FxHashMap::default();
let mut rate_limited_ips: FxHashMap<String, u64> = FxHashMap::default();
for line in reader.lines() {
let line = line?;
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => {
stats.skipped += 1;
skipped += 1;
continue;
}
};
if entry.fields.method.is_empty() {
stats.skipped += 1;
skipped += 1;
continue;
}
stats.total += 1;
total += 1;
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
let ip: IpAddr = match ip_str.parse() {
Ok(ip) => ip,
Err(_) => {
stats.skipped += 1;
skipped += 1;
continue;
}
};
// DDoS check
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
@@ -131,78 +120,57 @@ pub fn run(args: ReplayArgs) -> Result<()> {
);
if ddos_action == DDoSAction::Block {
stats.ddos_blocked += 1;
*stats.ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
ddos_blocked += 1;
*ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
continue;
}
// Rate limit check
if let Some(limiter) = &rate_limiter {
// Audit logs don't have auth headers, so all traffic is keyed by IP
let rl_key = RateLimitKey::Ip(ip);
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
stats.rate_limited += 1;
*stats.rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
rate_limited += 1;
*rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
continue;
}
}
stats.allowed += 1;
allowed += 1;
}
// Report
let total = stats.total;
eprintln!("═══ Replay Results ═══════════════════════════════════════");
eprintln!(" Total requests: {total}");
eprintln!(" Skipped (parse): {}", stats.skipped);
eprintln!(" Allowed: {} ({:.1}%)", stats.allowed, pct(stats.allowed, total));
eprintln!(" DDoS blocked: {} ({:.1}%)", stats.ddos_blocked, pct(stats.ddos_blocked, total));
if rate_limiter.is_some() {
eprintln!(" Rate limited: {} ({:.1}%)", stats.rate_limited, pct(stats.rate_limited, total));
// Compute false positive / true positive counts
let (false_positive_ips, true_positive_ips) =
count_fp_tp(&args.input, &ddos_blocked_ips, &rate_limited_ips)?;
Ok(ReplayResult {
total,
skipped,
ddos_blocked,
rate_limited,
allowed,
ddos_blocked_ips,
rate_limited_ips,
false_positive_ips,
true_positive_ips,
})
}
if !stats.ddos_blocked_ips.is_empty() {
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
let mut sorted: Vec<_> = stats.ddos_blocked_ips.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
for (ip, count) in sorted.iter().take(20) {
eprintln!(" {:<40} {} reqs blocked", ip, count);
}
}
if !stats.rate_limited_ips.is_empty() {
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
let mut sorted: Vec<_> = stats.rate_limited_ips.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
for (ip, count) in sorted.iter().take(20) {
eprintln!(" {:<40} {} reqs limited", ip, count);
}
}
// Check for false positives: IPs that were blocked but had 2xx statuses in the original logs
eprintln!("\n── False positive check ──────────────────────────────────");
check_false_positives(&args.input, &stats)?;
eprintln!("══════════════════════════════════════════════════════════");
Ok(())
}
/// Re-scan the log to find blocked IPs that had mostly 2xx responses originally
/// (i.e. they were legitimate traffic that the model would incorrectly block).
fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
let blocked_ips: rustc_hash::FxHashSet<&str> = stats
.ddos_blocked_ips
/// Count false-positive and true-positive IPs from blocked set.
/// FP = blocked IP where >60% of original responses were 2xx/3xx.
fn count_fp_tp(
input: &str,
ddos_blocked_ips: &FxHashMap<String, u64>,
rate_limited_ips: &FxHashMap<String, u64>,
) -> Result<(usize, usize)> {
let blocked_ips: rustc_hash::FxHashSet<&str> = ddos_blocked_ips
.keys()
.chain(stats.rate_limited_ips.keys())
.chain(rate_limited_ips.keys())
.map(|s| s.as_str())
.collect();
if blocked_ips.is_empty() {
eprintln!(" No blocked IPs — nothing to check.");
return Ok(());
return Ok((0, 0));
}
// Collect original status codes for blocked IPs
let file = std::fs::File::open(input)?;
let reader = std::io::BufReader::new(file);
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
@@ -215,47 +183,69 @@ fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
};
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
if blocked_ips.contains(ip_str.as_str()) {
ip_statuses
.entry(ip_str)
.or_default()
.push(entry.fields.status);
ip_statuses.entry(ip_str).or_default().push(entry.fields.status);
}
}
let mut suspects = Vec::new();
for (ip, statuses) in &ip_statuses {
let mut fp = 0usize;
let mut tp = 0usize;
for (_ip, statuses) in &ip_statuses {
let total = statuses.len();
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
let ok_pct = (ok_count as f64 / total as f64) * 100.0;
// If >60% of original responses were 2xx/3xx, this might be a false positive
let ok_pct = ok_count as f64 / total as f64 * 100.0;
if ok_pct > 60.0 {
let blocked = stats
.ddos_blocked_ips
.get(ip)
.copied()
.unwrap_or(0)
+ stats
.rate_limited_ips
.get(ip)
.copied()
.unwrap_or(0);
suspects.push((ip.clone(), total, ok_pct, blocked));
fp += 1;
} else {
tp += 1;
}
}
if suspects.is_empty() {
Ok((fp, tp))
}
pub fn run(args: ReplayArgs) -> Result<()> {
eprintln!("Loading model from {}...", args.model_path);
eprintln!("Replaying {}...\n", args.input);
let result = replay_and_evaluate(&args)?;
let total = result.total;
eprintln!("═══ Replay Results ═══════════════════════════════════════");
eprintln!(" Total requests: {total}");
eprintln!(" Skipped (parse): {}", result.skipped);
eprintln!(" Allowed: {} ({:.1}%)", result.allowed, pct(result.allowed, total));
eprintln!(" DDoS blocked: {} ({:.1}%)", result.ddos_blocked, pct(result.ddos_blocked, total));
if args.rate_limit {
eprintln!(" Rate limited: {} ({:.1}%)", result.rate_limited, pct(result.rate_limited, total));
}
if !result.ddos_blocked_ips.is_empty() {
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
let mut sorted: Vec<_> = result.ddos_blocked_ips.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
for (ip, count) in sorted.iter().take(20) {
eprintln!(" {:<40} {} reqs blocked", ip, count);
}
}
if !result.rate_limited_ips.is_empty() {
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
let mut sorted: Vec<_> = result.rate_limited_ips.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
for (ip, count) in sorted.iter().take(20) {
eprintln!(" {:<40} {} reqs limited", ip, count);
}
}
eprintln!("\n── False positive check ──────────────────────────────────");
if result.false_positive_ips == 0 {
eprintln!(" No likely false positives found.");
} else {
suspects.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
eprintln!("{} IPs were blocked but had mostly successful responses:", suspects.len());
for (ip, total, ok_pct, blocked) in suspects.iter().take(15) {
eprintln!(
" {:<40} {}/{} reqs were 2xx/3xx ({:.0}%), {} blocked",
ip, ((*ok_pct / 100.0) * *total as f64) as u64, total, ok_pct, blocked,
);
}
eprintln!("{} IPs were blocked but had mostly successful responses", result.false_positive_ips);
}
eprintln!(" True positive IPs: {}", result.true_positive_ips);
eprintln!("══════════════════════════════════════════════════════════");
Ok(())
}

316
src/ensemble/replay.rs Normal file
View File

@@ -0,0 +1,316 @@
//! Replay audit logs through the ensemble models (scanner + DDoS).
use crate::ddos::audit_log::{self, AuditLog};
use crate::ddos::features::{method_to_u8, LogIpState};
use crate::ddos::model::DDoSAction;
use crate::ensemble::ddos::{ddos_ensemble_predict, DDoSEnsemblePath};
use crate::ensemble::scanner::{scanner_ensemble_predict, EnsemblePath};
use crate::scanner::features::{self, fx_hash_bytes};
use crate::scanner::model::ScannerAction;
use anyhow::{Context, Result};
use rustc_hash::{FxHashMap, FxHashSet};
use std::hash::{Hash, Hasher};
use std::io::BufRead;
pub struct ReplayEnsembleArgs {
pub input: String,
pub window_secs: u64,
pub min_events: usize,
}
pub fn run(args: ReplayEnsembleArgs) -> Result<()> {
eprintln!("replaying {} through ensemble models...\n", args.input);
let file =
std::fs::File::open(&args.input).with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
// --- Parse all entries ---
let mut entries: Vec<AuditLog> = Vec::new();
let mut parse_errors = 0u64;
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<AuditLog>(&line) {
Ok(e) => entries.push(e),
Err(_) => parse_errors += 1,
}
}
let total = entries.len() as u64;
eprintln!("parsed {} entries ({} parse errors)\n", total, parse_errors);
// --- Scanner replay ---
eprintln!("═══ Scanner Ensemble ═════════════════════════════════════");
replay_scanner(&entries);
// --- DDoS replay ---
eprintln!("\n═══ DDoS Ensemble ═══════════════════════════════════════");
replay_ddos(&entries, args.window_secs as f64, args.min_events);
eprintln!("\n══════════════════════════════════════════════════════════");
Ok(())
}
fn replay_scanner(entries: &[AuditLog]) {
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
.iter()
.map(|e| fx_hash_bytes(e.as_bytes()))
.collect();
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
for e in entries {
let prefix = e.fields.host.split('.').next().unwrap_or("");
log_hosts.insert(fx_hash_bytes(prefix.as_bytes()));
}
let mut total = 0u64;
let mut blocked = 0u64;
let mut allowed = 0u64;
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
let mut blocked_examples: Vec<(String, String, f64)> = Vec::new(); // (path, reason, score)
let mut fp_candidates: Vec<(String, u16, f64)> = Vec::new(); // blocked but had 2xx status
for e in entries {
let f = &e.fields;
let host_prefix = f.host.split('.').next().unwrap_or("");
let has_cookies = f.has_cookies.unwrap_or(false);
let has_referer = f
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = f
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let feats = features::extract_features_f32(
&f.method,
&f.path,
host_prefix,
has_cookies,
has_referer,
has_accept_language,
"-",
&f.user_agent,
f.content_length,
&fragment_hashes,
&extension_hashes,
&log_hosts,
);
let verdict = scanner_ensemble_predict(&feats);
total += 1;
match verdict.path {
EnsemblePath::TreeBlock => path_counts[0] += 1,
EnsemblePath::TreeAllow => path_counts[1] += 1,
EnsemblePath::Mlp => path_counts[2] += 1,
}
match verdict.action {
ScannerAction::Block => {
blocked += 1;
if blocked_examples.len() < 20 {
blocked_examples.push((
f.path.clone(),
verdict.reason.to_string(),
verdict.score,
));
}
if (200..400).contains(&f.status) {
fp_candidates.push((f.path.clone(), f.status, verdict.score));
}
}
ScannerAction::Allow => allowed += 1,
}
}
let pct = |n: u64| {
if total == 0 {
0.0
} else {
n as f64 / total as f64 * 100.0
}
};
eprintln!(" total: {total}");
eprintln!(
" blocked: {} ({:.1}%)",
blocked,
pct(blocked)
);
eprintln!(
" allowed: {} ({:.1}%)",
allowed,
pct(allowed)
);
eprintln!(
" paths: tree_block={} tree_allow={} mlp={}",
path_counts[0], path_counts[1], path_counts[2]
);
if !blocked_examples.is_empty() {
eprintln!("\n blocked examples (first 20):");
for (path, reason, score) in &blocked_examples {
eprintln!(" {:<50} {reason} (score={score:.3})", truncate(path, 50));
}
}
let fp_count = fp_candidates.len();
if fp_count > 0 {
eprintln!(
"\n potential false positives (blocked but had 2xx/3xx): {}",
fp_count
);
for (path, status, score) in fp_candidates.iter().take(10) {
eprintln!(
" {:<50} status={status} score={score:.3}",
truncate(path, 50)
);
}
}
}
fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) {
fn fx_hash(s: &str) -> u64 {
let mut h = rustc_hash::FxHasher::default();
s.hash(&mut h);
h.finish()
}
// Aggregate per-IP state.
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
for e in entries {
let f = &e.fields;
let ip = audit_log::strip_port(&f.client_ip).to_string();
let state = ip_states.entry(ip).or_default();
let ts = state.timestamps.len() as f64;
state.timestamps.push(ts);
state.methods.push(method_to_u8(&f.method));
state.path_hashes.push(fx_hash(&f.path));
state.host_hashes.push(fx_hash(&f.host));
state.user_agent_hashes.push(fx_hash(&f.user_agent));
state.statuses.push(f.status);
state
.durations
.push(f.duration_ms.min(u32::MAX as u64) as u32);
state
.content_lengths
.push(f.content_length.min(u32::MAX as u64) as u32);
state
.has_cookies
.push(f.has_cookies.unwrap_or(false));
state.has_referer.push(
f.referer
.as_deref()
.map(|r| r != "-")
.unwrap_or(false),
);
state.has_accept_language.push(
f.accept_language
.as_deref()
.map(|a| a != "-")
.unwrap_or(false),
);
state
.suspicious_paths
.push(crate::ddos::features::is_suspicious_path(&f.path));
}
let mut total_ips = 0u64;
let mut blocked_ips = 0u64;
let mut allowed_ips = 0u64;
let mut skipped_ips = 0u64;
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
let mut blocked_details: Vec<(String, usize, f64, &'static str)> = Vec::new();
for (ip, state) in &ip_states {
let n = state.timestamps.len();
if n < min_events {
skipped_ips += 1;
continue;
}
total_ips += 1;
let fv = state.extract_features_for_window(0, n, window_secs);
let fv_f32: [f32; 14] = {
let mut arr = [0.0f32; 14];
for i in 0..14 {
arr[i] = fv[i] as f32;
}
arr
};
let verdict = ddos_ensemble_predict(&fv_f32);
match verdict.path {
DDoSEnsemblePath::TreeBlock => path_counts[0] += 1,
DDoSEnsemblePath::TreeAllow => path_counts[1] += 1,
DDoSEnsemblePath::Mlp => path_counts[2] += 1,
}
match verdict.action {
DDoSAction::Block => {
blocked_ips += 1;
if blocked_details.len() < 30 {
blocked_details.push((ip.clone(), n, verdict.score, verdict.reason));
}
}
DDoSAction::Allow => allowed_ips += 1,
}
}
let pct = |n: u64, d: u64| {
if d == 0 {
0.0
} else {
n as f64 / d as f64 * 100.0
}
};
eprintln!(" unique IPs: {} ({} skipped < {} events)", ip_states.len(), skipped_ips, min_events);
eprintln!(" evaluated: {total_ips}");
eprintln!(
" blocked: {} ({:.1}%)",
blocked_ips,
pct(blocked_ips, total_ips)
);
eprintln!(
" allowed: {} ({:.1}%)",
allowed_ips,
pct(allowed_ips, total_ips)
);
eprintln!(
" paths: tree_block={} tree_allow={} mlp={}",
path_counts[0], path_counts[1], path_counts[2]
);
if !blocked_details.is_empty() {
eprintln!("\n blocked IPs (up to 30):");
let mut sorted = blocked_details;
sorted.sort_by(|a, b| b.1.cmp(&a.1));
for (ip, reqs, score, reason) in &sorted {
eprintln!(
" {:<40} {} reqs score={:.3} {reason}",
ip, reqs, score
);
}
}
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}...", &s[..max - 3])
}
}

View File

@@ -2,7 +2,7 @@ mod cert;
mod telemetry;
mod watcher;
use sunbeam_proxy::{acme, config};
use sunbeam_proxy::{acme, autotune, config};
use sunbeam_proxy::proxy::SunbeamProxy;
use sunbeam_proxy::ddos;
use sunbeam_proxy::rate_limit;
@@ -32,32 +32,10 @@ enum Commands {
#[arg(long)]
upgrade: bool,
},
/// Replay audit logs through the DDoS detector and rate limiter
ReplayDdos {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Path to trained model file
#[arg(short, long, default_value = "ddos_model.bin")]
model: String,
/// Optional config file (for rate limit settings)
#[arg(short, long)]
config: Option<String>,
/// KNN k parameter
#[arg(long, default_value = "5")]
k: usize,
/// Attack threshold
#[arg(long, default_value = "0.6")]
threshold: f64,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP before classification
#[arg(long, default_value = "10")]
min_events: usize,
/// Also run rate limiter during replay
#[arg(long)]
rate_limit: bool,
/// Replay audit logs through detection models
Replay {
#[command(subcommand)]
mode: ReplayMode,
},
/// Train a DDoS detection model from audit logs
TrainDdos {
@@ -107,31 +85,180 @@ enum Commands {
#[arg(long)]
csic: bool,
},
/// Bayesian hyperparameter optimization for DDoS model
AutotuneDdos {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output best model file path
#[arg(short, long, default_value = "ddos_model_best.bin")]
output: String,
/// Number of optimization trials
#[arg(long, default_value = "200")]
trials: usize,
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
#[arg(long, default_value = "1.0")]
beta: f64,
/// JSONL file to log each trial's parameters and results
#[arg(long)]
trial_log: Option<String>,
},
/// Download and cache upstream datasets (CIC-IDS2017)
DownloadDatasets,
/// Prepare a unified training dataset from multiple sources
PrepareDataset {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Path to OWASP ModSecurity audit log file (optional extra data)
#[arg(long)]
owasp: Option<String>,
/// Directory containing .txt wordlists (optional, enhances synthetic scanner)
#[arg(long)]
wordlists: Option<String>,
/// Output dataset file path
#[arg(short, long, default_value = "dataset.bin")]
output: String,
/// Random seed
#[arg(long, default_value = "42")]
seed: u64,
/// Path to heuristics.toml for auto-labeling production logs
#[arg(long)]
heuristics: Option<String>,
},
#[cfg(feature = "training")]
/// Train scanner ensemble (decision tree + MLP) from prepared dataset
TrainMlpScanner {
/// Path to prepared dataset (.bin)
#[arg(short = 'd', long)]
dataset: String,
/// Output directory for generated weight files
#[arg(short, long, default_value = "src/ensemble/gen")]
output_dir: String,
/// Hidden layer dimension
#[arg(long, default_value = "32")]
hidden_dim: usize,
/// Training epochs
#[arg(long, default_value = "100")]
epochs: usize,
/// Learning rate
#[arg(long, default_value = "0.001")]
learning_rate: f64,
/// Batch size
#[arg(long, default_value = "64")]
batch_size: usize,
/// Max tree depth
#[arg(long, default_value = "6")]
tree_max_depth: usize,
/// Min purity for tree leaves (below -> Defer)
#[arg(long, default_value = "0.90")]
tree_min_purity: f32,
},
#[cfg(feature = "training")]
/// Train DDoS ensemble (decision tree + MLP) from prepared dataset
TrainMlpDdos {
#[arg(short = 'd', long)]
dataset: String,
#[arg(short, long, default_value = "src/ensemble/gen")]
output_dir: String,
#[arg(long, default_value = "32")]
hidden_dim: usize,
#[arg(long, default_value = "100")]
epochs: usize,
#[arg(long, default_value = "0.001")]
learning_rate: f64,
#[arg(long, default_value = "64")]
batch_size: usize,
#[arg(long, default_value = "6")]
tree_max_depth: usize,
#[arg(long, default_value = "0.90")]
tree_min_purity: f32,
},
/// Bayesian hyperparameter optimization for scanner model
AutotuneScanner {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output best model file path
#[arg(short, long, default_value = "scanner_model_best.bin")]
output: String,
/// Directory (or file) containing .txt wordlists of scanner paths
#[arg(long)]
wordlists: Option<String>,
/// Include CSIC 2010 dataset as base training data
#[arg(long)]
csic: bool,
/// Number of optimization trials
#[arg(long, default_value = "200")]
trials: usize,
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
#[arg(long, default_value = "1.0")]
beta: f64,
/// JSONL file to log each trial's parameters and results
#[arg(long)]
trial_log: Option<String>,
},
}
#[derive(Subcommand)]
enum ReplayMode {
/// Replay through ensemble models (scanner + DDoS)
Ensemble {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP before DDoS classification
#[arg(long, default_value = "5")]
min_events: usize,
},
/// Replay through legacy KNN DDoS detector
Ddos {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Path to trained model file
#[arg(short, long, default_value = "ddos_model.bin")]
model: String,
/// Optional config file (for rate limit settings)
#[arg(short, long)]
config: Option<String>,
/// KNN k parameter
#[arg(long, default_value = "5")]
k: usize,
/// Attack threshold
#[arg(long, default_value = "0.6")]
threshold: f64,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP before classification
#[arg(long, default_value = "10")]
min_events: usize,
/// Also run rate limiter during replay
#[arg(long)]
rate_limit: bool,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
Commands::Serve { upgrade } => run_serve(upgrade),
Commands::ReplayDdos {
input,
model,
config,
k,
threshold,
window_secs,
min_events,
rate_limit,
} => ddos::replay::run(ddos::replay::ReplayArgs {
input,
model_path: model,
config_path: config,
k,
threshold,
window_secs,
min_events,
rate_limit,
}),
Commands::Replay { mode } => match mode {
ReplayMode::Ensemble { input, window_secs, min_events } => {
sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs {
input, window_secs, min_events,
})
}
ReplayMode::Ddos { input, model, config, k, threshold, window_secs, min_events, rate_limit } => {
ddos::replay::run(ddos::replay::ReplayArgs {
input, model_path: model, config_path: config, k, threshold, window_secs, min_events, rate_limit,
})
}
},
Commands::TrainDdos {
input,
output,
@@ -166,6 +293,56 @@ fn main() -> Result<()> {
threshold,
csic,
}),
Commands::DownloadDatasets => {
sunbeam_proxy::dataset::download::download_all()
},
Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics } => {
sunbeam_proxy::dataset::prepare::run(sunbeam_proxy::dataset::prepare::PrepareDatasetArgs {
input, owasp, wordlists, output, seed, heuristics,
})
},
#[cfg(feature = "training")]
Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
sunbeam_proxy::training::train_scanner::run(sunbeam_proxy::training::train_scanner::TrainScannerMlpArgs {
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
})
},
#[cfg(feature = "training")]
Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
sunbeam_proxy::training::train_ddos::run(sunbeam_proxy::training::train_ddos::TrainDdosMlpArgs {
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
})
},
Commands::AutotuneDdos {
input,
output,
trials,
beta,
trial_log,
} => autotune::ddos::run_autotune(autotune::ddos::AutotuneDdosArgs {
input,
output,
trials,
beta,
trial_log,
}),
Commands::AutotuneScanner {
input,
output,
wordlists,
csic,
trials,
beta,
trial_log,
} => autotune::scanner::run_autotune(autotune::scanner::AutotuneScannerArgs {
input,
output,
wordlists,
csic,
trials,
beta,
trial_log,
}),
}
}
@@ -189,8 +366,20 @@ fn run_serve(upgrade: bool) -> Result<()> {
// 2. Load DDoS detection model if configured.
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
if ddos_cfg.enabled {
if ddos_cfg.use_ensemble {
// Ensemble path: compiled-in weights, no model file needed.
// We still need a TrainedModel for the struct, but it won't be used.
let dummy_model = ddos::model::TrainedModel::empty(ddos_cfg.k, ddos_cfg.threshold);
let detector = Arc::new(ddos::detector::DDoSDetector::new_ensemble(dummy_model, ddos_cfg));
tracing::info!(
k = ddos_cfg.k,
threshold = ddos_cfg.threshold,
"DDoS ensemble detector enabled"
);
Some(detector)
} else if let Some(ref model_path) = ddos_cfg.model_path {
match ddos::model::TrainedModel::load(
std::path::Path::new(&ddos_cfg.model_path),
std::path::Path::new(model_path),
Some(ddos_cfg.k),
Some(ddos_cfg.threshold),
) {
@@ -210,6 +399,10 @@ fn run_serve(upgrade: bool) -> Result<()> {
None
}
}
} else {
tracing::warn!("DDoS enabled but no model_path and use_ensemble=false; detection disabled");
None
}
} else {
None
}
@@ -245,7 +438,33 @@ fn run_serve(upgrade: bool) -> Result<()> {
// 2c. Load scanner model if configured.
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
if scanner_cfg.enabled {
match scanner::model::ScannerModel::load(std::path::Path::new(&scanner_cfg.model_path)) {
if scanner_cfg.use_ensemble {
// Ensemble path: compiled-in weights, no model file needed.
let detector = scanner::detector::ScannerDetector::new_ensemble(&cfg.routes);
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
// Start bot allowlist if rules are configured.
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
let al = scanner::allowlist::BotAllowlist::spawn(
&scanner_cfg.allowlist,
scanner_cfg.bot_cache_ttl_secs,
);
tracing::info!(
rules = scanner_cfg.allowlist.len(),
"bot allowlist enabled"
);
Some(al)
} else {
None
};
tracing::info!(
threshold = scanner_cfg.threshold,
"scanner ensemble detector enabled"
);
(Some(handle), bot_allowlist)
} else if let Some(ref model_path) = scanner_cfg.model_path {
match scanner::model::ScannerModel::load(std::path::Path::new(model_path)) {
Ok(mut model) => {
let fragment_count = model.fragments.len();
model.threshold = scanner_cfg.threshold;
@@ -270,13 +489,13 @@ fn run_serve(upgrade: bool) -> Result<()> {
// Start background file watcher for hot-reload.
if scanner_cfg.poll_interval_secs > 0 {
let watcher_handle = handle.clone();
let model_path = std::path::PathBuf::from(&scanner_cfg.model_path);
let watcher_model_path = std::path::PathBuf::from(model_path);
let threshold = scanner_cfg.threshold;
let routes = cfg.routes.clone();
let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs);
std::thread::spawn(move || {
scanner::watcher::watch_scanner_model(
watcher_handle, model_path, threshold, routes, interval,
watcher_handle, watcher_model_path, threshold, routes, interval,
);
});
}
@@ -294,6 +513,10 @@ fn run_serve(upgrade: bool) -> Result<()> {
(None, None)
}
}
} else {
tracing::warn!("scanner enabled but no model_path and use_ensemble=false; scanner detection disabled");
(None, None)
}
} else {
(None, None)
}