chore: update scanner/ddos trainers, benchmarks, and tests

Expand DDoS feature vector to 14 dimensions (cookie_ratio, referer_ratio,
accept_language_ratio, suspicious_path_ratio). Add heuristic auto-labeling
to DDoS trainer. Update benchmarks and tests to match new feature vectors.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:21 +00:00
parent 905fd78299
commit 5daed3ecb0
5 changed files with 599 additions and 160 deletions

40
benches/ddos_bench.rs Normal file
View File

@@ -0,0 +1,40 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use sunbeam_proxy::ensemble::ddos::ddos_ensemble_predict;
use sunbeam_proxy::ensemble::gen::ddos_weights;
use sunbeam_proxy::ensemble::mlp::mlp_predict_32;
use sunbeam_proxy::ensemble::tree::tree_predict;
fn bench_ensemble_ddos_full(c: &mut Criterion) {
let raw: [f32; 14] = [5.0, 10.0, 2.0, 0.1, 50.0, 0.5, 3.0, 0.3, 500.0, 2.0, 0.8, 0.7, 0.9, 0.1];
c.bench_function("ensemble::ddos full predict", |b| {
b.iter(|| ddos_ensemble_predict(black_box(&raw)))
});
}
fn bench_ensemble_ddos_tree_only(c: &mut Criterion) {
let input: [f32; 14] = [0.5; 14];
c.bench_function("ensemble::ddos tree_only", |b| {
b.iter(|| tree_predict(black_box(&ddos_weights::TREE_NODES), black_box(&input)))
});
}
fn bench_ensemble_ddos_mlp_only(c: &mut Criterion) {
let input: [f32; 14] = [0.5; 14];
c.bench_function("ensemble::ddos mlp_only", |b| {
b.iter(|| mlp_predict_32::<14>(
black_box(&ddos_weights::W1),
black_box(&ddos_weights::B1),
black_box(&ddos_weights::W2),
black_box(ddos_weights::B2),
black_box(&input),
))
});
}
criterion_group!(
benches,
bench_ensemble_ddos_full,
bench_ensemble_ddos_tree_only,
bench_ensemble_ddos_mlp_only,
);
criterion_main!(benches);

View File

