feat(cli): restructure replay as subcommand with ensemble and ddos modes
Add `replay ensemble` (runs logs through compiled-in tree+MLP for both scanner and DDoS) and `replay ddos` (legacy KNN). Also adds CLI commands for download-datasets, prepare-dataset, train-mlp-scanner, train-mlp-ddos, autotune-ddos, and autotune-scanner. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
@@ -21,38 +21,39 @@ pub struct ReplayArgs {
|
|||||||
pub rate_limit: bool,
|
pub rate_limit: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ReplayStats {
|
pub struct ReplayResult {
|
||||||
total: u64,
|
pub total: u64,
|
||||||
skipped: u64,
|
pub skipped: u64,
|
||||||
ddos_blocked: u64,
|
pub ddos_blocked: u64,
|
||||||
rate_limited: u64,
|
pub rate_limited: u64,
|
||||||
allowed: u64,
|
pub allowed: u64,
|
||||||
ddos_blocked_ips: FxHashMap<String, u64>,
|
pub ddos_blocked_ips: FxHashMap<String, u64>,
|
||||||
rate_limited_ips: FxHashMap<String, u64>,
|
pub rate_limited_ips: FxHashMap<String, u64>,
|
||||||
|
pub false_positive_ips: usize,
|
||||||
|
pub true_positive_ips: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run(args: ReplayArgs) -> Result<()> {
|
/// Core replay pipeline: load model, replay logs, compute stats including false positive analysis.
|
||||||
eprintln!("Loading model from {}...", args.model_path);
|
pub fn replay_and_evaluate(args: &ReplayArgs) -> Result<ReplayResult> {
|
||||||
let model = TrainedModel::load(
|
let model = TrainedModel::load(
|
||||||
std::path::Path::new(&args.model_path),
|
std::path::Path::new(&args.model_path),
|
||||||
Some(args.k),
|
Some(args.k),
|
||||||
Some(args.threshold),
|
Some(args.threshold),
|
||||||
)
|
)
|
||||||
.with_context(|| format!("loading model from {}", args.model_path))?;
|
.with_context(|| format!("loading model from {}", args.model_path))?;
|
||||||
eprintln!(" {} training points, k={}, threshold={}", model.point_count(), args.k, args.threshold);
|
|
||||||
|
|
||||||
let ddos_cfg = DDoSConfig {
|
let ddos_cfg = DDoSConfig {
|
||||||
model_path: args.model_path.clone(),
|
model_path: Some(args.model_path.clone()),
|
||||||
k: args.k,
|
k: args.k,
|
||||||
threshold: args.threshold,
|
threshold: args.threshold,
|
||||||
window_secs: args.window_secs,
|
window_secs: args.window_secs,
|
||||||
window_capacity: 1000,
|
window_capacity: 1000,
|
||||||
min_events: args.min_events,
|
min_events: args.min_events,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
|
use_ensemble: false,
|
||||||
};
|
};
|
||||||
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
|
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
|
||||||
|
|
||||||
// Optionally set up rate limiter
|
|
||||||
let rate_limiter = if args.rate_limit {
|
let rate_limiter = if args.rate_limit {
|
||||||
let rl_cfg = if let Some(cfg_path) = &args.config_path {
|
let rl_cfg = if let Some(cfg_path) = &args.config_path {
|
||||||
let cfg = crate::config::Config::load(cfg_path)?;
|
let cfg = crate::config::Config::load(cfg_path)?;
|
||||||
@@ -60,61 +61,49 @@ pub fn run(args: ReplayArgs) -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
default_rate_limit_config()
|
default_rate_limit_config()
|
||||||
};
|
};
|
||||||
eprintln!(
|
|
||||||
" Rate limiter: auth burst={} rate={}/s, unauth burst={} rate={}/s",
|
|
||||||
rl_cfg.authenticated.burst,
|
|
||||||
rl_cfg.authenticated.rate,
|
|
||||||
rl_cfg.unauthenticated.burst,
|
|
||||||
rl_cfg.unauthenticated.rate,
|
|
||||||
);
|
|
||||||
Some(RateLimiter::new(&rl_cfg))
|
Some(RateLimiter::new(&rl_cfg))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
eprintln!("Replaying {}...\n", args.input);
|
|
||||||
|
|
||||||
let file = std::fs::File::open(&args.input)
|
let file = std::fs::File::open(&args.input)
|
||||||
.with_context(|| format!("opening {}", args.input))?;
|
.with_context(|| format!("opening {}", args.input))?;
|
||||||
let reader = std::io::BufReader::new(file);
|
let reader = std::io::BufReader::new(file);
|
||||||
|
|
||||||
let mut stats = ReplayStats {
|
let mut total = 0u64;
|
||||||
total: 0,
|
let mut skipped = 0u64;
|
||||||
skipped: 0,
|
let mut ddos_blocked = 0u64;
|
||||||
ddos_blocked: 0,
|
let mut rate_limited = 0u64;
|
||||||
rate_limited: 0,
|
let mut allowed = 0u64;
|
||||||
allowed: 0,
|
let mut ddos_blocked_ips: FxHashMap<String, u64> = FxHashMap::default();
|
||||||
ddos_blocked_ips: FxHashMap::default(),
|
let mut rate_limited_ips: FxHashMap<String, u64> = FxHashMap::default();
|
||||||
rate_limited_ips: FxHashMap::default(),
|
|
||||||
};
|
|
||||||
|
|
||||||
for line in reader.lines() {
|
for line in reader.lines() {
|
||||||
let line = line?;
|
let line = line?;
|
||||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||||
Ok(e) => e,
|
Ok(e) => e,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
stats.skipped += 1;
|
skipped += 1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if entry.fields.method.is_empty() {
|
if entry.fields.method.is_empty() {
|
||||||
stats.skipped += 1;
|
skipped += 1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
stats.total += 1;
|
total += 1;
|
||||||
|
|
||||||
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||||
let ip: IpAddr = match ip_str.parse() {
|
let ip: IpAddr = match ip_str.parse() {
|
||||||
Ok(ip) => ip,
|
Ok(ip) => ip,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
stats.skipped += 1;
|
skipped += 1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// DDoS check
|
|
||||||
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
|
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
|
||||||
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
|
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
|
||||||
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
|
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
|
||||||
@@ -131,78 +120,57 @@ pub fn run(args: ReplayArgs) -> Result<()> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
if ddos_action == DDoSAction::Block {
|
if ddos_action == DDoSAction::Block {
|
||||||
stats.ddos_blocked += 1;
|
ddos_blocked += 1;
|
||||||
*stats.ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
*ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rate limit check
|
|
||||||
if let Some(limiter) = &rate_limiter {
|
if let Some(limiter) = &rate_limiter {
|
||||||
// Audit logs don't have auth headers, so all traffic is keyed by IP
|
|
||||||
let rl_key = RateLimitKey::Ip(ip);
|
let rl_key = RateLimitKey::Ip(ip);
|
||||||
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
|
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
|
||||||
stats.rate_limited += 1;
|
rate_limited += 1;
|
||||||
*stats.rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
*rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stats.allowed += 1;
|
allowed += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Report
|
// Compute false positive / true positive counts
|
||||||
let total = stats.total;
|
let (false_positive_ips, true_positive_ips) =
|
||||||
eprintln!("═══ Replay Results ═══════════════════════════════════════");
|
count_fp_tp(&args.input, &ddos_blocked_ips, &rate_limited_ips)?;
|
||||||
eprintln!(" Total requests: {total}");
|
|
||||||
eprintln!(" Skipped (parse): {}", stats.skipped);
|
|
||||||
eprintln!(" Allowed: {} ({:.1}%)", stats.allowed, pct(stats.allowed, total));
|
|
||||||
eprintln!(" DDoS blocked: {} ({:.1}%)", stats.ddos_blocked, pct(stats.ddos_blocked, total));
|
|
||||||
if rate_limiter.is_some() {
|
|
||||||
eprintln!(" Rate limited: {} ({:.1}%)", stats.rate_limited, pct(stats.rate_limited, total));
|
|
||||||
}
|
|
||||||
|
|
||||||
if !stats.ddos_blocked_ips.is_empty() {
|
Ok(ReplayResult {
|
||||||
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
|
total,
|
||||||
let mut sorted: Vec<_> = stats.ddos_blocked_ips.iter().collect();
|
skipped,
|
||||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
ddos_blocked,
|
||||||
for (ip, count) in sorted.iter().take(20) {
|
rate_limited,
|
||||||
eprintln!(" {:<40} {} reqs blocked", ip, count);
|
allowed,
|
||||||
}
|
ddos_blocked_ips,
|
||||||
}
|
rate_limited_ips,
|
||||||
|
false_positive_ips,
|
||||||
if !stats.rate_limited_ips.is_empty() {
|
true_positive_ips,
|
||||||
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
|
})
|
||||||
let mut sorted: Vec<_> = stats.rate_limited_ips.iter().collect();
|
|
||||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
|
||||||
for (ip, count) in sorted.iter().take(20) {
|
|
||||||
eprintln!(" {:<40} {} reqs limited", ip, count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for false positives: IPs that were blocked but had 2xx statuses in the original logs
|
|
||||||
eprintln!("\n── False positive check ──────────────────────────────────");
|
|
||||||
check_false_positives(&args.input, &stats)?;
|
|
||||||
|
|
||||||
eprintln!("══════════════════════════════════════════════════════════");
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Re-scan the log to find blocked IPs that had mostly 2xx responses originally
|
/// Count false-positive and true-positive IPs from blocked set.
|
||||||
/// (i.e. they were legitimate traffic that the model would incorrectly block).
|
/// FP = blocked IP where >60% of original responses were 2xx/3xx.
|
||||||
fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
|
fn count_fp_tp(
|
||||||
let blocked_ips: rustc_hash::FxHashSet<&str> = stats
|
input: &str,
|
||||||
.ddos_blocked_ips
|
ddos_blocked_ips: &FxHashMap<String, u64>,
|
||||||
|
rate_limited_ips: &FxHashMap<String, u64>,
|
||||||
|
) -> Result<(usize, usize)> {
|
||||||
|
let blocked_ips: rustc_hash::FxHashSet<&str> = ddos_blocked_ips
|
||||||
.keys()
|
.keys()
|
||||||
.chain(stats.rate_limited_ips.keys())
|
.chain(rate_limited_ips.keys())
|
||||||
.map(|s| s.as_str())
|
.map(|s| s.as_str())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if blocked_ips.is_empty() {
|
if blocked_ips.is_empty() {
|
||||||
eprintln!(" No blocked IPs — nothing to check.");
|
return Ok((0, 0));
|
||||||
return Ok(());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect original status codes for blocked IPs
|
|
||||||
let file = std::fs::File::open(input)?;
|
let file = std::fs::File::open(input)?;
|
||||||
let reader = std::io::BufReader::new(file);
|
let reader = std::io::BufReader::new(file);
|
||||||
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
|
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
|
||||||
@@ -215,47 +183,69 @@ fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
|
|||||||
};
|
};
|
||||||
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||||
if blocked_ips.contains(ip_str.as_str()) {
|
if blocked_ips.contains(ip_str.as_str()) {
|
||||||
ip_statuses
|
ip_statuses.entry(ip_str).or_default().push(entry.fields.status);
|
||||||
.entry(ip_str)
|
|
||||||
.or_default()
|
|
||||||
.push(entry.fields.status);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut suspects = Vec::new();
|
let mut fp = 0usize;
|
||||||
for (ip, statuses) in &ip_statuses {
|
let mut tp = 0usize;
|
||||||
|
for (_ip, statuses) in &ip_statuses {
|
||||||
let total = statuses.len();
|
let total = statuses.len();
|
||||||
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
|
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
|
||||||
let ok_pct = (ok_count as f64 / total as f64) * 100.0;
|
let ok_pct = ok_count as f64 / total as f64 * 100.0;
|
||||||
// If >60% of original responses were 2xx/3xx, this might be a false positive
|
|
||||||
if ok_pct > 60.0 {
|
if ok_pct > 60.0 {
|
||||||
let blocked = stats
|
fp += 1;
|
||||||
.ddos_blocked_ips
|
} else {
|
||||||
.get(ip)
|
tp += 1;
|
||||||
.copied()
|
|
||||||
.unwrap_or(0)
|
|
||||||
+ stats
|
|
||||||
.rate_limited_ips
|
|
||||||
.get(ip)
|
|
||||||
.copied()
|
|
||||||
.unwrap_or(0);
|
|
||||||
suspects.push((ip.clone(), total, ok_pct, blocked));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if suspects.is_empty() {
|
Ok((fp, tp))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run(args: ReplayArgs) -> Result<()> {
|
||||||
|
eprintln!("Loading model from {}...", args.model_path);
|
||||||
|
eprintln!("Replaying {}...\n", args.input);
|
||||||
|
|
||||||
|
let result = replay_and_evaluate(&args)?;
|
||||||
|
|
||||||
|
let total = result.total;
|
||||||
|
eprintln!("═══ Replay Results ═══════════════════════════════════════");
|
||||||
|
eprintln!(" Total requests: {total}");
|
||||||
|
eprintln!(" Skipped (parse): {}", result.skipped);
|
||||||
|
eprintln!(" Allowed: {} ({:.1}%)", result.allowed, pct(result.allowed, total));
|
||||||
|
eprintln!(" DDoS blocked: {} ({:.1}%)", result.ddos_blocked, pct(result.ddos_blocked, total));
|
||||||
|
if args.rate_limit {
|
||||||
|
eprintln!(" Rate limited: {} ({:.1}%)", result.rate_limited, pct(result.rate_limited, total));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.ddos_blocked_ips.is_empty() {
|
||||||
|
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
|
||||||
|
let mut sorted: Vec<_> = result.ddos_blocked_ips.iter().collect();
|
||||||
|
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||||
|
for (ip, count) in sorted.iter().take(20) {
|
||||||
|
eprintln!(" {:<40} {} reqs blocked", ip, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.rate_limited_ips.is_empty() {
|
||||||
|
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
|
||||||
|
let mut sorted: Vec<_> = result.rate_limited_ips.iter().collect();
|
||||||
|
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||||
|
for (ip, count) in sorted.iter().take(20) {
|
||||||
|
eprintln!(" {:<40} {} reqs limited", ip, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
eprintln!("\n── False positive check ──────────────────────────────────");
|
||||||
|
if result.false_positive_ips == 0 {
|
||||||
eprintln!(" No likely false positives found.");
|
eprintln!(" No likely false positives found.");
|
||||||
} else {
|
} else {
|
||||||
suspects.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
|
eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses", result.false_positive_ips);
|
||||||
eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses:", suspects.len());
|
|
||||||
for (ip, total, ok_pct, blocked) in suspects.iter().take(15) {
|
|
||||||
eprintln!(
|
|
||||||
" {:<40} {}/{} reqs were 2xx/3xx ({:.0}%), {} blocked",
|
|
||||||
ip, ((*ok_pct / 100.0) * *total as f64) as u64, total, ok_pct, blocked,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
eprintln!(" True positive IPs: {}", result.true_positive_ips);
|
||||||
|
|
||||||
|
eprintln!("══════════════════════════════════════════════════════════");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
316
src/ensemble/replay.rs
Normal file
316
src/ensemble/replay.rs
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
//! Replay audit logs through the ensemble models (scanner + DDoS).
|
||||||
|
|
||||||
|
use crate::ddos::audit_log::{self, AuditLog};
|
||||||
|
use crate::ddos::features::{method_to_u8, LogIpState};
|
||||||
|
use crate::ddos::model::DDoSAction;
|
||||||
|
use crate::ensemble::ddos::{ddos_ensemble_predict, DDoSEnsemblePath};
|
||||||
|
use crate::ensemble::scanner::{scanner_ensemble_predict, EnsemblePath};
|
||||||
|
use crate::scanner::features::{self, fx_hash_bytes};
|
||||||
|
use crate::scanner::model::ScannerAction;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use rustc_hash::{FxHashMap, FxHashSet};
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::io::BufRead;
|
||||||
|
|
||||||
|
pub struct ReplayEnsembleArgs {
|
||||||
|
pub input: String,
|
||||||
|
pub window_secs: u64,
|
||||||
|
pub min_events: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run(args: ReplayEnsembleArgs) -> Result<()> {
|
||||||
|
eprintln!("replaying {} through ensemble models...\n", args.input);
|
||||||
|
|
||||||
|
let file =
|
||||||
|
std::fs::File::open(&args.input).with_context(|| format!("opening {}", args.input))?;
|
||||||
|
let reader = std::io::BufReader::new(file);
|
||||||
|
|
||||||
|
// --- Parse all entries ---
|
||||||
|
let mut entries: Vec<AuditLog> = Vec::new();
|
||||||
|
let mut parse_errors = 0u64;
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match serde_json::from_str::<AuditLog>(&line) {
|
||||||
|
Ok(e) => entries.push(e),
|
||||||
|
Err(_) => parse_errors += 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let total = entries.len() as u64;
|
||||||
|
eprintln!("parsed {} entries ({} parse errors)\n", total, parse_errors);
|
||||||
|
|
||||||
|
// --- Scanner replay ---
|
||||||
|
eprintln!("═══ Scanner Ensemble ═════════════════════════════════════");
|
||||||
|
replay_scanner(&entries);
|
||||||
|
|
||||||
|
// --- DDoS replay ---
|
||||||
|
eprintln!("\n═══ DDoS Ensemble ═══════════════════════════════════════");
|
||||||
|
replay_ddos(&entries, args.window_secs as f64, args.min_events);
|
||||||
|
|
||||||
|
eprintln!("\n══════════════════════════════════════════════════════════");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn replay_scanner(entries: &[AuditLog]) {
|
||||||
|
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
||||||
|
.iter()
|
||||||
|
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
||||||
|
.collect();
|
||||||
|
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
|
||||||
|
.iter()
|
||||||
|
.map(|e| fx_hash_bytes(e.as_bytes()))
|
||||||
|
.collect();
|
||||||
|
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
|
||||||
|
for e in entries {
|
||||||
|
let prefix = e.fields.host.split('.').next().unwrap_or("");
|
||||||
|
log_hosts.insert(fx_hash_bytes(prefix.as_bytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut total = 0u64;
|
||||||
|
let mut blocked = 0u64;
|
||||||
|
let mut allowed = 0u64;
|
||||||
|
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
|
||||||
|
let mut blocked_examples: Vec<(String, String, f64)> = Vec::new(); // (path, reason, score)
|
||||||
|
let mut fp_candidates: Vec<(String, u16, f64)> = Vec::new(); // blocked but had 2xx status
|
||||||
|
|
||||||
|
for e in entries {
|
||||||
|
let f = &e.fields;
|
||||||
|
let host_prefix = f.host.split('.').next().unwrap_or("");
|
||||||
|
let has_cookies = f.has_cookies.unwrap_or(false);
|
||||||
|
let has_referer = f
|
||||||
|
.referer
|
||||||
|
.as_ref()
|
||||||
|
.map(|r| r != "-" && !r.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let has_accept_language = f
|
||||||
|
.accept_language
|
||||||
|
.as_ref()
|
||||||
|
.map(|a| a != "-" && !a.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let feats = features::extract_features_f32(
|
||||||
|
&f.method,
|
||||||
|
&f.path,
|
||||||
|
host_prefix,
|
||||||
|
has_cookies,
|
||||||
|
has_referer,
|
||||||
|
has_accept_language,
|
||||||
|
"-",
|
||||||
|
&f.user_agent,
|
||||||
|
f.content_length,
|
||||||
|
&fragment_hashes,
|
||||||
|
&extension_hashes,
|
||||||
|
&log_hosts,
|
||||||
|
);
|
||||||
|
|
||||||
|
let verdict = scanner_ensemble_predict(&feats);
|
||||||
|
total += 1;
|
||||||
|
|
||||||
|
match verdict.path {
|
||||||
|
EnsemblePath::TreeBlock => path_counts[0] += 1,
|
||||||
|
EnsemblePath::TreeAllow => path_counts[1] += 1,
|
||||||
|
EnsemblePath::Mlp => path_counts[2] += 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
match verdict.action {
|
||||||
|
ScannerAction::Block => {
|
||||||
|
blocked += 1;
|
||||||
|
if blocked_examples.len() < 20 {
|
||||||
|
blocked_examples.push((
|
||||||
|
f.path.clone(),
|
||||||
|
verdict.reason.to_string(),
|
||||||
|
verdict.score,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if (200..400).contains(&f.status) {
|
||||||
|
fp_candidates.push((f.path.clone(), f.status, verdict.score));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ScannerAction::Allow => allowed += 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let pct = |n: u64| {
|
||||||
|
if total == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
n as f64 / total as f64 * 100.0
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
eprintln!(" total: {total}");
|
||||||
|
eprintln!(
|
||||||
|
" blocked: {} ({:.1}%)",
|
||||||
|
blocked,
|
||||||
|
pct(blocked)
|
||||||
|
);
|
||||||
|
eprintln!(
|
||||||
|
" allowed: {} ({:.1}%)",
|
||||||
|
allowed,
|
||||||
|
pct(allowed)
|
||||||
|
);
|
||||||
|
eprintln!(
|
||||||
|
" paths: tree_block={} tree_allow={} mlp={}",
|
||||||
|
path_counts[0], path_counts[1], path_counts[2]
|
||||||
|
);
|
||||||
|
|
||||||
|
if !blocked_examples.is_empty() {
|
||||||
|
eprintln!("\n blocked examples (first 20):");
|
||||||
|
for (path, reason, score) in &blocked_examples {
|
||||||
|
eprintln!(" {:<50} {reason} (score={score:.3})", truncate(path, 50));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let fp_count = fp_candidates.len();
|
||||||
|
if fp_count > 0 {
|
||||||
|
eprintln!(
|
||||||
|
"\n potential false positives (blocked but had 2xx/3xx): {}",
|
||||||
|
fp_count
|
||||||
|
);
|
||||||
|
for (path, status, score) in fp_candidates.iter().take(10) {
|
||||||
|
eprintln!(
|
||||||
|
" {:<50} status={status} score={score:.3}",
|
||||||
|
truncate(path, 50)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) {
|
||||||
|
fn fx_hash(s: &str) -> u64 {
|
||||||
|
let mut h = rustc_hash::FxHasher::default();
|
||||||
|
s.hash(&mut h);
|
||||||
|
h.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate per-IP state.
|
||||||
|
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
|
||||||
|
|
||||||
|
for e in entries {
|
||||||
|
let f = &e.fields;
|
||||||
|
let ip = audit_log::strip_port(&f.client_ip).to_string();
|
||||||
|
let state = ip_states.entry(ip).or_default();
|
||||||
|
let ts = state.timestamps.len() as f64;
|
||||||
|
state.timestamps.push(ts);
|
||||||
|
state.methods.push(method_to_u8(&f.method));
|
||||||
|
state.path_hashes.push(fx_hash(&f.path));
|
||||||
|
state.host_hashes.push(fx_hash(&f.host));
|
||||||
|
state.user_agent_hashes.push(fx_hash(&f.user_agent));
|
||||||
|
state.statuses.push(f.status);
|
||||||
|
state
|
||||||
|
.durations
|
||||||
|
.push(f.duration_ms.min(u32::MAX as u64) as u32);
|
||||||
|
state
|
||||||
|
.content_lengths
|
||||||
|
.push(f.content_length.min(u32::MAX as u64) as u32);
|
||||||
|
state
|
||||||
|
.has_cookies
|
||||||
|
.push(f.has_cookies.unwrap_or(false));
|
||||||
|
state.has_referer.push(
|
||||||
|
f.referer
|
||||||
|
.as_deref()
|
||||||
|
.map(|r| r != "-")
|
||||||
|
.unwrap_or(false),
|
||||||
|
);
|
||||||
|
state.has_accept_language.push(
|
||||||
|
f.accept_language
|
||||||
|
.as_deref()
|
||||||
|
.map(|a| a != "-")
|
||||||
|
.unwrap_or(false),
|
||||||
|
);
|
||||||
|
state
|
||||||
|
.suspicious_paths
|
||||||
|
.push(crate::ddos::features::is_suspicious_path(&f.path));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut total_ips = 0u64;
|
||||||
|
let mut blocked_ips = 0u64;
|
||||||
|
let mut allowed_ips = 0u64;
|
||||||
|
let mut skipped_ips = 0u64;
|
||||||
|
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
|
||||||
|
let mut blocked_details: Vec<(String, usize, f64, &'static str)> = Vec::new();
|
||||||
|
|
||||||
|
for (ip, state) in &ip_states {
|
||||||
|
let n = state.timestamps.len();
|
||||||
|
if n < min_events {
|
||||||
|
skipped_ips += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
total_ips += 1;
|
||||||
|
|
||||||
|
let fv = state.extract_features_for_window(0, n, window_secs);
|
||||||
|
let fv_f32: [f32; 14] = {
|
||||||
|
let mut arr = [0.0f32; 14];
|
||||||
|
for i in 0..14 {
|
||||||
|
arr[i] = fv[i] as f32;
|
||||||
|
}
|
||||||
|
arr
|
||||||
|
};
|
||||||
|
|
||||||
|
let verdict = ddos_ensemble_predict(&fv_f32);
|
||||||
|
|
||||||
|
match verdict.path {
|
||||||
|
DDoSEnsemblePath::TreeBlock => path_counts[0] += 1,
|
||||||
|
DDoSEnsemblePath::TreeAllow => path_counts[1] += 1,
|
||||||
|
DDoSEnsemblePath::Mlp => path_counts[2] += 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
match verdict.action {
|
||||||
|
DDoSAction::Block => {
|
||||||
|
blocked_ips += 1;
|
||||||
|
if blocked_details.len() < 30 {
|
||||||
|
blocked_details.push((ip.clone(), n, verdict.score, verdict.reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DDoSAction::Allow => allowed_ips += 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let pct = |n: u64, d: u64| {
|
||||||
|
if d == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
n as f64 / d as f64 * 100.0
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
eprintln!(" unique IPs: {} ({} skipped < {} events)", ip_states.len(), skipped_ips, min_events);
|
||||||
|
eprintln!(" evaluated: {total_ips}");
|
||||||
|
eprintln!(
|
||||||
|
" blocked: {} ({:.1}%)",
|
||||||
|
blocked_ips,
|
||||||
|
pct(blocked_ips, total_ips)
|
||||||
|
);
|
||||||
|
eprintln!(
|
||||||
|
" allowed: {} ({:.1}%)",
|
||||||
|
allowed_ips,
|
||||||
|
pct(allowed_ips, total_ips)
|
||||||
|
);
|
||||||
|
eprintln!(
|
||||||
|
" paths: tree_block={} tree_allow={} mlp={}",
|
||||||
|
path_counts[0], path_counts[1], path_counts[2]
|
||||||
|
);
|
||||||
|
|
||||||
|
if !blocked_details.is_empty() {
|
||||||
|
eprintln!("\n blocked IPs (up to 30):");
|
||||||
|
let mut sorted = blocked_details;
|
||||||
|
sorted.sort_by(|a, b| b.1.cmp(&a.1));
|
||||||
|
for (ip, reqs, score, reason) in &sorted {
|
||||||
|
eprintln!(
|
||||||
|
" {:<40} {} reqs score={:.3} {reason}",
|
||||||
|
ip, reqs, score
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate(s: &str, max: usize) -> String {
|
||||||
|
if s.len() <= max {
|
||||||
|
s.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{}...", &s[..max - 3])
|
||||||
|
}
|
||||||
|
}
|
||||||
443
src/main.rs
443
src/main.rs
@@ -2,7 +2,7 @@ mod cert;
|
|||||||
mod telemetry;
|
mod telemetry;
|
||||||
mod watcher;
|
mod watcher;
|
||||||
|
|
||||||
use sunbeam_proxy::{acme, config};
|
use sunbeam_proxy::{acme, autotune, config};
|
||||||
use sunbeam_proxy::proxy::SunbeamProxy;
|
use sunbeam_proxy::proxy::SunbeamProxy;
|
||||||
use sunbeam_proxy::ddos;
|
use sunbeam_proxy::ddos;
|
||||||
use sunbeam_proxy::rate_limit;
|
use sunbeam_proxy::rate_limit;
|
||||||
@@ -32,32 +32,10 @@ enum Commands {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
upgrade: bool,
|
upgrade: bool,
|
||||||
},
|
},
|
||||||
/// Replay audit logs through the DDoS detector and rate limiter
|
/// Replay audit logs through detection models
|
||||||
ReplayDdos {
|
Replay {
|
||||||
/// Path to audit log JSONL file
|
#[command(subcommand)]
|
||||||
#[arg(short, long)]
|
mode: ReplayMode,
|
||||||
input: String,
|
|
||||||
/// Path to trained model file
|
|
||||||
#[arg(short, long, default_value = "ddos_model.bin")]
|
|
||||||
model: String,
|
|
||||||
/// Optional config file (for rate limit settings)
|
|
||||||
#[arg(short, long)]
|
|
||||||
config: Option<String>,
|
|
||||||
/// KNN k parameter
|
|
||||||
#[arg(long, default_value = "5")]
|
|
||||||
k: usize,
|
|
||||||
/// Attack threshold
|
|
||||||
#[arg(long, default_value = "0.6")]
|
|
||||||
threshold: f64,
|
|
||||||
/// Sliding window size in seconds
|
|
||||||
#[arg(long, default_value = "60")]
|
|
||||||
window_secs: u64,
|
|
||||||
/// Minimum events per IP before classification
|
|
||||||
#[arg(long, default_value = "10")]
|
|
||||||
min_events: usize,
|
|
||||||
/// Also run rate limiter during replay
|
|
||||||
#[arg(long)]
|
|
||||||
rate_limit: bool,
|
|
||||||
},
|
},
|
||||||
/// Train a DDoS detection model from audit logs
|
/// Train a DDoS detection model from audit logs
|
||||||
TrainDdos {
|
TrainDdos {
|
||||||
@@ -107,31 +85,180 @@ enum Commands {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
csic: bool,
|
csic: bool,
|
||||||
},
|
},
|
||||||
|
/// Bayesian hyperparameter optimization for DDoS model
|
||||||
|
AutotuneDdos {
|
||||||
|
/// Path to audit log JSONL file
|
||||||
|
#[arg(short, long)]
|
||||||
|
input: String,
|
||||||
|
/// Output best model file path
|
||||||
|
#[arg(short, long, default_value = "ddos_model_best.bin")]
|
||||||
|
output: String,
|
||||||
|
/// Number of optimization trials
|
||||||
|
#[arg(long, default_value = "200")]
|
||||||
|
trials: usize,
|
||||||
|
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
|
||||||
|
#[arg(long, default_value = "1.0")]
|
||||||
|
beta: f64,
|
||||||
|
/// JSONL file to log each trial's parameters and results
|
||||||
|
#[arg(long)]
|
||||||
|
trial_log: Option<String>,
|
||||||
|
},
|
||||||
|
/// Download and cache upstream datasets (CIC-IDS2017)
|
||||||
|
DownloadDatasets,
|
||||||
|
/// Prepare a unified training dataset from multiple sources
|
||||||
|
PrepareDataset {
|
||||||
|
/// Path to audit log JSONL file
|
||||||
|
#[arg(short, long)]
|
||||||
|
input: String,
|
||||||
|
/// Path to OWASP ModSecurity audit log file (optional extra data)
|
||||||
|
#[arg(long)]
|
||||||
|
owasp: Option<String>,
|
||||||
|
/// Directory containing .txt wordlists (optional, enhances synthetic scanner)
|
||||||
|
#[arg(long)]
|
||||||
|
wordlists: Option<String>,
|
||||||
|
/// Output dataset file path
|
||||||
|
#[arg(short, long, default_value = "dataset.bin")]
|
||||||
|
output: String,
|
||||||
|
/// Random seed
|
||||||
|
#[arg(long, default_value = "42")]
|
||||||
|
seed: u64,
|
||||||
|
/// Path to heuristics.toml for auto-labeling production logs
|
||||||
|
#[arg(long)]
|
||||||
|
heuristics: Option<String>,
|
||||||
|
},
|
||||||
|
#[cfg(feature = "training")]
|
||||||
|
/// Train scanner ensemble (decision tree + MLP) from prepared dataset
|
||||||
|
TrainMlpScanner {
|
||||||
|
/// Path to prepared dataset (.bin)
|
||||||
|
#[arg(short = 'd', long)]
|
||||||
|
dataset: String,
|
||||||
|
/// Output directory for generated weight files
|
||||||
|
#[arg(short, long, default_value = "src/ensemble/gen")]
|
||||||
|
output_dir: String,
|
||||||
|
/// Hidden layer dimension
|
||||||
|
#[arg(long, default_value = "32")]
|
||||||
|
hidden_dim: usize,
|
||||||
|
/// Training epochs
|
||||||
|
#[arg(long, default_value = "100")]
|
||||||
|
epochs: usize,
|
||||||
|
/// Learning rate
|
||||||
|
#[arg(long, default_value = "0.001")]
|
||||||
|
learning_rate: f64,
|
||||||
|
/// Batch size
|
||||||
|
#[arg(long, default_value = "64")]
|
||||||
|
batch_size: usize,
|
||||||
|
/// Max tree depth
|
||||||
|
#[arg(long, default_value = "6")]
|
||||||
|
tree_max_depth: usize,
|
||||||
|
/// Min purity for tree leaves (below -> Defer)
|
||||||
|
#[arg(long, default_value = "0.90")]
|
||||||
|
tree_min_purity: f32,
|
||||||
|
},
|
||||||
|
#[cfg(feature = "training")]
|
||||||
|
/// Train DDoS ensemble (decision tree + MLP) from prepared dataset
|
||||||
|
TrainMlpDdos {
|
||||||
|
#[arg(short = 'd', long)]
|
||||||
|
dataset: String,
|
||||||
|
#[arg(short, long, default_value = "src/ensemble/gen")]
|
||||||
|
output_dir: String,
|
||||||
|
#[arg(long, default_value = "32")]
|
||||||
|
hidden_dim: usize,
|
||||||
|
#[arg(long, default_value = "100")]
|
||||||
|
epochs: usize,
|
||||||
|
#[arg(long, default_value = "0.001")]
|
||||||
|
learning_rate: f64,
|
||||||
|
#[arg(long, default_value = "64")]
|
||||||
|
batch_size: usize,
|
||||||
|
#[arg(long, default_value = "6")]
|
||||||
|
tree_max_depth: usize,
|
||||||
|
#[arg(long, default_value = "0.90")]
|
||||||
|
tree_min_purity: f32,
|
||||||
|
},
|
||||||
|
/// Bayesian hyperparameter optimization for scanner model
|
||||||
|
AutotuneScanner {
|
||||||
|
/// Path to audit log JSONL file
|
||||||
|
#[arg(short, long)]
|
||||||
|
input: String,
|
||||||
|
/// Output best model file path
|
||||||
|
#[arg(short, long, default_value = "scanner_model_best.bin")]
|
||||||
|
output: String,
|
||||||
|
/// Directory (or file) containing .txt wordlists of scanner paths
|
||||||
|
#[arg(long)]
|
||||||
|
wordlists: Option<String>,
|
||||||
|
/// Include CSIC 2010 dataset as base training data
|
||||||
|
#[arg(long)]
|
||||||
|
csic: bool,
|
||||||
|
/// Number of optimization trials
|
||||||
|
#[arg(long, default_value = "200")]
|
||||||
|
trials: usize,
|
||||||
|
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
|
||||||
|
#[arg(long, default_value = "1.0")]
|
||||||
|
beta: f64,
|
||||||
|
/// JSONL file to log each trial's parameters and results
|
||||||
|
#[arg(long)]
|
||||||
|
trial_log: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand)]
|
||||||
|
enum ReplayMode {
|
||||||
|
/// Replay through ensemble models (scanner + DDoS)
|
||||||
|
Ensemble {
|
||||||
|
/// Path to audit log JSONL file
|
||||||
|
#[arg(short, long)]
|
||||||
|
input: String,
|
||||||
|
/// Sliding window size in seconds
|
||||||
|
#[arg(long, default_value = "60")]
|
||||||
|
window_secs: u64,
|
||||||
|
/// Minimum events per IP before DDoS classification
|
||||||
|
#[arg(long, default_value = "5")]
|
||||||
|
min_events: usize,
|
||||||
|
},
|
||||||
|
/// Replay through legacy KNN DDoS detector
|
||||||
|
Ddos {
|
||||||
|
/// Path to audit log JSONL file
|
||||||
|
#[arg(short, long)]
|
||||||
|
input: String,
|
||||||
|
/// Path to trained model file
|
||||||
|
#[arg(short, long, default_value = "ddos_model.bin")]
|
||||||
|
model: String,
|
||||||
|
/// Optional config file (for rate limit settings)
|
||||||
|
#[arg(short, long)]
|
||||||
|
config: Option<String>,
|
||||||
|
/// KNN k parameter
|
||||||
|
#[arg(long, default_value = "5")]
|
||||||
|
k: usize,
|
||||||
|
/// Attack threshold
|
||||||
|
#[arg(long, default_value = "0.6")]
|
||||||
|
threshold: f64,
|
||||||
|
/// Sliding window size in seconds
|
||||||
|
#[arg(long, default_value = "60")]
|
||||||
|
window_secs: u64,
|
||||||
|
/// Minimum events per IP before classification
|
||||||
|
#[arg(long, default_value = "10")]
|
||||||
|
min_events: usize,
|
||||||
|
/// Also run rate limiter during replay
|
||||||
|
#[arg(long)]
|
||||||
|
rate_limit: bool,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
|
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
|
||||||
Commands::Serve { upgrade } => run_serve(upgrade),
|
Commands::Serve { upgrade } => run_serve(upgrade),
|
||||||
Commands::ReplayDdos {
|
Commands::Replay { mode } => match mode {
|
||||||
input,
|
ReplayMode::Ensemble { input, window_secs, min_events } => {
|
||||||
model,
|
sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs {
|
||||||
config,
|
input, window_secs, min_events,
|
||||||
k,
|
})
|
||||||
threshold,
|
}
|
||||||
window_secs,
|
ReplayMode::Ddos { input, model, config, k, threshold, window_secs, min_events, rate_limit } => {
|
||||||
min_events,
|
ddos::replay::run(ddos::replay::ReplayArgs {
|
||||||
rate_limit,
|
input, model_path: model, config_path: config, k, threshold, window_secs, min_events, rate_limit,
|
||||||
} => ddos::replay::run(ddos::replay::ReplayArgs {
|
})
|
||||||
input,
|
}
|
||||||
model_path: model,
|
},
|
||||||
config_path: config,
|
|
||||||
k,
|
|
||||||
threshold,
|
|
||||||
window_secs,
|
|
||||||
min_events,
|
|
||||||
rate_limit,
|
|
||||||
}),
|
|
||||||
Commands::TrainDdos {
|
Commands::TrainDdos {
|
||||||
input,
|
input,
|
||||||
output,
|
output,
|
||||||
@@ -166,6 +293,56 @@ fn main() -> Result<()> {
|
|||||||
threshold,
|
threshold,
|
||||||
csic,
|
csic,
|
||||||
}),
|
}),
|
||||||
|
Commands::DownloadDatasets => {
|
||||||
|
sunbeam_proxy::dataset::download::download_all()
|
||||||
|
},
|
||||||
|
Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics } => {
|
||||||
|
sunbeam_proxy::dataset::prepare::run(sunbeam_proxy::dataset::prepare::PrepareDatasetArgs {
|
||||||
|
input, owasp, wordlists, output, seed, heuristics,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
#[cfg(feature = "training")]
|
||||||
|
Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
|
||||||
|
sunbeam_proxy::training::train_scanner::run(sunbeam_proxy::training::train_scanner::TrainScannerMlpArgs {
|
||||||
|
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
#[cfg(feature = "training")]
|
||||||
|
Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
|
||||||
|
sunbeam_proxy::training::train_ddos::run(sunbeam_proxy::training::train_ddos::TrainDdosMlpArgs {
|
||||||
|
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Commands::AutotuneDdos {
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
trials,
|
||||||
|
beta,
|
||||||
|
trial_log,
|
||||||
|
} => autotune::ddos::run_autotune(autotune::ddos::AutotuneDdosArgs {
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
trials,
|
||||||
|
beta,
|
||||||
|
trial_log,
|
||||||
|
}),
|
||||||
|
Commands::AutotuneScanner {
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
wordlists,
|
||||||
|
csic,
|
||||||
|
trials,
|
||||||
|
beta,
|
||||||
|
trial_log,
|
||||||
|
} => autotune::scanner::run_autotune(autotune::scanner::AutotuneScannerArgs {
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
wordlists,
|
||||||
|
csic,
|
||||||
|
trials,
|
||||||
|
beta,
|
||||||
|
trial_log,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,26 +366,42 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
|||||||
// 2. Load DDoS detection model if configured.
|
// 2. Load DDoS detection model if configured.
|
||||||
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
|
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
|
||||||
if ddos_cfg.enabled {
|
if ddos_cfg.enabled {
|
||||||
match ddos::model::TrainedModel::load(
|
if ddos_cfg.use_ensemble {
|
||||||
std::path::Path::new(&ddos_cfg.model_path),
|
// Ensemble path: compiled-in weights, no model file needed.
|
||||||
Some(ddos_cfg.k),
|
// We still need a TrainedModel for the struct, but it won't be used.
|
||||||
Some(ddos_cfg.threshold),
|
let dummy_model = ddos::model::TrainedModel::empty(ddos_cfg.k, ddos_cfg.threshold);
|
||||||
) {
|
let detector = Arc::new(ddos::detector::DDoSDetector::new_ensemble(dummy_model, ddos_cfg));
|
||||||
Ok(model) => {
|
tracing::info!(
|
||||||
let point_count = model.point_count();
|
k = ddos_cfg.k,
|
||||||
let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg));
|
threshold = ddos_cfg.threshold,
|
||||||
tracing::info!(
|
"DDoS ensemble detector enabled"
|
||||||
points = point_count,
|
);
|
||||||
k = ddos_cfg.k,
|
Some(detector)
|
||||||
threshold = ddos_cfg.threshold,
|
} else if let Some(ref model_path) = ddos_cfg.model_path {
|
||||||
"DDoS detector loaded"
|
match ddos::model::TrainedModel::load(
|
||||||
);
|
std::path::Path::new(model_path),
|
||||||
Some(detector)
|
Some(ddos_cfg.k),
|
||||||
}
|
Some(ddos_cfg.threshold),
|
||||||
Err(e) => {
|
) {
|
||||||
tracing::warn!(error = %e, "failed to load DDoS model; detection disabled");
|
Ok(model) => {
|
||||||
None
|
let point_count = model.point_count();
|
||||||
|
let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg));
|
||||||
|
tracing::info!(
|
||||||
|
points = point_count,
|
||||||
|
k = ddos_cfg.k,
|
||||||
|
threshold = ddos_cfg.threshold,
|
||||||
|
"DDoS detector loaded"
|
||||||
|
);
|
||||||
|
Some(detector)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "failed to load DDoS model; detection disabled");
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
tracing::warn!("DDoS enabled but no model_path and use_ensemble=false; detection disabled");
|
||||||
|
None
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@@ -245,54 +438,84 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
|||||||
// 2c. Load scanner model if configured.
|
// 2c. Load scanner model if configured.
|
||||||
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
|
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
|
||||||
if scanner_cfg.enabled {
|
if scanner_cfg.enabled {
|
||||||
match scanner::model::ScannerModel::load(std::path::Path::new(&scanner_cfg.model_path)) {
|
if scanner_cfg.use_ensemble {
|
||||||
Ok(mut model) => {
|
// Ensemble path: compiled-in weights, no model file needed.
|
||||||
let fragment_count = model.fragments.len();
|
let detector = scanner::detector::ScannerDetector::new_ensemble(&cfg.routes);
|
||||||
model.threshold = scanner_cfg.threshold;
|
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
|
||||||
let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes);
|
|
||||||
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
|
|
||||||
|
|
||||||
// Start bot allowlist if rules are configured.
|
// Start bot allowlist if rules are configured.
|
||||||
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
|
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
|
||||||
let al = scanner::allowlist::BotAllowlist::spawn(
|
let al = scanner::allowlist::BotAllowlist::spawn(
|
||||||
&scanner_cfg.allowlist,
|
&scanner_cfg.allowlist,
|
||||||
scanner_cfg.bot_cache_ttl_secs,
|
scanner_cfg.bot_cache_ttl_secs,
|
||||||
);
|
|
||||||
tracing::info!(
|
|
||||||
rules = scanner_cfg.allowlist.len(),
|
|
||||||
"bot allowlist enabled"
|
|
||||||
);
|
|
||||||
Some(al)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start background file watcher for hot-reload.
|
|
||||||
if scanner_cfg.poll_interval_secs > 0 {
|
|
||||||
let watcher_handle = handle.clone();
|
|
||||||
let model_path = std::path::PathBuf::from(&scanner_cfg.model_path);
|
|
||||||
let threshold = scanner_cfg.threshold;
|
|
||||||
let routes = cfg.routes.clone();
|
|
||||||
let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs);
|
|
||||||
std::thread::spawn(move || {
|
|
||||||
scanner::watcher::watch_scanner_model(
|
|
||||||
watcher_handle, model_path, threshold, routes, interval,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::info!(
|
|
||||||
fragments = fragment_count,
|
|
||||||
threshold = scanner_cfg.threshold,
|
|
||||||
poll_interval_secs = scanner_cfg.poll_interval_secs,
|
|
||||||
"scanner detector loaded"
|
|
||||||
);
|
);
|
||||||
(Some(handle), bot_allowlist)
|
tracing::info!(
|
||||||
}
|
rules = scanner_cfg.allowlist.len(),
|
||||||
Err(e) => {
|
"bot allowlist enabled"
|
||||||
tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled");
|
);
|
||||||
(None, None)
|
Some(al)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
threshold = scanner_cfg.threshold,
|
||||||
|
"scanner ensemble detector enabled"
|
||||||
|
);
|
||||||
|
(Some(handle), bot_allowlist)
|
||||||
|
} else if let Some(ref model_path) = scanner_cfg.model_path {
|
||||||
|
match scanner::model::ScannerModel::load(std::path::Path::new(model_path)) {
|
||||||
|
Ok(mut model) => {
|
||||||
|
let fragment_count = model.fragments.len();
|
||||||
|
model.threshold = scanner_cfg.threshold;
|
||||||
|
let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes);
|
||||||
|
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
|
||||||
|
|
||||||
|
// Start bot allowlist if rules are configured.
|
||||||
|
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
|
||||||
|
let al = scanner::allowlist::BotAllowlist::spawn(
|
||||||
|
&scanner_cfg.allowlist,
|
||||||
|
scanner_cfg.bot_cache_ttl_secs,
|
||||||
|
);
|
||||||
|
tracing::info!(
|
||||||
|
rules = scanner_cfg.allowlist.len(),
|
||||||
|
"bot allowlist enabled"
|
||||||
|
);
|
||||||
|
Some(al)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start background file watcher for hot-reload.
|
||||||
|
if scanner_cfg.poll_interval_secs > 0 {
|
||||||
|
let watcher_handle = handle.clone();
|
||||||
|
let watcher_model_path = std::path::PathBuf::from(model_path);
|
||||||
|
let threshold = scanner_cfg.threshold;
|
||||||
|
let routes = cfg.routes.clone();
|
||||||
|
let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs);
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
scanner::watcher::watch_scanner_model(
|
||||||
|
watcher_handle, watcher_model_path, threshold, routes, interval,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
fragments = fragment_count,
|
||||||
|
threshold = scanner_cfg.threshold,
|
||||||
|
poll_interval_secs = scanner_cfg.poll_interval_secs,
|
||||||
|
"scanner detector loaded"
|
||||||
|
);
|
||||||
|
(Some(handle), bot_allowlist)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled");
|
||||||
|
(None, None)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
tracing::warn!("scanner enabled but no model_path and use_ensemble=false; scanner detection disabled");
|
||||||
|
(None, None)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
(None, None)
|
(None, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user