feat(cli): restructure replay as subcommand with ensemble and ddos modes
Add `replay ensemble` (runs logs through compiled-in tree+MLP for both scanner and DDoS) and `replay ddos` (legacy KNN). Also adds CLI commands for download-datasets, prepare-dataset, train-mlp-scanner, train-mlp-ddos, autotune-ddos, and autotune-scanner. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
@@ -21,38 +21,39 @@ pub struct ReplayArgs {
|
||||
pub rate_limit: bool,
|
||||
}
|
||||
|
||||
struct ReplayStats {
|
||||
total: u64,
|
||||
skipped: u64,
|
||||
ddos_blocked: u64,
|
||||
rate_limited: u64,
|
||||
allowed: u64,
|
||||
ddos_blocked_ips: FxHashMap<String, u64>,
|
||||
rate_limited_ips: FxHashMap<String, u64>,
|
||||
pub struct ReplayResult {
|
||||
pub total: u64,
|
||||
pub skipped: u64,
|
||||
pub ddos_blocked: u64,
|
||||
pub rate_limited: u64,
|
||||
pub allowed: u64,
|
||||
pub ddos_blocked_ips: FxHashMap<String, u64>,
|
||||
pub rate_limited_ips: FxHashMap<String, u64>,
|
||||
pub false_positive_ips: usize,
|
||||
pub true_positive_ips: usize,
|
||||
}
|
||||
|
||||
pub fn run(args: ReplayArgs) -> Result<()> {
|
||||
eprintln!("Loading model from {}...", args.model_path);
|
||||
/// Core replay pipeline: load model, replay logs, compute stats including false positive analysis.
|
||||
pub fn replay_and_evaluate(args: &ReplayArgs) -> Result<ReplayResult> {
|
||||
let model = TrainedModel::load(
|
||||
std::path::Path::new(&args.model_path),
|
||||
Some(args.k),
|
||||
Some(args.threshold),
|
||||
)
|
||||
.with_context(|| format!("loading model from {}", args.model_path))?;
|
||||
eprintln!(" {} training points, k={}, threshold={}", model.point_count(), args.k, args.threshold);
|
||||
|
||||
let ddos_cfg = DDoSConfig {
|
||||
model_path: args.model_path.clone(),
|
||||
model_path: Some(args.model_path.clone()),
|
||||
k: args.k,
|
||||
threshold: args.threshold,
|
||||
window_secs: args.window_secs,
|
||||
window_capacity: 1000,
|
||||
min_events: args.min_events,
|
||||
enabled: true,
|
||||
use_ensemble: false,
|
||||
};
|
||||
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
|
||||
|
||||
// Optionally set up rate limiter
|
||||
let rate_limiter = if args.rate_limit {
|
||||
let rl_cfg = if let Some(cfg_path) = &args.config_path {
|
||||
let cfg = crate::config::Config::load(cfg_path)?;
|
||||
@@ -60,61 +61,49 @@ pub fn run(args: ReplayArgs) -> Result<()> {
|
||||
} else {
|
||||
default_rate_limit_config()
|
||||
};
|
||||
eprintln!(
|
||||
" Rate limiter: auth burst={} rate={}/s, unauth burst={} rate={}/s",
|
||||
rl_cfg.authenticated.burst,
|
||||
rl_cfg.authenticated.rate,
|
||||
rl_cfg.unauthenticated.burst,
|
||||
rl_cfg.unauthenticated.rate,
|
||||
);
|
||||
Some(RateLimiter::new(&rl_cfg))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
eprintln!("Replaying {}...\n", args.input);
|
||||
|
||||
let file = std::fs::File::open(&args.input)
|
||||
.with_context(|| format!("opening {}", args.input))?;
|
||||
let reader = std::io::BufReader::new(file);
|
||||
|
||||
let mut stats = ReplayStats {
|
||||
total: 0,
|
||||
skipped: 0,
|
||||
ddos_blocked: 0,
|
||||
rate_limited: 0,
|
||||
allowed: 0,
|
||||
ddos_blocked_ips: FxHashMap::default(),
|
||||
rate_limited_ips: FxHashMap::default(),
|
||||
};
|
||||
let mut total = 0u64;
|
||||
let mut skipped = 0u64;
|
||||
let mut ddos_blocked = 0u64;
|
||||
let mut rate_limited = 0u64;
|
||||
let mut allowed = 0u64;
|
||||
let mut ddos_blocked_ips: FxHashMap<String, u64> = FxHashMap::default();
|
||||
let mut rate_limited_ips: FxHashMap<String, u64> = FxHashMap::default();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||
Ok(e) => e,
|
||||
Err(_) => {
|
||||
stats.skipped += 1;
|
||||
skipped += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if entry.fields.method.is_empty() {
|
||||
stats.skipped += 1;
|
||||
skipped += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
stats.total += 1;
|
||||
total += 1;
|
||||
|
||||
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||
let ip: IpAddr = match ip_str.parse() {
|
||||
Ok(ip) => ip,
|
||||
Err(_) => {
|
||||
stats.skipped += 1;
|
||||
skipped += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// DDoS check
|
||||
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
|
||||
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
|
||||
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
|
||||
@@ -131,78 +120,57 @@ pub fn run(args: ReplayArgs) -> Result<()> {
|
||||
);
|
||||
|
||||
if ddos_action == DDoSAction::Block {
|
||||
stats.ddos_blocked += 1;
|
||||
*stats.ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||
ddos_blocked += 1;
|
||||
*ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Rate limit check
|
||||
if let Some(limiter) = &rate_limiter {
|
||||
// Audit logs don't have auth headers, so all traffic is keyed by IP
|
||||
let rl_key = RateLimitKey::Ip(ip);
|
||||
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
|
||||
stats.rate_limited += 1;
|
||||
*stats.rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||
rate_limited += 1;
|
||||
*rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
stats.allowed += 1;
|
||||
allowed += 1;
|
||||
}
|
||||
|
||||
// Report
|
||||
let total = stats.total;
|
||||
eprintln!("═══ Replay Results ═══════════════════════════════════════");
|
||||
eprintln!(" Total requests: {total}");
|
||||
eprintln!(" Skipped (parse): {}", stats.skipped);
|
||||
eprintln!(" Allowed: {} ({:.1}%)", stats.allowed, pct(stats.allowed, total));
|
||||
eprintln!(" DDoS blocked: {} ({:.1}%)", stats.ddos_blocked, pct(stats.ddos_blocked, total));
|
||||
if rate_limiter.is_some() {
|
||||
eprintln!(" Rate limited: {} ({:.1}%)", stats.rate_limited, pct(stats.rate_limited, total));
|
||||
// Compute false positive / true positive counts
|
||||
let (false_positive_ips, true_positive_ips) =
|
||||
count_fp_tp(&args.input, &ddos_blocked_ips, &rate_limited_ips)?;
|
||||
|
||||
Ok(ReplayResult {
|
||||
total,
|
||||
skipped,
|
||||
ddos_blocked,
|
||||
rate_limited,
|
||||
allowed,
|
||||
ddos_blocked_ips,
|
||||
rate_limited_ips,
|
||||
false_positive_ips,
|
||||
true_positive_ips,
|
||||
})
|
||||
}
|
||||
|
||||
if !stats.ddos_blocked_ips.is_empty() {
|
||||
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
|
||||
let mut sorted: Vec<_> = stats.ddos_blocked_ips.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
for (ip, count) in sorted.iter().take(20) {
|
||||
eprintln!(" {:<40} {} reqs blocked", ip, count);
|
||||
}
|
||||
}
|
||||
|
||||
if !stats.rate_limited_ips.is_empty() {
|
||||
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
|
||||
let mut sorted: Vec<_> = stats.rate_limited_ips.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
for (ip, count) in sorted.iter().take(20) {
|
||||
eprintln!(" {:<40} {} reqs limited", ip, count);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for false positives: IPs that were blocked but had 2xx statuses in the original logs
|
||||
eprintln!("\n── False positive check ──────────────────────────────────");
|
||||
check_false_positives(&args.input, &stats)?;
|
||||
|
||||
eprintln!("══════════════════════════════════════════════════════════");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-scan the log to find blocked IPs that had mostly 2xx responses originally
|
||||
/// (i.e. they were legitimate traffic that the model would incorrectly block).
|
||||
fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
|
||||
let blocked_ips: rustc_hash::FxHashSet<&str> = stats
|
||||
.ddos_blocked_ips
|
||||
/// Count false-positive and true-positive IPs from blocked set.
|
||||
/// FP = blocked IP where >60% of original responses were 2xx/3xx.
|
||||
fn count_fp_tp(
|
||||
input: &str,
|
||||
ddos_blocked_ips: &FxHashMap<String, u64>,
|
||||
rate_limited_ips: &FxHashMap<String, u64>,
|
||||
) -> Result<(usize, usize)> {
|
||||
let blocked_ips: rustc_hash::FxHashSet<&str> = ddos_blocked_ips
|
||||
.keys()
|
||||
.chain(stats.rate_limited_ips.keys())
|
||||
.chain(rate_limited_ips.keys())
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
|
||||
if blocked_ips.is_empty() {
|
||||
eprintln!(" No blocked IPs — nothing to check.");
|
||||
return Ok(());
|
||||
return Ok((0, 0));
|
||||
}
|
||||
|
||||
// Collect original status codes for blocked IPs
|
||||
let file = std::fs::File::open(input)?;
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
|
||||
@@ -215,47 +183,69 @@ fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
|
||||
};
|
||||
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||
if blocked_ips.contains(ip_str.as_str()) {
|
||||
ip_statuses
|
||||
.entry(ip_str)
|
||||
.or_default()
|
||||
.push(entry.fields.status);
|
||||
ip_statuses.entry(ip_str).or_default().push(entry.fields.status);
|
||||
}
|
||||
}
|
||||
|
||||
let mut suspects = Vec::new();
|
||||
for (ip, statuses) in &ip_statuses {
|
||||
let mut fp = 0usize;
|
||||
let mut tp = 0usize;
|
||||
for (_ip, statuses) in &ip_statuses {
|
||||
let total = statuses.len();
|
||||
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
|
||||
let ok_pct = (ok_count as f64 / total as f64) * 100.0;
|
||||
// If >60% of original responses were 2xx/3xx, this might be a false positive
|
||||
let ok_pct = ok_count as f64 / total as f64 * 100.0;
|
||||
if ok_pct > 60.0 {
|
||||
let blocked = stats
|
||||
.ddos_blocked_ips
|
||||
.get(ip)
|
||||
.copied()
|
||||
.unwrap_or(0)
|
||||
+ stats
|
||||
.rate_limited_ips
|
||||
.get(ip)
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
suspects.push((ip.clone(), total, ok_pct, blocked));
|
||||
fp += 1;
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if suspects.is_empty() {
|
||||
Ok((fp, tp))
|
||||
}
|
||||
|
||||
pub fn run(args: ReplayArgs) -> Result<()> {
|
||||
eprintln!("Loading model from {}...", args.model_path);
|
||||
eprintln!("Replaying {}...\n", args.input);
|
||||
|
||||
let result = replay_and_evaluate(&args)?;
|
||||
|
||||
let total = result.total;
|
||||
eprintln!("═══ Replay Results ═══════════════════════════════════════");
|
||||
eprintln!(" Total requests: {total}");
|
||||
eprintln!(" Skipped (parse): {}", result.skipped);
|
||||
eprintln!(" Allowed: {} ({:.1}%)", result.allowed, pct(result.allowed, total));
|
||||
eprintln!(" DDoS blocked: {} ({:.1}%)", result.ddos_blocked, pct(result.ddos_blocked, total));
|
||||
if args.rate_limit {
|
||||
eprintln!(" Rate limited: {} ({:.1}%)", result.rate_limited, pct(result.rate_limited, total));
|
||||
}
|
||||
|
||||
if !result.ddos_blocked_ips.is_empty() {
|
||||
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
|
||||
let mut sorted: Vec<_> = result.ddos_blocked_ips.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
for (ip, count) in sorted.iter().take(20) {
|
||||
eprintln!(" {:<40} {} reqs blocked", ip, count);
|
||||
}
|
||||
}
|
||||
|
||||
if !result.rate_limited_ips.is_empty() {
|
||||
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
|
||||
let mut sorted: Vec<_> = result.rate_limited_ips.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
for (ip, count) in sorted.iter().take(20) {
|
||||
eprintln!(" {:<40} {} reqs limited", ip, count);
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("\n── False positive check ──────────────────────────────────");
|
||||
if result.false_positive_ips == 0 {
|
||||
eprintln!(" No likely false positives found.");
|
||||
} else {
|
||||
suspects.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
|
||||
eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses:", suspects.len());
|
||||
for (ip, total, ok_pct, blocked) in suspects.iter().take(15) {
|
||||
eprintln!(
|
||||
" {:<40} {}/{} reqs were 2xx/3xx ({:.0}%), {} blocked",
|
||||
ip, ((*ok_pct / 100.0) * *total as f64) as u64, total, ok_pct, blocked,
|
||||
);
|
||||
}
|
||||
eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses", result.false_positive_ips);
|
||||
}
|
||||
eprintln!(" True positive IPs: {}", result.true_positive_ips);
|
||||
|
||||
eprintln!("══════════════════════════════════════════════════════════");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
316
src/ensemble/replay.rs
Normal file
316
src/ensemble/replay.rs
Normal file
@@ -0,0 +1,316 @@
|
||||
//! Replay audit logs through the ensemble models (scanner + DDoS).
|
||||
|
||||
use crate::ddos::audit_log::{self, AuditLog};
|
||||
use crate::ddos::features::{method_to_u8, LogIpState};
|
||||
use crate::ddos::model::DDoSAction;
|
||||
use crate::ensemble::ddos::{ddos_ensemble_predict, DDoSEnsemblePath};
|
||||
use crate::ensemble::scanner::{scanner_ensemble_predict, EnsemblePath};
|
||||
use crate::scanner::features::{self, fx_hash_bytes};
|
||||
use crate::scanner::model::ScannerAction;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::io::BufRead;
|
||||
|
||||
pub struct ReplayEnsembleArgs {
|
||||
pub input: String,
|
||||
pub window_secs: u64,
|
||||
pub min_events: usize,
|
||||
}
|
||||
|
||||
pub fn run(args: ReplayEnsembleArgs) -> Result<()> {
|
||||
eprintln!("replaying {} through ensemble models...\n", args.input);
|
||||
|
||||
let file =
|
||||
std::fs::File::open(&args.input).with_context(|| format!("opening {}", args.input))?;
|
||||
let reader = std::io::BufReader::new(file);
|
||||
|
||||
// --- Parse all entries ---
|
||||
let mut entries: Vec<AuditLog> = Vec::new();
|
||||
let mut parse_errors = 0u64;
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_str::<AuditLog>(&line) {
|
||||
Ok(e) => entries.push(e),
|
||||
Err(_) => parse_errors += 1,
|
||||
}
|
||||
}
|
||||
let total = entries.len() as u64;
|
||||
eprintln!("parsed {} entries ({} parse errors)\n", total, parse_errors);
|
||||
|
||||
// --- Scanner replay ---
|
||||
eprintln!("═══ Scanner Ensemble ═════════════════════════════════════");
|
||||
replay_scanner(&entries);
|
||||
|
||||
// --- DDoS replay ---
|
||||
eprintln!("\n═══ DDoS Ensemble ═══════════════════════════════════════");
|
||||
replay_ddos(&entries, args.window_secs as f64, args.min_events);
|
||||
|
||||
eprintln!("\n══════════════════════════════════════════════════════════");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn replay_scanner(entries: &[AuditLog]) {
|
||||
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
||||
.iter()
|
||||
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
||||
.collect();
|
||||
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
|
||||
.iter()
|
||||
.map(|e| fx_hash_bytes(e.as_bytes()))
|
||||
.collect();
|
||||
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
|
||||
for e in entries {
|
||||
let prefix = e.fields.host.split('.').next().unwrap_or("");
|
||||
log_hosts.insert(fx_hash_bytes(prefix.as_bytes()));
|
||||
}
|
||||
|
||||
let mut total = 0u64;
|
||||
let mut blocked = 0u64;
|
||||
let mut allowed = 0u64;
|
||||
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
|
||||
let mut blocked_examples: Vec<(String, String, f64)> = Vec::new(); // (path, reason, score)
|
||||
let mut fp_candidates: Vec<(String, u16, f64)> = Vec::new(); // blocked but had 2xx status
|
||||
|
||||
for e in entries {
|
||||
let f = &e.fields;
|
||||
let host_prefix = f.host.split('.').next().unwrap_or("");
|
||||
let has_cookies = f.has_cookies.unwrap_or(false);
|
||||
let has_referer = f
|
||||
.referer
|
||||
.as_ref()
|
||||
.map(|r| r != "-" && !r.is_empty())
|
||||
.unwrap_or(false);
|
||||
let has_accept_language = f
|
||||
.accept_language
|
||||
.as_ref()
|
||||
.map(|a| a != "-" && !a.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
let feats = features::extract_features_f32(
|
||||
&f.method,
|
||||
&f.path,
|
||||
host_prefix,
|
||||
has_cookies,
|
||||
has_referer,
|
||||
has_accept_language,
|
||||
"-",
|
||||
&f.user_agent,
|
||||
f.content_length,
|
||||
&fragment_hashes,
|
||||
&extension_hashes,
|
||||
&log_hosts,
|
||||
);
|
||||
|
||||
let verdict = scanner_ensemble_predict(&feats);
|
||||
total += 1;
|
||||
|
||||
match verdict.path {
|
||||
EnsemblePath::TreeBlock => path_counts[0] += 1,
|
||||
EnsemblePath::TreeAllow => path_counts[1] += 1,
|
||||
EnsemblePath::Mlp => path_counts[2] += 1,
|
||||
}
|
||||
|
||||
match verdict.action {
|
||||
ScannerAction::Block => {
|
||||
blocked += 1;
|
||||
if blocked_examples.len() < 20 {
|
||||
blocked_examples.push((
|
||||
f.path.clone(),
|
||||
verdict.reason.to_string(),
|
||||
verdict.score,
|
||||
));
|
||||
}
|
||||
if (200..400).contains(&f.status) {
|
||||
fp_candidates.push((f.path.clone(), f.status, verdict.score));
|
||||
}
|
||||
}
|
||||
ScannerAction::Allow => allowed += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let pct = |n: u64| {
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
n as f64 / total as f64 * 100.0
|
||||
}
|
||||
};
|
||||
|
||||
eprintln!(" total: {total}");
|
||||
eprintln!(
|
||||
" blocked: {} ({:.1}%)",
|
||||
blocked,
|
||||
pct(blocked)
|
||||
);
|
||||
eprintln!(
|
||||
" allowed: {} ({:.1}%)",
|
||||
allowed,
|
||||
pct(allowed)
|
||||
);
|
||||
eprintln!(
|
||||
" paths: tree_block={} tree_allow={} mlp={}",
|
||||
path_counts[0], path_counts[1], path_counts[2]
|
||||
);
|
||||
|
||||
if !blocked_examples.is_empty() {
|
||||
eprintln!("\n blocked examples (first 20):");
|
||||
for (path, reason, score) in &blocked_examples {
|
||||
eprintln!(" {:<50} {reason} (score={score:.3})", truncate(path, 50));
|
||||
}
|
||||
}
|
||||
|
||||
let fp_count = fp_candidates.len();
|
||||
if fp_count > 0 {
|
||||
eprintln!(
|
||||
"\n potential false positives (blocked but had 2xx/3xx): {}",
|
||||
fp_count
|
||||
);
|
||||
for (path, status, score) in fp_candidates.iter().take(10) {
|
||||
eprintln!(
|
||||
" {:<50} status={status} score={score:.3}",
|
||||
truncate(path, 50)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) {
|
||||
fn fx_hash(s: &str) -> u64 {
|
||||
let mut h = rustc_hash::FxHasher::default();
|
||||
s.hash(&mut h);
|
||||
h.finish()
|
||||
}
|
||||
|
||||
// Aggregate per-IP state.
|
||||
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
|
||||
|
||||
for e in entries {
|
||||
let f = &e.fields;
|
||||
let ip = audit_log::strip_port(&f.client_ip).to_string();
|
||||
let state = ip_states.entry(ip).or_default();
|
||||
let ts = state.timestamps.len() as f64;
|
||||
state.timestamps.push(ts);
|
||||
state.methods.push(method_to_u8(&f.method));
|
||||
state.path_hashes.push(fx_hash(&f.path));
|
||||
state.host_hashes.push(fx_hash(&f.host));
|
||||
state.user_agent_hashes.push(fx_hash(&f.user_agent));
|
||||
state.statuses.push(f.status);
|
||||
state
|
||||
.durations
|
||||
.push(f.duration_ms.min(u32::MAX as u64) as u32);
|
||||
state
|
||||
.content_lengths
|
||||
.push(f.content_length.min(u32::MAX as u64) as u32);
|
||||
state
|
||||
.has_cookies
|
||||
.push(f.has_cookies.unwrap_or(false));
|
||||
state.has_referer.push(
|
||||
f.referer
|
||||
.as_deref()
|
||||
.map(|r| r != "-")
|
||||
.unwrap_or(false),
|
||||
);
|
||||
state.has_accept_language.push(
|
||||
f.accept_language
|
||||
.as_deref()
|
||||
.map(|a| a != "-")
|
||||
.unwrap_or(false),
|
||||
);
|
||||
state
|
||||
.suspicious_paths
|
||||
.push(crate::ddos::features::is_suspicious_path(&f.path));
|
||||
}
|
||||
|
||||
let mut total_ips = 0u64;
|
||||
let mut blocked_ips = 0u64;
|
||||
let mut allowed_ips = 0u64;
|
||||
let mut skipped_ips = 0u64;
|
||||
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
|
||||
let mut blocked_details: Vec<(String, usize, f64, &'static str)> = Vec::new();
|
||||
|
||||
for (ip, state) in &ip_states {
|
||||
let n = state.timestamps.len();
|
||||
if n < min_events {
|
||||
skipped_ips += 1;
|
||||
continue;
|
||||
}
|
||||
total_ips += 1;
|
||||
|
||||
let fv = state.extract_features_for_window(0, n, window_secs);
|
||||
let fv_f32: [f32; 14] = {
|
||||
let mut arr = [0.0f32; 14];
|
||||
for i in 0..14 {
|
||||
arr[i] = fv[i] as f32;
|
||||
}
|
||||
arr
|
||||
};
|
||||
|
||||
let verdict = ddos_ensemble_predict(&fv_f32);
|
||||
|
||||
match verdict.path {
|
||||
DDoSEnsemblePath::TreeBlock => path_counts[0] += 1,
|
||||
DDoSEnsemblePath::TreeAllow => path_counts[1] += 1,
|
||||
DDoSEnsemblePath::Mlp => path_counts[2] += 1,
|
||||
}
|
||||
|
||||
match verdict.action {
|
||||
DDoSAction::Block => {
|
||||
blocked_ips += 1;
|
||||
if blocked_details.len() < 30 {
|
||||
blocked_details.push((ip.clone(), n, verdict.score, verdict.reason));
|
||||
}
|
||||
}
|
||||
DDoSAction::Allow => allowed_ips += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let pct = |n: u64, d: u64| {
|
||||
if d == 0 {
|
||||
0.0
|
||||
} else {
|
||||
n as f64 / d as f64 * 100.0
|
||||
}
|
||||
};
|
||||
|
||||
eprintln!(" unique IPs: {} ({} skipped < {} events)", ip_states.len(), skipped_ips, min_events);
|
||||
eprintln!(" evaluated: {total_ips}");
|
||||
eprintln!(
|
||||
" blocked: {} ({:.1}%)",
|
||||
blocked_ips,
|
||||
pct(blocked_ips, total_ips)
|
||||
);
|
||||
eprintln!(
|
||||
" allowed: {} ({:.1}%)",
|
||||
allowed_ips,
|
||||
pct(allowed_ips, total_ips)
|
||||
);
|
||||
eprintln!(
|
||||
" paths: tree_block={} tree_allow={} mlp={}",
|
||||
path_counts[0], path_counts[1], path_counts[2]
|
||||
);
|
||||
|
||||
if !blocked_details.is_empty() {
|
||||
eprintln!("\n blocked IPs (up to 30):");
|
||||
let mut sorted = blocked_details;
|
||||
sorted.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
for (ip, reqs, score, reason) in &sorted {
|
||||
eprintln!(
|
||||
" {:<40} {} reqs score={:.3} {reason}",
|
||||
ip, reqs, score
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate(s: &str, max: usize) -> String {
|
||||
if s.len() <= max {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}...", &s[..max - 3])
|
||||
}
|
||||
}
|
||||
323
src/main.rs
323
src/main.rs
@@ -2,7 +2,7 @@ mod cert;
|
||||
mod telemetry;
|
||||
mod watcher;
|
||||
|
||||
use sunbeam_proxy::{acme, config};
|
||||
use sunbeam_proxy::{acme, autotune, config};
|
||||
use sunbeam_proxy::proxy::SunbeamProxy;
|
||||
use sunbeam_proxy::ddos;
|
||||
use sunbeam_proxy::rate_limit;
|
||||
@@ -32,32 +32,10 @@ enum Commands {
|
||||
#[arg(long)]
|
||||
upgrade: bool,
|
||||
},
|
||||
/// Replay audit logs through the DDoS detector and rate limiter
|
||||
ReplayDdos {
|
||||
/// Path to audit log JSONL file
|
||||
#[arg(short, long)]
|
||||
input: String,
|
||||
/// Path to trained model file
|
||||
#[arg(short, long, default_value = "ddos_model.bin")]
|
||||
model: String,
|
||||
/// Optional config file (for rate limit settings)
|
||||
#[arg(short, long)]
|
||||
config: Option<String>,
|
||||
/// KNN k parameter
|
||||
#[arg(long, default_value = "5")]
|
||||
k: usize,
|
||||
/// Attack threshold
|
||||
#[arg(long, default_value = "0.6")]
|
||||
threshold: f64,
|
||||
/// Sliding window size in seconds
|
||||
#[arg(long, default_value = "60")]
|
||||
window_secs: u64,
|
||||
/// Minimum events per IP before classification
|
||||
#[arg(long, default_value = "10")]
|
||||
min_events: usize,
|
||||
/// Also run rate limiter during replay
|
||||
#[arg(long)]
|
||||
rate_limit: bool,
|
||||
/// Replay audit logs through detection models
|
||||
Replay {
|
||||
#[command(subcommand)]
|
||||
mode: ReplayMode,
|
||||
},
|
||||
/// Train a DDoS detection model from audit logs
|
||||
TrainDdos {
|
||||
@@ -107,31 +85,180 @@ enum Commands {
|
||||
#[arg(long)]
|
||||
csic: bool,
|
||||
},
|
||||
/// Bayesian hyperparameter optimization for DDoS model
|
||||
AutotuneDdos {
|
||||
/// Path to audit log JSONL file
|
||||
#[arg(short, long)]
|
||||
input: String,
|
||||
/// Output best model file path
|
||||
#[arg(short, long, default_value = "ddos_model_best.bin")]
|
||||
output: String,
|
||||
/// Number of optimization trials
|
||||
#[arg(long, default_value = "200")]
|
||||
trials: usize,
|
||||
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
|
||||
#[arg(long, default_value = "1.0")]
|
||||
beta: f64,
|
||||
/// JSONL file to log each trial's parameters and results
|
||||
#[arg(long)]
|
||||
trial_log: Option<String>,
|
||||
},
|
||||
/// Download and cache upstream datasets (CIC-IDS2017)
|
||||
DownloadDatasets,
|
||||
/// Prepare a unified training dataset from multiple sources
|
||||
PrepareDataset {
|
||||
/// Path to audit log JSONL file
|
||||
#[arg(short, long)]
|
||||
input: String,
|
||||
/// Path to OWASP ModSecurity audit log file (optional extra data)
|
||||
#[arg(long)]
|
||||
owasp: Option<String>,
|
||||
/// Directory containing .txt wordlists (optional, enhances synthetic scanner)
|
||||
#[arg(long)]
|
||||
wordlists: Option<String>,
|
||||
/// Output dataset file path
|
||||
#[arg(short, long, default_value = "dataset.bin")]
|
||||
output: String,
|
||||
/// Random seed
|
||||
#[arg(long, default_value = "42")]
|
||||
seed: u64,
|
||||
/// Path to heuristics.toml for auto-labeling production logs
|
||||
#[arg(long)]
|
||||
heuristics: Option<String>,
|
||||
},
|
||||
#[cfg(feature = "training")]
|
||||
/// Train scanner ensemble (decision tree + MLP) from prepared dataset
|
||||
TrainMlpScanner {
|
||||
/// Path to prepared dataset (.bin)
|
||||
#[arg(short = 'd', long)]
|
||||
dataset: String,
|
||||
/// Output directory for generated weight files
|
||||
#[arg(short, long, default_value = "src/ensemble/gen")]
|
||||
output_dir: String,
|
||||
/// Hidden layer dimension
|
||||
#[arg(long, default_value = "32")]
|
||||
hidden_dim: usize,
|
||||
/// Training epochs
|
||||
#[arg(long, default_value = "100")]
|
||||
epochs: usize,
|
||||
/// Learning rate
|
||||
#[arg(long, default_value = "0.001")]
|
||||
learning_rate: f64,
|
||||
/// Batch size
|
||||
#[arg(long, default_value = "64")]
|
||||
batch_size: usize,
|
||||
/// Max tree depth
|
||||
#[arg(long, default_value = "6")]
|
||||
tree_max_depth: usize,
|
||||
/// Min purity for tree leaves (below -> Defer)
|
||||
#[arg(long, default_value = "0.90")]
|
||||
tree_min_purity: f32,
|
||||
},
|
||||
#[cfg(feature = "training")]
|
||||
/// Train DDoS ensemble (decision tree + MLP) from prepared dataset
|
||||
TrainMlpDdos {
|
||||
#[arg(short = 'd', long)]
|
||||
dataset: String,
|
||||
#[arg(short, long, default_value = "src/ensemble/gen")]
|
||||
output_dir: String,
|
||||
#[arg(long, default_value = "32")]
|
||||
hidden_dim: usize,
|
||||
#[arg(long, default_value = "100")]
|
||||
epochs: usize,
|
||||
#[arg(long, default_value = "0.001")]
|
||||
learning_rate: f64,
|
||||
#[arg(long, default_value = "64")]
|
||||
batch_size: usize,
|
||||
#[arg(long, default_value = "6")]
|
||||
tree_max_depth: usize,
|
||||
#[arg(long, default_value = "0.90")]
|
||||
tree_min_purity: f32,
|
||||
},
|
||||
/// Bayesian hyperparameter optimization for scanner model
|
||||
AutotuneScanner {
|
||||
/// Path to audit log JSONL file
|
||||
#[arg(short, long)]
|
||||
input: String,
|
||||
/// Output best model file path
|
||||
#[arg(short, long, default_value = "scanner_model_best.bin")]
|
||||
output: String,
|
||||
/// Directory (or file) containing .txt wordlists of scanner paths
|
||||
#[arg(long)]
|
||||
wordlists: Option<String>,
|
||||
/// Include CSIC 2010 dataset as base training data
|
||||
#[arg(long)]
|
||||
csic: bool,
|
||||
/// Number of optimization trials
|
||||
#[arg(long, default_value = "200")]
|
||||
trials: usize,
|
||||
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
|
||||
#[arg(long, default_value = "1.0")]
|
||||
beta: f64,
|
||||
/// JSONL file to log each trial's parameters and results
|
||||
#[arg(long)]
|
||||
trial_log: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum ReplayMode {
|
||||
/// Replay through ensemble models (scanner + DDoS)
|
||||
Ensemble {
|
||||
/// Path to audit log JSONL file
|
||||
#[arg(short, long)]
|
||||
input: String,
|
||||
/// Sliding window size in seconds
|
||||
#[arg(long, default_value = "60")]
|
||||
window_secs: u64,
|
||||
/// Minimum events per IP before DDoS classification
|
||||
#[arg(long, default_value = "5")]
|
||||
min_events: usize,
|
||||
},
|
||||
/// Replay through legacy KNN DDoS detector
|
||||
Ddos {
|
||||
/// Path to audit log JSONL file
|
||||
#[arg(short, long)]
|
||||
input: String,
|
||||
/// Path to trained model file
|
||||
#[arg(short, long, default_value = "ddos_model.bin")]
|
||||
model: String,
|
||||
/// Optional config file (for rate limit settings)
|
||||
#[arg(short, long)]
|
||||
config: Option<String>,
|
||||
/// KNN k parameter
|
||||
#[arg(long, default_value = "5")]
|
||||
k: usize,
|
||||
/// Attack threshold
|
||||
#[arg(long, default_value = "0.6")]
|
||||
threshold: f64,
|
||||
/// Sliding window size in seconds
|
||||
#[arg(long, default_value = "60")]
|
||||
window_secs: u64,
|
||||
/// Minimum events per IP before classification
|
||||
#[arg(long, default_value = "10")]
|
||||
min_events: usize,
|
||||
/// Also run rate limiter during replay
|
||||
#[arg(long)]
|
||||
rate_limit: bool,
|
||||
},
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
|
||||
Commands::Serve { upgrade } => run_serve(upgrade),
|
||||
Commands::ReplayDdos {
|
||||
input,
|
||||
model,
|
||||
config,
|
||||
k,
|
||||
threshold,
|
||||
window_secs,
|
||||
min_events,
|
||||
rate_limit,
|
||||
} => ddos::replay::run(ddos::replay::ReplayArgs {
|
||||
input,
|
||||
model_path: model,
|
||||
config_path: config,
|
||||
k,
|
||||
threshold,
|
||||
window_secs,
|
||||
min_events,
|
||||
rate_limit,
|
||||
}),
|
||||
Commands::Replay { mode } => match mode {
|
||||
ReplayMode::Ensemble { input, window_secs, min_events } => {
|
||||
sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs {
|
||||
input, window_secs, min_events,
|
||||
})
|
||||
}
|
||||
ReplayMode::Ddos { input, model, config, k, threshold, window_secs, min_events, rate_limit } => {
|
||||
ddos::replay::run(ddos::replay::ReplayArgs {
|
||||
input, model_path: model, config_path: config, k, threshold, window_secs, min_events, rate_limit,
|
||||
})
|
||||
}
|
||||
},
|
||||
Commands::TrainDdos {
|
||||
input,
|
||||
output,
|
||||
@@ -166,6 +293,56 @@ fn main() -> Result<()> {
|
||||
threshold,
|
||||
csic,
|
||||
}),
|
||||
Commands::DownloadDatasets => {
|
||||
sunbeam_proxy::dataset::download::download_all()
|
||||
},
|
||||
Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics } => {
|
||||
sunbeam_proxy::dataset::prepare::run(sunbeam_proxy::dataset::prepare::PrepareDatasetArgs {
|
||||
input, owasp, wordlists, output, seed, heuristics,
|
||||
})
|
||||
},
|
||||
#[cfg(feature = "training")]
|
||||
Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
|
||||
sunbeam_proxy::training::train_scanner::run(sunbeam_proxy::training::train_scanner::TrainScannerMlpArgs {
|
||||
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
|
||||
})
|
||||
},
|
||||
#[cfg(feature = "training")]
|
||||
Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
|
||||
sunbeam_proxy::training::train_ddos::run(sunbeam_proxy::training::train_ddos::TrainDdosMlpArgs {
|
||||
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
|
||||
})
|
||||
},
|
||||
Commands::AutotuneDdos {
|
||||
input,
|
||||
output,
|
||||
trials,
|
||||
beta,
|
||||
trial_log,
|
||||
} => autotune::ddos::run_autotune(autotune::ddos::AutotuneDdosArgs {
|
||||
input,
|
||||
output,
|
||||
trials,
|
||||
beta,
|
||||
trial_log,
|
||||
}),
|
||||
Commands::AutotuneScanner {
|
||||
input,
|
||||
output,
|
||||
wordlists,
|
||||
csic,
|
||||
trials,
|
||||
beta,
|
||||
trial_log,
|
||||
} => autotune::scanner::run_autotune(autotune::scanner::AutotuneScannerArgs {
|
||||
input,
|
||||
output,
|
||||
wordlists,
|
||||
csic,
|
||||
trials,
|
||||
beta,
|
||||
trial_log,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,8 +366,20 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
||||
// 2. Load DDoS detection model if configured.
|
||||
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
|
||||
if ddos_cfg.enabled {
|
||||
if ddos_cfg.use_ensemble {
|
||||
// Ensemble path: compiled-in weights, no model file needed.
|
||||
// We still need a TrainedModel for the struct, but it won't be used.
|
||||
let dummy_model = ddos::model::TrainedModel::empty(ddos_cfg.k, ddos_cfg.threshold);
|
||||
let detector = Arc::new(ddos::detector::DDoSDetector::new_ensemble(dummy_model, ddos_cfg));
|
||||
tracing::info!(
|
||||
k = ddos_cfg.k,
|
||||
threshold = ddos_cfg.threshold,
|
||||
"DDoS ensemble detector enabled"
|
||||
);
|
||||
Some(detector)
|
||||
} else if let Some(ref model_path) = ddos_cfg.model_path {
|
||||
match ddos::model::TrainedModel::load(
|
||||
std::path::Path::new(&ddos_cfg.model_path),
|
||||
std::path::Path::new(model_path),
|
||||
Some(ddos_cfg.k),
|
||||
Some(ddos_cfg.threshold),
|
||||
) {
|
||||
@@ -210,6 +399,10 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("DDoS enabled but no model_path and use_ensemble=false; detection disabled");
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -245,7 +438,33 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
||||
// 2c. Load scanner model if configured.
|
||||
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
|
||||
if scanner_cfg.enabled {
|
||||
match scanner::model::ScannerModel::load(std::path::Path::new(&scanner_cfg.model_path)) {
|
||||
if scanner_cfg.use_ensemble {
|
||||
// Ensemble path: compiled-in weights, no model file needed.
|
||||
let detector = scanner::detector::ScannerDetector::new_ensemble(&cfg.routes);
|
||||
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
|
||||
|
||||
// Start bot allowlist if rules are configured.
|
||||
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
|
||||
let al = scanner::allowlist::BotAllowlist::spawn(
|
||||
&scanner_cfg.allowlist,
|
||||
scanner_cfg.bot_cache_ttl_secs,
|
||||
);
|
||||
tracing::info!(
|
||||
rules = scanner_cfg.allowlist.len(),
|
||||
"bot allowlist enabled"
|
||||
);
|
||||
Some(al)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
threshold = scanner_cfg.threshold,
|
||||
"scanner ensemble detector enabled"
|
||||
);
|
||||
(Some(handle), bot_allowlist)
|
||||
} else if let Some(ref model_path) = scanner_cfg.model_path {
|
||||
match scanner::model::ScannerModel::load(std::path::Path::new(model_path)) {
|
||||
Ok(mut model) => {
|
||||
let fragment_count = model.fragments.len();
|
||||
model.threshold = scanner_cfg.threshold;
|
||||
@@ -270,13 +489,13 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
||||
// Start background file watcher for hot-reload.
|
||||
if scanner_cfg.poll_interval_secs > 0 {
|
||||
let watcher_handle = handle.clone();
|
||||
let model_path = std::path::PathBuf::from(&scanner_cfg.model_path);
|
||||
let watcher_model_path = std::path::PathBuf::from(model_path);
|
||||
let threshold = scanner_cfg.threshold;
|
||||
let routes = cfg.routes.clone();
|
||||
let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs);
|
||||
std::thread::spawn(move || {
|
||||
scanner::watcher::watch_scanner_model(
|
||||
watcher_handle, model_path, threshold, routes, interval,
|
||||
watcher_handle, watcher_model_path, threshold, routes, interval,
|
||||
);
|
||||
});
|
||||
}
|
||||
@@ -294,6 +513,10 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
||||
(None, None)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("scanner enabled but no model_path and use_ensemble=false; scanner detection disabled");
|
||||
(None, None)
|
||||
}
|
||||
} else {
|
||||
(None, None)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user