@@ -1,5 +1,9 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion}; use criterion::{black_box, criterion_group, criterion_main, Criterion};
use sunbeam_proxy::config::RouteConfig; use sunbeam_proxy::config::RouteConfig;
use sunbeam_proxy::ensemble::gen::scanner_weights;
use sunbeam_proxy::ensemble::mlp::mlp_predict_32;
use sunbeam_proxy::ensemble::scanner::scanner_ensemble_predict;
use sunbeam_proxy::ensemble::tree::tree_predict;
use sunbeam_proxy::scanner::detector::ScannerDetector; use sunbeam_proxy::scanner::detector::ScannerDetector;
use sunbeam_proxy::scanner::features::{ use sunbeam_proxy::scanner::features::{
self, fx_hash_bytes, ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS, self, fx_hash_bytes, ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS,
@@ -250,6 +254,34 @@ fn bench_extract_features(c: &mut Criterion) {
}); });
} }
fn bench_ensemble_scanner_full(c: &mut Criterion) {
// Raw features simulating a scanner probe
let raw: [f32; 12] = [0.8, 3.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 1.0];
c.bench_function("ensemble::scanner full predict", |b| {
b.iter(|| scanner_ensemble_predict(black_box(&raw)))
});
}
fn bench_ensemble_scanner_tree_only(c: &mut Criterion) {
let input: [f32; 12] = [0.8, 0.3, 1.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 1.0];
c.bench_function("ensemble::scanner tree_only", |b| {
b.iter(|| tree_predict(black_box(&scanner_weights::TREE_NODES), black_box(&input)))
});
}
fn bench_ensemble_scanner_mlp_only(c: &mut Criterion) {
let input: [f32; 12] = [0.5; 12];
c.bench_function("ensemble::scanner mlp_only", |b| {
b.iter(|| mlp_predict_32::<12>(
black_box(&scanner_weights::W1),
black_box(&scanner_weights::B1),
black_box(&scanner_weights::W2),
black_box(scanner_weights::B2),
black_box(&input),
))
});
}
criterion_group!( criterion_group!(
benches, benches,
bench_check_normal_browser, bench_check_normal_browser,
@@ -260,5 +292,8 @@ criterion_group!(
bench_check_deep_path, bench_check_deep_path,
bench_check_api_legitimate, bench_check_api_legitimate,
bench_extract_features, bench_extract_features,
bench_ensemble_scanner_full,
bench_ensemble_scanner_tree_only,
bench_ensemble_scanner_mlp_only,
); );
criterion_main!(benches); criterion_main!(benches);

View File

@@ -41,6 +41,34 @@ fn default_no_cookies_threshold() -> f64 { 0.05 }
fn default_no_cookies_path_count() -> f64 { 20.0 } fn default_no_cookies_path_count() -> f64 { 20.0 }
fn default_min_events() -> usize { 10 } fn default_min_events() -> usize { 10 }
impl HeuristicThresholds {
pub fn new(
request_rate: f64,
path_repetition: f64,
error_rate: f64,
suspicious_path_ratio: f64,
no_cookies_threshold: f64,
no_cookies_path_count: f64,
min_events: usize,
) -> Self {
Self {
request_rate,
path_repetition,
error_rate,
suspicious_path_ratio,
no_cookies_threshold,
no_cookies_path_count,
min_events,
}
}
}
pub struct DdosTrainResult {
pub model: SerializedModel,
pub attack_count: usize,
pub normal_count: usize,
}
pub struct TrainArgs { pub struct TrainArgs {
pub input: String, pub input: String,
pub output: String, pub output: String,
@@ -82,147 +110,15 @@ fn parse_timestamp(ts: &str) -> f64 {
} }
pub fn run(args: TrainArgs) -> Result<()> { /// Core training pipeline: parse logs, extract features, label IPs, build KNN model.
eprintln!("Parsing logs from {}...", args.input); pub fn train_model(args: &TrainArgs) -> Result<DdosTrainResult> {
let ip_states = parse_logs(&args.input)?;
// Parse logs into per-IP state
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
let file = std::fs::File::open(&args.input)
.with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
let mut total_lines = 0u64;
let mut parse_errors = 0u64;
for line in reader.lines() {
let line = line?;
total_lines += 1;
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => {
parse_errors += 1;
continue;
}
};
// Skip non-audit entries
if entry.fields.method.is_empty() {
continue;
}
let ip = audit_log::strip_port(&entry.fields.client_ip).to_string();
let ts = parse_timestamp(&entry.timestamp);
let state = ip_states.entry(ip).or_default();
state.timestamps.push(ts);
state.methods.push(method_to_u8(&entry.fields.method));
state.path_hashes.push(fx_hash(&entry.fields.path));
state.host_hashes.push(fx_hash(&entry.fields.host));
state
.user_agent_hashes
.push(fx_hash(&entry.fields.user_agent));
state.statuses.push(entry.fields.status);
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
state
.content_lengths
.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
state.has_referer.push(
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
);
state.has_accept_language.push(
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
);
state.suspicious_paths.push(
crate::ddos::features::is_suspicious_path(&entry.fields.path),
);
}
eprintln!(
"Parsed {} lines ({} errors), {} unique IPs",
total_lines,
parse_errors,
ip_states.len()
);
// Extract feature vectors per IP (using sliding windows)
let window_secs = args.window_secs as f64; let window_secs = args.window_secs as f64;
let mut ip_features: FxHashMap<String, Vec<FeatureVector>> = FxHashMap::default(); let ip_features = extract_ip_features(&ip_states, args.min_events, window_secs);
for (ip, state) in &ip_states {
let n = state.timestamps.len();
if n < args.min_events {
continue;
}
// Extract one feature vector per window
let mut features = Vec::new();
let mut start = 0;
for end in 1..n {
let span = state.timestamps[end] - state.timestamps[start];
if span >= window_secs || end == n - 1 {
let fv = state.extract_features_for_window(start, end + 1, window_secs);
features.push(fv);
start = end + 1;
}
}
if !features.is_empty() {
ip_features.insert(ip.clone(), features);
}
}
// Label IPs // Label IPs
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default(); let ip_labels = label_ips(args, &ip_features)?;
if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) {
// IP list mode
let attack_ips: FxHashSet<String> = std::fs::read_to_string(attack_file)
.context("reading attack IPs file")?
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
let normal_ips: FxHashSet<String> = std::fs::read_to_string(normal_file)
.context("reading normal IPs file")?
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
for ip in ip_features.keys() {
if attack_ips.contains(ip) {
ip_labels.insert(ip.clone(), TrafficLabel::Attack);
} else if normal_ips.contains(ip) {
ip_labels.insert(ip.clone(), TrafficLabel::Normal);
}
}
} else if let Some(heuristics_file) = &args.heuristics {
// Heuristic auto-labeling
let heuristics_str = std::fs::read_to_string(heuristics_file)
.context("reading heuristics file")?;
let thresholds: HeuristicThresholds =
toml::from_str(&heuristics_str).context("parsing heuristics TOML")?;
for (ip, features) in &ip_features {
// Use the aggregate (last/max) feature vector for labeling
let avg = average_features(features);
let is_attack = avg[0] > thresholds.request_rate // request_rate
|| avg[7] > thresholds.path_repetition // path_repetition
|| avg[3] > thresholds.error_rate // error_rate
|| avg[13] > thresholds.suspicious_path_ratio // suspicious_path_ratio
|| (avg[10] < thresholds.no_cookies_threshold // no cookies + high unique paths
&& avg[1] > thresholds.no_cookies_path_count);
ip_labels.insert(
ip.clone(),
if is_attack {
TrafficLabel::Attack
} else {
TrafficLabel::Normal
},
);
}
} else {
bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling");
}
// Build training dataset // Build training dataset
let mut all_points: Vec<FeatureVector> = Vec::new(); let mut all_points: Vec<FeatureVector> = Vec::new();
@@ -246,12 +142,6 @@ pub fn run(args: TrainArgs) -> Result<()> {
.filter(|&&l| l == TrafficLabel::Attack) .filter(|&&l| l == TrafficLabel::Attack)
.count(); .count();
let normal_count = all_labels.len() - attack_count; let normal_count = all_labels.len() - attack_count;
eprintln!(
"Training with {} points ({} attack, {} normal)",
all_points.len(),
attack_count,
normal_count
);
// Normalize // Normalize
let norm_params = NormParams::from_data(&all_points); let norm_params = NormParams::from_data(&all_points);
@@ -260,7 +150,6 @@ pub fn run(args: TrainArgs) -> Result<()> {
.map(|v| norm_params.normalize(v)) .map(|v| norm_params.normalize(v))
.collect(); .collect();
// Serialize
let model = SerializedModel { let model = SerializedModel {
points: normalized, points: normalized,
labels: all_labels, labels: all_labels,
@@ -269,7 +158,215 @@ pub fn run(args: TrainArgs) -> Result<()> {
threshold: args.threshold, threshold: args.threshold,
}; };
let encoded = bincode::serialize(&model).context("serializing model")?; Ok(DdosTrainResult {
model,
attack_count,
normal_count,
})
}
/// Train a DDoS model from pre-parsed IP states with programmatic heuristic thresholds.
/// Used by the autotune pipeline to avoid re-parsing logs on each trial.
pub fn train_model_from_states(
ip_states: &FxHashMap<String, LogIpState>,
thresholds: &HeuristicThresholds,
k: usize,
threshold: f64,
window_secs: u64,
min_events: usize,
) -> Result<DdosTrainResult> {
let window_secs_f64 = window_secs as f64;
let ip_features = extract_ip_features(ip_states, min_events, window_secs_f64);
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default();
for (ip, features) in &ip_features {
let avg = average_features(features);
let is_attack = avg[0] > thresholds.request_rate
|| avg[7] > thresholds.path_repetition
|| avg[3] > thresholds.error_rate
|| avg[13] > thresholds.suspicious_path_ratio
|| (avg[10] < thresholds.no_cookies_threshold
&& avg[1] > thresholds.no_cookies_path_count);
ip_labels.insert(
ip.clone(),
if is_attack { TrafficLabel::Attack } else { TrafficLabel::Normal },
);
}
let mut all_points: Vec<FeatureVector> = Vec::new();
let mut all_labels: Vec<TrafficLabel> = Vec::new();
for (ip, features) in &ip_features {
if let Some(&label) = ip_labels.get(ip) {
for fv in features {
all_points.push(*fv);
all_labels.push(label);
}
}
}
if all_points.is_empty() {
bail!("No labeled data points found with these heuristic thresholds.");
}
let attack_count = all_labels.iter().filter(|&&l| l == TrafficLabel::Attack).count();
let normal_count = all_labels.len() - attack_count;
let norm_params = NormParams::from_data(&all_points);
let normalized: Vec<FeatureVector> = all_points.iter().map(|v| norm_params.normalize(v)).collect();
let model = SerializedModel {
points: normalized,
labels: all_labels,
norm_params,
k,
threshold,
};
Ok(DdosTrainResult { model, attack_count, normal_count })
}
/// Parse audit logs into per-IP state maps.
pub fn parse_logs(input: &str) -> Result<FxHashMap<String, LogIpState>> {
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
let file = std::fs::File::open(input)
.with_context(|| format!("opening {}", input))?;
let reader = std::io::BufReader::new(file);
for line in reader.lines() {
let line = line?;
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => continue,
};
if entry.fields.method.is_empty() {
continue;
}
let ip = audit_log::strip_port(&entry.fields.client_ip).to_string();
let ts = parse_timestamp(&entry.timestamp);
let state = ip_states.entry(ip).or_default();
state.timestamps.push(ts);
state.methods.push(method_to_u8(&entry.fields.method));
state.path_hashes.push(fx_hash(&entry.fields.path));
state.host_hashes.push(fx_hash(&entry.fields.host));
state.user_agent_hashes.push(fx_hash(&entry.fields.user_agent));
state.statuses.push(entry.fields.status);
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
state.content_lengths.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
state.has_referer.push(
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
);
state.has_accept_language.push(
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
);
state.suspicious_paths.push(
crate::ddos::features::is_suspicious_path(&entry.fields.path),
);
}
Ok(ip_states)
}
/// Extract feature vectors per IP using sliding windows.
pub fn extract_ip_features(
ip_states: &FxHashMap<String, LogIpState>,
min_events: usize,
window_secs: f64,
) -> FxHashMap<String, Vec<FeatureVector>> {
let mut ip_features: FxHashMap<String, Vec<FeatureVector>> = FxHashMap::default();
for (ip, state) in ip_states {
let n = state.timestamps.len();
if n < min_events {
continue;
}
let mut features = Vec::new();
let mut start = 0;
for end in 1..n {
let span = state.timestamps[end] - state.timestamps[start];
if span >= window_secs || end == n - 1 {
let fv = state.extract_features_for_window(start, end + 1, window_secs);
features.push(fv);
start = end + 1;
}
}
if !features.is_empty() {
ip_features.insert(ip.clone(), features);
}
}
ip_features
}
fn label_ips(
args: &TrainArgs,
ip_features: &FxHashMap<String, Vec<FeatureVector>>,
) -> Result<FxHashMap<String, TrafficLabel>> {
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default();
if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) {
let attack_ips: FxHashSet<String> = std::fs::read_to_string(attack_file)
.context("reading attack IPs file")?
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
let normal_ips: FxHashSet<String> = std::fs::read_to_string(normal_file)
.context("reading normal IPs file")?
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
for ip in ip_features.keys() {
if attack_ips.contains(ip) {
ip_labels.insert(ip.clone(), TrafficLabel::Attack);
} else if normal_ips.contains(ip) {
ip_labels.insert(ip.clone(), TrafficLabel::Normal);
}
}
} else if let Some(heuristics_file) = &args.heuristics {
let heuristics_str = std::fs::read_to_string(heuristics_file)
.context("reading heuristics file")?;
let thresholds: HeuristicThresholds =
toml::from_str(&heuristics_str).context("parsing heuristics TOML")?;
for (ip, features) in ip_features {
let avg = average_features(features);
let is_attack = avg[0] > thresholds.request_rate
|| avg[7] > thresholds.path_repetition
|| avg[3] > thresholds.error_rate
|| avg[13] > thresholds.suspicious_path_ratio
|| (avg[10] < thresholds.no_cookies_threshold
&& avg[1] > thresholds.no_cookies_path_count);
ip_labels.insert(
ip.clone(),
if is_attack { TrafficLabel::Attack } else { TrafficLabel::Normal },
);
}
} else {
bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling");
}
Ok(ip_labels)
}
pub fn run(args: TrainArgs) -> Result<()> {
eprintln!("Parsing logs from {}...", args.input);
let result = train_model(&args)?;
eprintln!(
"Training with {} points ({} attack, {} normal)",
result.model.points.len(),
result.attack_count,
result.normal_count
);
let encoded = bincode::serialize(&result.model).context("serializing model")?;
std::fs::write(&args.output, &encoded) std::fs::write(&args.output, &encoded)
.with_context(|| format!("writing model to {}", args.output))?; .with_context(|| format!("writing model to {}", args.output))?;
@@ -277,7 +374,7 @@ pub fn run(args: TrainArgs) -> Result<()> {
"Model saved to {} ({} bytes, {} points)", "Model saved to {} ({} bytes, {} points)",
args.output, args.output,
encoded.len(), encoded.len(),
model.points.len() result.model.points.len()
); );
Ok(()) Ok(())

View File

@@ -18,7 +18,7 @@ pub struct TrainScannerArgs {
} }
/// Default suspicious fragments — matches the DDoS feature list plus extras. /// Default suspicious fragments — matches the DDoS feature list plus extras.
const DEFAULT_FRAGMENTS: &[&str] = &[ pub const DEFAULT_FRAGMENTS: &[&str] = &[
".env", ".git", ".bak", ".sql", ".tar", ".zip", ".env", ".git", ".bak", ".sql", ".tar", ".zip",
"wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc", "wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc",
"phpinfo", "phpmyadmin", "php-info", "phpinfo", "phpmyadmin", "php-info",
@@ -32,9 +32,218 @@ const DEFAULT_FRAGMENTS: &[&str] = &[
const ATTACK_EXTENSIONS: &[&str] = &[".env", ".sql", ".bak", ".git/config"]; const ATTACK_EXTENSIONS: &[&str] = &[".env", ".sql", ".bak", ".git/config"];
const TRAVERSAL_MARKERS: &[&str] = &["..", "%00", "%0a"]; const TRAVERSAL_MARKERS: &[&str] = &["..", "%00", "%0a"];
struct LabeledSample { pub struct LabeledSample {
features: ScannerFeatureVector, pub features: ScannerFeatureVector,
label: f64, // 1.0 = attack, 0.0 = normal pub label: f64, // 1.0 = attack, 0.0 = normal
}
pub struct ScannerTrainResult {
pub model: ScannerModel,
pub train_metrics: Metrics,
pub test_metrics: Metrics,
}
/// Core training pipeline: parse logs, label, train, evaluate. Returns the trained model and metrics.
pub fn train_and_evaluate(
args: &TrainScannerArgs,
learning_rate: f64,
epochs: usize,
class_weight_multiplier: f64,
) -> Result<ScannerTrainResult> {
let mut fragments: Vec<String> = DEFAULT_FRAGMENTS.iter().map(|s| s.to_string()).collect();
let fragment_hashes: FxHashSet<u64> = fragments
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
.iter()
.map(|e| fx_hash_bytes(e.as_bytes()))
.collect();
let mut samples: Vec<LabeledSample> = Vec::new();
let file = std::fs::File::open(&args.input)
.with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
let mut parsed_entries: Vec<(AuditFields, String)> = Vec::new();
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => continue,
};
let host_prefix = entry
.fields
.host
.split('.')
.next()
.unwrap_or("")
.to_string();
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
parsed_entries.push((entry.fields, host_prefix));
}
for (fields, host_prefix) in &parsed_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let feats = features::extract_features(
&fields.method,
&fields.path,
host_prefix,
has_cookies,
has_referer,
has_accept_language,
"-",
&fields.user_agent,
fields.content_length,
&fragment_hashes,
&extension_hashes,
&log_hosts,
);
let label = if let Some(ref gt) = fields.label {
match gt.as_str() {
"attack" | "anomalous" => Some(1.0),
"normal" => Some(0.0),
_ => None,
}
} else {
label_request(
&fields.path,
has_cookies,
has_referer,
has_accept_language,
&fields.user_agent,
host_prefix,
fields.status,
&log_hosts,
&fragment_hashes,
)
};
if let Some(l) = label {
samples.push(LabeledSample {
features: feats,
label: l,
});
}
}
if args.csic {
let csic_entries = crate::scanner::csic::fetch_csic_dataset()?;
for (_, host_prefix) in &csic_entries {
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
}
for (fields, host_prefix) in &csic_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let feats = features::extract_features(
&fields.method,
&fields.path,
host_prefix,
has_cookies,
has_referer,
has_accept_language,
"-",
&fields.user_agent,
fields.content_length,
&fragment_hashes,
&extension_hashes,
&log_hosts,
);
let label = match fields.label.as_deref() {
Some("attack" | "anomalous") => 1.0,
Some("normal") => 0.0,
_ => continue,
};
samples.push(LabeledSample {
features: feats,
label,
});
}
}
if let Some(wordlist_dir) = &args.wordlists {
ingest_wordlists(
wordlist_dir,
&mut samples,
&mut fragments,
&fragment_hashes,
&extension_hashes,
&log_hosts,
)?;
}
if samples.is_empty() {
anyhow::bail!("no training samples found");
}
// Stratified 80/20 train/test split
let (train_samples, test_samples) = stratified_split(&mut samples, 0.8, 42);
// Compute normalization params from training set only
let train_feature_vecs: Vec<ScannerFeatureVector> =
train_samples.iter().map(|s| s.features).collect();
let norm_params = ScannerNormParams::from_data(&train_feature_vecs);
let train_normalized: Vec<ScannerFeatureVector> =
train_feature_vecs.iter().map(|v| norm_params.normalize(v)).collect();
// Train logistic regression with configurable params
let weights = train_logistic_regression_weighted(
&train_normalized,
&train_samples,
epochs,
learning_rate,
class_weight_multiplier,
);
let model = ScannerModel {
weights,
threshold: args.threshold,
norm_params: norm_params.clone(),
fragments,
};
let train_metrics = evaluate(&train_normalized, &train_samples, &weights, args.threshold);
let test_feature_vecs: Vec<ScannerFeatureVector> =
test_samples.iter().map(|s| s.features).collect();
let test_normalized: Vec<ScannerFeatureVector> =
test_feature_vecs.iter().map(|v| norm_params.normalize(v)).collect();
let test_metrics = evaluate(&test_normalized, &test_samples, &weights, args.threshold);
Ok(ScannerTrainResult {
model,
train_metrics,
test_metrics,
})
} }
pub fn run(args: TrainScannerArgs) -> Result<()> { pub fn run(args: TrainScannerArgs) -> Result<()> {
@@ -295,24 +504,28 @@ pub fn run(args: TrainScannerArgs) -> Result<()> {
Ok(()) Ok(())
} }
struct Metrics { pub struct Metrics {
tp: u32, pub tp: u32,
fp: u32, pub fp: u32,
tn: u32, pub tn: u32,
fn_: u32, pub fn_: u32,
} }
impl Metrics { impl Metrics {
fn precision(&self) -> f64 { pub fn precision(&self) -> f64 {
if self.tp + self.fp > 0 { self.tp as f64 / (self.tp + self.fp) as f64 } else { 0.0 } if self.tp + self.fp > 0 { self.tp as f64 / (self.tp + self.fp) as f64 } else { 0.0 }
} }
fn recall(&self) -> f64 { pub fn recall(&self) -> f64 {
if self.tp + self.fn_ > 0 { self.tp as f64 / (self.tp + self.fn_) as f64 } else { 0.0 } if self.tp + self.fn_ > 0 { self.tp as f64 / (self.tp + self.fn_) as f64 } else { 0.0 }
} }
fn f1(&self) -> f64 { pub fn f1(&self) -> f64 {
self.fbeta(1.0)
}
pub fn fbeta(&self, beta: f64) -> f64 {
let p = self.precision(); let p = self.precision();
let r = self.recall(); let r = self.recall();
if p + r > 0.0 { 2.0 * p * r / (p + r) } else { 0.0 } let b2 = beta * beta;
if p + r > 0.0 { (1.0 + b2) * p * r / (b2 * p + r) } else { 0.0 }
} }
fn print(&self, total: usize) { fn print(&self, total: usize) {
let acc = (self.tp + self.tn) as f64 / total as f64 * 100.0; let acc = (self.tp + self.tn) as f64 / total as f64 * 100.0;
@@ -457,6 +670,59 @@ fn train_logistic_regression(
weights weights
} }
/// Like `train_logistic_regression` but with configurable class_weight_multiplier.
/// A multiplier > 1.0 increases the weight of the minority (attack) class further.
fn train_logistic_regression_weighted(
normalized: &[ScannerFeatureVector],
samples: &[LabeledSample],
epochs: usize,
learning_rate: f64,
class_weight_multiplier: f64,
) -> [f64; NUM_SCANNER_WEIGHTS] {
let mut weights = [0.0f64; NUM_SCANNER_WEIGHTS];
let n = samples.len() as f64;
let n_attack = samples.iter().filter(|s| s.label > 0.5).count() as f64;
let n_normal = n - n_attack;
let (w_attack, w_normal) = if n_attack > 0.0 && n_normal > 0.0 {
(n / (2.0 * n_attack) * class_weight_multiplier, n / (2.0 * n_normal))
} else {
(1.0, 1.0)
};
for _epoch in 0..epochs {
let mut gradients = [0.0f64; NUM_SCANNER_WEIGHTS];
for (i, sample) in samples.iter().enumerate() {
let f = &normalized[i];
let mut z = weights[NUM_SCANNER_FEATURES + 2];
for j in 0..NUM_SCANNER_FEATURES {
z += weights[j] * f[j];
}
z += weights[12] * f[0] * (1.0 - f[3]);
z += weights[13] * (1.0 - f[9]) * (1.0 - f[5]);
let prediction = sigmoid(z);
let error = prediction - sample.label;
let cw = if sample.label > 0.5 { w_attack } else { w_normal };
let weighted_error = error * cw;
for j in 0..NUM_SCANNER_FEATURES {
gradients[j] += weighted_error * f[j];
}
gradients[12] += weighted_error * f[0] * (1.0 - f[3]);
gradients[13] += weighted_error * (1.0 - f[9]) * (1.0 - f[5]);
gradients[14] += weighted_error;
}
for j in 0..NUM_SCANNER_WEIGHTS {
weights[j] -= learning_rate * gradients[j] / n;
}
}
weights
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn label_request( fn label_request(
path: &str, path: &str,

View File

@@ -46,13 +46,14 @@ fn make_model(
fn default_ddos_config() -> DDoSConfig { fn default_ddos_config() -> DDoSConfig {
DDoSConfig { DDoSConfig {
model_path: String::new(), model_path: Some(String::new()),
k: 5, k: 5,
threshold: 0.6, threshold: 0.6,
window_secs: 60, window_secs: 60,
window_capacity: 1000, window_capacity: 1000,
min_events: 10, min_events: 10,
enabled: true, enabled: true,
use_ensemble: false,
} }
} }