chore: update scanner/ddos trainers, benchmarks, and tests
Expand DDoS feature vector to 14 dimensions (cookie_ratio, referer_ratio, accept_language_ratio, suspicious_path_ratio). Add heuristic auto-labeling to DDoS trainer. Update benchmarks and tests to match new feature vectors. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
40
benches/ddos_bench.rs
Normal file
40
benches/ddos_bench.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||||
|
use sunbeam_proxy::ensemble::ddos::ddos_ensemble_predict;
|
||||||
|
use sunbeam_proxy::ensemble::gen::ddos_weights;
|
||||||
|
use sunbeam_proxy::ensemble::mlp::mlp_predict_32;
|
||||||
|
use sunbeam_proxy::ensemble::tree::tree_predict;
|
||||||
|
|
||||||
|
fn bench_ensemble_ddos_full(c: &mut Criterion) {
|
||||||
|
let raw: [f32; 14] = [5.0, 10.0, 2.0, 0.1, 50.0, 0.5, 3.0, 0.3, 500.0, 2.0, 0.8, 0.7, 0.9, 0.1];
|
||||||
|
c.bench_function("ensemble::ddos full predict", |b| {
|
||||||
|
b.iter(|| ddos_ensemble_predict(black_box(&raw)))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bench_ensemble_ddos_tree_only(c: &mut Criterion) {
|
||||||
|
let input: [f32; 14] = [0.5; 14];
|
||||||
|
c.bench_function("ensemble::ddos tree_only", |b| {
|
||||||
|
b.iter(|| tree_predict(black_box(&ddos_weights::TREE_NODES), black_box(&input)))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bench_ensemble_ddos_mlp_only(c: &mut Criterion) {
|
||||||
|
let input: [f32; 14] = [0.5; 14];
|
||||||
|
c.bench_function("ensemble::ddos mlp_only", |b| {
|
||||||
|
b.iter(|| mlp_predict_32::<14>(
|
||||||
|
black_box(&ddos_weights::W1),
|
||||||
|
black_box(&ddos_weights::B1),
|
||||||
|
black_box(&ddos_weights::W2),
|
||||||
|
black_box(ddos_weights::B2),
|
||||||
|
black_box(&input),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(
|
||||||
|
benches,
|
||||||
|
bench_ensemble_ddos_full,
|
||||||
|
bench_ensemble_ddos_tree_only,
|
||||||
|
bench_ensemble_ddos_mlp_only,
|
||||||
|
);
|
||||||
|
criterion_main!(benches);
|
||||||
@@ -1,5 +1,9 @@
|
|||||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||||
use sunbeam_proxy::config::RouteConfig;
|
use sunbeam_proxy::config::RouteConfig;
|
||||||
|
use sunbeam_proxy::ensemble::gen::scanner_weights;
|
||||||
|
use sunbeam_proxy::ensemble::mlp::mlp_predict_32;
|
||||||
|
use sunbeam_proxy::ensemble::scanner::scanner_ensemble_predict;
|
||||||
|
use sunbeam_proxy::ensemble::tree::tree_predict;
|
||||||
use sunbeam_proxy::scanner::detector::ScannerDetector;
|
use sunbeam_proxy::scanner::detector::ScannerDetector;
|
||||||
use sunbeam_proxy::scanner::features::{
|
use sunbeam_proxy::scanner::features::{
|
||||||
self, fx_hash_bytes, ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS,
|
self, fx_hash_bytes, ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS,
|
||||||
@@ -250,6 +254,34 @@ fn bench_extract_features(c: &mut Criterion) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn bench_ensemble_scanner_full(c: &mut Criterion) {
|
||||||
|
// Raw features simulating a scanner probe
|
||||||
|
let raw: [f32; 12] = [0.8, 3.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 1.0];
|
||||||
|
c.bench_function("ensemble::scanner full predict", |b| {
|
||||||
|
b.iter(|| scanner_ensemble_predict(black_box(&raw)))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bench_ensemble_scanner_tree_only(c: &mut Criterion) {
|
||||||
|
let input: [f32; 12] = [0.8, 0.3, 1.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 1.0];
|
||||||
|
c.bench_function("ensemble::scanner tree_only", |b| {
|
||||||
|
b.iter(|| tree_predict(black_box(&scanner_weights::TREE_NODES), black_box(&input)))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bench_ensemble_scanner_mlp_only(c: &mut Criterion) {
|
||||||
|
let input: [f32; 12] = [0.5; 12];
|
||||||
|
c.bench_function("ensemble::scanner mlp_only", |b| {
|
||||||
|
b.iter(|| mlp_predict_32::<12>(
|
||||||
|
black_box(&scanner_weights::W1),
|
||||||
|
black_box(&scanner_weights::B1),
|
||||||
|
black_box(&scanner_weights::W2),
|
||||||
|
black_box(scanner_weights::B2),
|
||||||
|
black_box(&input),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
criterion_group!(
|
criterion_group!(
|
||||||
benches,
|
benches,
|
||||||
bench_check_normal_browser,
|
bench_check_normal_browser,
|
||||||
@@ -260,5 +292,8 @@ criterion_group!(
|
|||||||
bench_check_deep_path,
|
bench_check_deep_path,
|
||||||
bench_check_api_legitimate,
|
bench_check_api_legitimate,
|
||||||
bench_extract_features,
|
bench_extract_features,
|
||||||
|
bench_ensemble_scanner_full,
|
||||||
|
bench_ensemble_scanner_tree_only,
|
||||||
|
bench_ensemble_scanner_mlp_only,
|
||||||
);
|
);
|
||||||
criterion_main!(benches);
|
criterion_main!(benches);
|
||||||
|
|||||||
@@ -41,6 +41,34 @@ fn default_no_cookies_threshold() -> f64 { 0.05 }
|
|||||||
fn default_no_cookies_path_count() -> f64 { 20.0 }
|
fn default_no_cookies_path_count() -> f64 { 20.0 }
|
||||||
fn default_min_events() -> usize { 10 }
|
fn default_min_events() -> usize { 10 }
|
||||||
|
|
||||||
|
impl HeuristicThresholds {
|
||||||
|
pub fn new(
|
||||||
|
request_rate: f64,
|
||||||
|
path_repetition: f64,
|
||||||
|
error_rate: f64,
|
||||||
|
suspicious_path_ratio: f64,
|
||||||
|
no_cookies_threshold: f64,
|
||||||
|
no_cookies_path_count: f64,
|
||||||
|
min_events: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
request_rate,
|
||||||
|
path_repetition,
|
||||||
|
error_rate,
|
||||||
|
suspicious_path_ratio,
|
||||||
|
no_cookies_threshold,
|
||||||
|
no_cookies_path_count,
|
||||||
|
min_events,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DdosTrainResult {
|
||||||
|
pub model: SerializedModel,
|
||||||
|
pub attack_count: usize,
|
||||||
|
pub normal_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct TrainArgs {
|
pub struct TrainArgs {
|
||||||
pub input: String,
|
pub input: String,
|
||||||
pub output: String,
|
pub output: String,
|
||||||
@@ -82,147 +110,15 @@ fn parse_timestamp(ts: &str) -> f64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn run(args: TrainArgs) -> Result<()> {
|
/// Core training pipeline: parse logs, extract features, label IPs, build KNN model.
|
||||||
eprintln!("Parsing logs from {}...", args.input);
|
pub fn train_model(args: &TrainArgs) -> Result<DdosTrainResult> {
|
||||||
|
let ip_states = parse_logs(&args.input)?;
|
||||||
|
|
||||||
// Parse logs into per-IP state
|
|
||||||
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
|
|
||||||
let file = std::fs::File::open(&args.input)
|
|
||||||
.with_context(|| format!("opening {}", args.input))?;
|
|
||||||
let reader = std::io::BufReader::new(file);
|
|
||||||
|
|
||||||
let mut total_lines = 0u64;
|
|
||||||
let mut parse_errors = 0u64;
|
|
||||||
|
|
||||||
for line in reader.lines() {
|
|
||||||
let line = line?;
|
|
||||||
total_lines += 1;
|
|
||||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
|
||||||
Ok(e) => e,
|
|
||||||
Err(_) => {
|
|
||||||
parse_errors += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Skip non-audit entries
|
|
||||||
if entry.fields.method.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let ip = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
|
||||||
let ts = parse_timestamp(&entry.timestamp);
|
|
||||||
|
|
||||||
let state = ip_states.entry(ip).or_default();
|
|
||||||
state.timestamps.push(ts);
|
|
||||||
state.methods.push(method_to_u8(&entry.fields.method));
|
|
||||||
state.path_hashes.push(fx_hash(&entry.fields.path));
|
|
||||||
state.host_hashes.push(fx_hash(&entry.fields.host));
|
|
||||||
state
|
|
||||||
.user_agent_hashes
|
|
||||||
.push(fx_hash(&entry.fields.user_agent));
|
|
||||||
state.statuses.push(entry.fields.status);
|
|
||||||
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
|
|
||||||
state
|
|
||||||
.content_lengths
|
|
||||||
.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
|
|
||||||
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
|
|
||||||
state.has_referer.push(
|
|
||||||
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
|
|
||||||
);
|
|
||||||
state.has_accept_language.push(
|
|
||||||
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
|
|
||||||
);
|
|
||||||
state.suspicious_paths.push(
|
|
||||||
crate::ddos::features::is_suspicious_path(&entry.fields.path),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
eprintln!(
|
|
||||||
"Parsed {} lines ({} errors), {} unique IPs",
|
|
||||||
total_lines,
|
|
||||||
parse_errors,
|
|
||||||
ip_states.len()
|
|
||||||
);
|
|
||||||
|
|
||||||
// Extract feature vectors per IP (using sliding windows)
|
|
||||||
let window_secs = args.window_secs as f64;
|
let window_secs = args.window_secs as f64;
|
||||||
let mut ip_features: FxHashMap<String, Vec<FeatureVector>> = FxHashMap::default();
|
let ip_features = extract_ip_features(&ip_states, args.min_events, window_secs);
|
||||||
|
|
||||||
for (ip, state) in &ip_states {
|
|
||||||
let n = state.timestamps.len();
|
|
||||||
if n < args.min_events {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Extract one feature vector per window
|
|
||||||
let mut features = Vec::new();
|
|
||||||
let mut start = 0;
|
|
||||||
for end in 1..n {
|
|
||||||
let span = state.timestamps[end] - state.timestamps[start];
|
|
||||||
if span >= window_secs || end == n - 1 {
|
|
||||||
let fv = state.extract_features_for_window(start, end + 1, window_secs);
|
|
||||||
features.push(fv);
|
|
||||||
start = end + 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !features.is_empty() {
|
|
||||||
ip_features.insert(ip.clone(), features);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Label IPs
|
// Label IPs
|
||||||
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default();
|
let ip_labels = label_ips(args, &ip_features)?;
|
||||||
|
|
||||||
if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) {
|
|
||||||
// IP list mode
|
|
||||||
let attack_ips: FxHashSet<String> = std::fs::read_to_string(attack_file)
|
|
||||||
.context("reading attack IPs file")?
|
|
||||||
.lines()
|
|
||||||
.map(|l| l.trim().to_string())
|
|
||||||
.filter(|l| !l.is_empty())
|
|
||||||
.collect();
|
|
||||||
let normal_ips: FxHashSet<String> = std::fs::read_to_string(normal_file)
|
|
||||||
.context("reading normal IPs file")?
|
|
||||||
.lines()
|
|
||||||
.map(|l| l.trim().to_string())
|
|
||||||
.filter(|l| !l.is_empty())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
for ip in ip_features.keys() {
|
|
||||||
if attack_ips.contains(ip) {
|
|
||||||
ip_labels.insert(ip.clone(), TrafficLabel::Attack);
|
|
||||||
} else if normal_ips.contains(ip) {
|
|
||||||
ip_labels.insert(ip.clone(), TrafficLabel::Normal);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if let Some(heuristics_file) = &args.heuristics {
|
|
||||||
// Heuristic auto-labeling
|
|
||||||
let heuristics_str = std::fs::read_to_string(heuristics_file)
|
|
||||||
.context("reading heuristics file")?;
|
|
||||||
let thresholds: HeuristicThresholds =
|
|
||||||
toml::from_str(&heuristics_str).context("parsing heuristics TOML")?;
|
|
||||||
|
|
||||||
for (ip, features) in &ip_features {
|
|
||||||
// Use the aggregate (last/max) feature vector for labeling
|
|
||||||
let avg = average_features(features);
|
|
||||||
let is_attack = avg[0] > thresholds.request_rate // request_rate
|
|
||||||
|| avg[7] > thresholds.path_repetition // path_repetition
|
|
||||||
|| avg[3] > thresholds.error_rate // error_rate
|
|
||||||
|| avg[13] > thresholds.suspicious_path_ratio // suspicious_path_ratio
|
|
||||||
|| (avg[10] < thresholds.no_cookies_threshold // no cookies + high unique paths
|
|
||||||
&& avg[1] > thresholds.no_cookies_path_count);
|
|
||||||
ip_labels.insert(
|
|
||||||
ip.clone(),
|
|
||||||
if is_attack {
|
|
||||||
TrafficLabel::Attack
|
|
||||||
} else {
|
|
||||||
TrafficLabel::Normal
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build training dataset
|
// Build training dataset
|
||||||
let mut all_points: Vec<FeatureVector> = Vec::new();
|
let mut all_points: Vec<FeatureVector> = Vec::new();
|
||||||
@@ -246,12 +142,6 @@ pub fn run(args: TrainArgs) -> Result<()> {
|
|||||||
.filter(|&&l| l == TrafficLabel::Attack)
|
.filter(|&&l| l == TrafficLabel::Attack)
|
||||||
.count();
|
.count();
|
||||||
let normal_count = all_labels.len() - attack_count;
|
let normal_count = all_labels.len() - attack_count;
|
||||||
eprintln!(
|
|
||||||
"Training with {} points ({} attack, {} normal)",
|
|
||||||
all_points.len(),
|
|
||||||
attack_count,
|
|
||||||
normal_count
|
|
||||||
);
|
|
||||||
|
|
||||||
// Normalize
|
// Normalize
|
||||||
let norm_params = NormParams::from_data(&all_points);
|
let norm_params = NormParams::from_data(&all_points);
|
||||||
@@ -260,7 +150,6 @@ pub fn run(args: TrainArgs) -> Result<()> {
|
|||||||
.map(|v| norm_params.normalize(v))
|
.map(|v| norm_params.normalize(v))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Serialize
|
|
||||||
let model = SerializedModel {
|
let model = SerializedModel {
|
||||||
points: normalized,
|
points: normalized,
|
||||||
labels: all_labels,
|
labels: all_labels,
|
||||||
@@ -269,7 +158,215 @@ pub fn run(args: TrainArgs) -> Result<()> {
|
|||||||
threshold: args.threshold,
|
threshold: args.threshold,
|
||||||
};
|
};
|
||||||
|
|
||||||
let encoded = bincode::serialize(&model).context("serializing model")?;
|
Ok(DdosTrainResult {
|
||||||
|
model,
|
||||||
|
attack_count,
|
||||||
|
normal_count,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Train a DDoS model from pre-parsed IP states with programmatic heuristic thresholds.
|
||||||
|
/// Used by the autotune pipeline to avoid re-parsing logs on each trial.
|
||||||
|
pub fn train_model_from_states(
|
||||||
|
ip_states: &FxHashMap<String, LogIpState>,
|
||||||
|
thresholds: &HeuristicThresholds,
|
||||||
|
k: usize,
|
||||||
|
threshold: f64,
|
||||||
|
window_secs: u64,
|
||||||
|
min_events: usize,
|
||||||
|
) -> Result<DdosTrainResult> {
|
||||||
|
let window_secs_f64 = window_secs as f64;
|
||||||
|
let ip_features = extract_ip_features(ip_states, min_events, window_secs_f64);
|
||||||
|
|
||||||
|
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default();
|
||||||
|
for (ip, features) in &ip_features {
|
||||||
|
let avg = average_features(features);
|
||||||
|
let is_attack = avg[0] > thresholds.request_rate
|
||||||
|
|| avg[7] > thresholds.path_repetition
|
||||||
|
|| avg[3] > thresholds.error_rate
|
||||||
|
|| avg[13] > thresholds.suspicious_path_ratio
|
||||||
|
|| (avg[10] < thresholds.no_cookies_threshold
|
||||||
|
&& avg[1] > thresholds.no_cookies_path_count);
|
||||||
|
ip_labels.insert(
|
||||||
|
ip.clone(),
|
||||||
|
if is_attack { TrafficLabel::Attack } else { TrafficLabel::Normal },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut all_points: Vec<FeatureVector> = Vec::new();
|
||||||
|
let mut all_labels: Vec<TrafficLabel> = Vec::new();
|
||||||
|
|
||||||
|
for (ip, features) in &ip_features {
|
||||||
|
if let Some(&label) = ip_labels.get(ip) {
|
||||||
|
for fv in features {
|
||||||
|
all_points.push(*fv);
|
||||||
|
all_labels.push(label);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_points.is_empty() {
|
||||||
|
bail!("No labeled data points found with these heuristic thresholds.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let attack_count = all_labels.iter().filter(|&&l| l == TrafficLabel::Attack).count();
|
||||||
|
let normal_count = all_labels.len() - attack_count;
|
||||||
|
|
||||||
|
let norm_params = NormParams::from_data(&all_points);
|
||||||
|
let normalized: Vec<FeatureVector> = all_points.iter().map(|v| norm_params.normalize(v)).collect();
|
||||||
|
|
||||||
|
let model = SerializedModel {
|
||||||
|
points: normalized,
|
||||||
|
labels: all_labels,
|
||||||
|
norm_params,
|
||||||
|
k,
|
||||||
|
threshold,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(DdosTrainResult { model, attack_count, normal_count })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse audit logs into per-IP state maps.
|
||||||
|
pub fn parse_logs(input: &str) -> Result<FxHashMap<String, LogIpState>> {
|
||||||
|
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
|
||||||
|
let file = std::fs::File::open(input)
|
||||||
|
.with_context(|| format!("opening {}", input))?;
|
||||||
|
let reader = std::io::BufReader::new(file);
|
||||||
|
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
if entry.fields.method.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ip = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||||
|
let ts = parse_timestamp(&entry.timestamp);
|
||||||
|
|
||||||
|
let state = ip_states.entry(ip).or_default();
|
||||||
|
state.timestamps.push(ts);
|
||||||
|
state.methods.push(method_to_u8(&entry.fields.method));
|
||||||
|
state.path_hashes.push(fx_hash(&entry.fields.path));
|
||||||
|
state.host_hashes.push(fx_hash(&entry.fields.host));
|
||||||
|
state.user_agent_hashes.push(fx_hash(&entry.fields.user_agent));
|
||||||
|
state.statuses.push(entry.fields.status);
|
||||||
|
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
|
||||||
|
state.content_lengths.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
|
||||||
|
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
|
||||||
|
state.has_referer.push(
|
||||||
|
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
|
||||||
|
);
|
||||||
|
state.has_accept_language.push(
|
||||||
|
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
|
||||||
|
);
|
||||||
|
state.suspicious_paths.push(
|
||||||
|
crate::ddos::features::is_suspicious_path(&entry.fields.path),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ip_states)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract feature vectors per IP using sliding windows.
|
||||||
|
pub fn extract_ip_features(
|
||||||
|
ip_states: &FxHashMap<String, LogIpState>,
|
||||||
|
min_events: usize,
|
||||||
|
window_secs: f64,
|
||||||
|
) -> FxHashMap<String, Vec<FeatureVector>> {
|
||||||
|
let mut ip_features: FxHashMap<String, Vec<FeatureVector>> = FxHashMap::default();
|
||||||
|
|
||||||
|
for (ip, state) in ip_states {
|
||||||
|
let n = state.timestamps.len();
|
||||||
|
if n < min_events {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let mut features = Vec::new();
|
||||||
|
let mut start = 0;
|
||||||
|
for end in 1..n {
|
||||||
|
let span = state.timestamps[end] - state.timestamps[start];
|
||||||
|
if span >= window_secs || end == n - 1 {
|
||||||
|
let fv = state.extract_features_for_window(start, end + 1, window_secs);
|
||||||
|
features.push(fv);
|
||||||
|
start = end + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !features.is_empty() {
|
||||||
|
ip_features.insert(ip.clone(), features);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ip_features
|
||||||
|
}
|
||||||
|
|
||||||
|
fn label_ips(
|
||||||
|
args: &TrainArgs,
|
||||||
|
ip_features: &FxHashMap<String, Vec<FeatureVector>>,
|
||||||
|
) -> Result<FxHashMap<String, TrafficLabel>> {
|
||||||
|
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default();
|
||||||
|
|
||||||
|
if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) {
|
||||||
|
let attack_ips: FxHashSet<String> = std::fs::read_to_string(attack_file)
|
||||||
|
.context("reading attack IPs file")?
|
||||||
|
.lines()
|
||||||
|
.map(|l| l.trim().to_string())
|
||||||
|
.filter(|l| !l.is_empty())
|
||||||
|
.collect();
|
||||||
|
let normal_ips: FxHashSet<String> = std::fs::read_to_string(normal_file)
|
||||||
|
.context("reading normal IPs file")?
|
||||||
|
.lines()
|
||||||
|
.map(|l| l.trim().to_string())
|
||||||
|
.filter(|l| !l.is_empty())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for ip in ip_features.keys() {
|
||||||
|
if attack_ips.contains(ip) {
|
||||||
|
ip_labels.insert(ip.clone(), TrafficLabel::Attack);
|
||||||
|
} else if normal_ips.contains(ip) {
|
||||||
|
ip_labels.insert(ip.clone(), TrafficLabel::Normal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let Some(heuristics_file) = &args.heuristics {
|
||||||
|
let heuristics_str = std::fs::read_to_string(heuristics_file)
|
||||||
|
.context("reading heuristics file")?;
|
||||||
|
let thresholds: HeuristicThresholds =
|
||||||
|
toml::from_str(&heuristics_str).context("parsing heuristics TOML")?;
|
||||||
|
|
||||||
|
for (ip, features) in ip_features {
|
||||||
|
let avg = average_features(features);
|
||||||
|
let is_attack = avg[0] > thresholds.request_rate
|
||||||
|
|| avg[7] > thresholds.path_repetition
|
||||||
|
|| avg[3] > thresholds.error_rate
|
||||||
|
|| avg[13] > thresholds.suspicious_path_ratio
|
||||||
|
|| (avg[10] < thresholds.no_cookies_threshold
|
||||||
|
&& avg[1] > thresholds.no_cookies_path_count);
|
||||||
|
ip_labels.insert(
|
||||||
|
ip.clone(),
|
||||||
|
if is_attack { TrafficLabel::Attack } else { TrafficLabel::Normal },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ip_labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run(args: TrainArgs) -> Result<()> {
|
||||||
|
eprintln!("Parsing logs from {}...", args.input);
|
||||||
|
|
||||||
|
let result = train_model(&args)?;
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
"Training with {} points ({} attack, {} normal)",
|
||||||
|
result.model.points.len(),
|
||||||
|
result.attack_count,
|
||||||
|
result.normal_count
|
||||||
|
);
|
||||||
|
|
||||||
|
let encoded = bincode::serialize(&result.model).context("serializing model")?;
|
||||||
std::fs::write(&args.output, &encoded)
|
std::fs::write(&args.output, &encoded)
|
||||||
.with_context(|| format!("writing model to {}", args.output))?;
|
.with_context(|| format!("writing model to {}", args.output))?;
|
||||||
|
|
||||||
@@ -277,7 +374,7 @@ pub fn run(args: TrainArgs) -> Result<()> {
|
|||||||
"Model saved to {} ({} bytes, {} points)",
|
"Model saved to {} ({} bytes, {} points)",
|
||||||
args.output,
|
args.output,
|
||||||
encoded.len(),
|
encoded.len(),
|
||||||
model.points.len()
|
result.model.points.len()
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ pub struct TrainScannerArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Default suspicious fragments — matches the DDoS feature list plus extras.
|
/// Default suspicious fragments — matches the DDoS feature list plus extras.
|
||||||
const DEFAULT_FRAGMENTS: &[&str] = &[
|
pub const DEFAULT_FRAGMENTS: &[&str] = &[
|
||||||
".env", ".git", ".bak", ".sql", ".tar", ".zip",
|
".env", ".git", ".bak", ".sql", ".tar", ".zip",
|
||||||
"wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc",
|
"wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc",
|
||||||
"phpinfo", "phpmyadmin", "php-info",
|
"phpinfo", "phpmyadmin", "php-info",
|
||||||
@@ -32,9 +32,218 @@ const DEFAULT_FRAGMENTS: &[&str] = &[
|
|||||||
const ATTACK_EXTENSIONS: &[&str] = &[".env", ".sql", ".bak", ".git/config"];
|
const ATTACK_EXTENSIONS: &[&str] = &[".env", ".sql", ".bak", ".git/config"];
|
||||||
const TRAVERSAL_MARKERS: &[&str] = &["..", "%00", "%0a"];
|
const TRAVERSAL_MARKERS: &[&str] = &["..", "%00", "%0a"];
|
||||||
|
|
||||||
struct LabeledSample {
|
pub struct LabeledSample {
|
||||||
features: ScannerFeatureVector,
|
pub features: ScannerFeatureVector,
|
||||||
label: f64, // 1.0 = attack, 0.0 = normal
|
pub label: f64, // 1.0 = attack, 0.0 = normal
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ScannerTrainResult {
|
||||||
|
pub model: ScannerModel,
|
||||||
|
pub train_metrics: Metrics,
|
||||||
|
pub test_metrics: Metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Core training pipeline: parse logs, label, train, evaluate. Returns the trained model and metrics.
|
||||||
|
pub fn train_and_evaluate(
|
||||||
|
args: &TrainScannerArgs,
|
||||||
|
learning_rate: f64,
|
||||||
|
epochs: usize,
|
||||||
|
class_weight_multiplier: f64,
|
||||||
|
) -> Result<ScannerTrainResult> {
|
||||||
|
let mut fragments: Vec<String> = DEFAULT_FRAGMENTS.iter().map(|s| s.to_string()).collect();
|
||||||
|
let fragment_hashes: FxHashSet<u64> = fragments
|
||||||
|
.iter()
|
||||||
|
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
||||||
|
.collect();
|
||||||
|
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
|
||||||
|
.iter()
|
||||||
|
.map(|e| fx_hash_bytes(e.as_bytes()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut samples: Vec<LabeledSample> = Vec::new();
|
||||||
|
let file = std::fs::File::open(&args.input)
|
||||||
|
.with_context(|| format!("opening {}", args.input))?;
|
||||||
|
let reader = std::io::BufReader::new(file);
|
||||||
|
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
|
||||||
|
let mut parsed_entries: Vec<(AuditFields, String)> = Vec::new();
|
||||||
|
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
let host_prefix = entry
|
||||||
|
.fields
|
||||||
|
.host
|
||||||
|
.split('.')
|
||||||
|
.next()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
|
||||||
|
parsed_entries.push((entry.fields, host_prefix));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (fields, host_prefix) in &parsed_entries {
|
||||||
|
let has_cookies = fields.has_cookies.unwrap_or(false);
|
||||||
|
let has_referer = fields
|
||||||
|
.referer
|
||||||
|
.as_ref()
|
||||||
|
.map(|r| r != "-" && !r.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let has_accept_language = fields
|
||||||
|
.accept_language
|
||||||
|
.as_ref()
|
||||||
|
.map(|a| a != "-" && !a.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let feats = features::extract_features(
|
||||||
|
&fields.method,
|
||||||
|
&fields.path,
|
||||||
|
host_prefix,
|
||||||
|
has_cookies,
|
||||||
|
has_referer,
|
||||||
|
has_accept_language,
|
||||||
|
"-",
|
||||||
|
&fields.user_agent,
|
||||||
|
fields.content_length,
|
||||||
|
&fragment_hashes,
|
||||||
|
&extension_hashes,
|
||||||
|
&log_hosts,
|
||||||
|
);
|
||||||
|
|
||||||
|
let label = if let Some(ref gt) = fields.label {
|
||||||
|
match gt.as_str() {
|
||||||
|
"attack" | "anomalous" => Some(1.0),
|
||||||
|
"normal" => Some(0.0),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
label_request(
|
||||||
|
&fields.path,
|
||||||
|
has_cookies,
|
||||||
|
has_referer,
|
||||||
|
has_accept_language,
|
||||||
|
&fields.user_agent,
|
||||||
|
host_prefix,
|
||||||
|
fields.status,
|
||||||
|
&log_hosts,
|
||||||
|
&fragment_hashes,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(l) = label {
|
||||||
|
samples.push(LabeledSample {
|
||||||
|
features: feats,
|
||||||
|
label: l,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.csic {
|
||||||
|
let csic_entries = crate::scanner::csic::fetch_csic_dataset()?;
|
||||||
|
for (_, host_prefix) in &csic_entries {
|
||||||
|
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
|
||||||
|
}
|
||||||
|
for (fields, host_prefix) in &csic_entries {
|
||||||
|
let has_cookies = fields.has_cookies.unwrap_or(false);
|
||||||
|
let has_referer = fields
|
||||||
|
.referer
|
||||||
|
.as_ref()
|
||||||
|
.map(|r| r != "-" && !r.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let has_accept_language = fields
|
||||||
|
.accept_language
|
||||||
|
.as_ref()
|
||||||
|
.map(|a| a != "-" && !a.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let feats = features::extract_features(
|
||||||
|
&fields.method,
|
||||||
|
&fields.path,
|
||||||
|
host_prefix,
|
||||||
|
has_cookies,
|
||||||
|
has_referer,
|
||||||
|
has_accept_language,
|
||||||
|
"-",
|
||||||
|
&fields.user_agent,
|
||||||
|
fields.content_length,
|
||||||
|
&fragment_hashes,
|
||||||
|
&extension_hashes,
|
||||||
|
&log_hosts,
|
||||||
|
);
|
||||||
|
|
||||||
|
let label = match fields.label.as_deref() {
|
||||||
|
Some("attack" | "anomalous") => 1.0,
|
||||||
|
Some("normal") => 0.0,
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
samples.push(LabeledSample {
|
||||||
|
features: feats,
|
||||||
|
label,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(wordlist_dir) = &args.wordlists {
|
||||||
|
ingest_wordlists(
|
||||||
|
wordlist_dir,
|
||||||
|
&mut samples,
|
||||||
|
&mut fragments,
|
||||||
|
&fragment_hashes,
|
||||||
|
&extension_hashes,
|
||||||
|
&log_hosts,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if samples.is_empty() {
|
||||||
|
anyhow::bail!("no training samples found");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stratified 80/20 train/test split
|
||||||
|
let (train_samples, test_samples) = stratified_split(&mut samples, 0.8, 42);
|
||||||
|
|
||||||
|
// Compute normalization params from training set only
|
||||||
|
let train_feature_vecs: Vec<ScannerFeatureVector> =
|
||||||
|
train_samples.iter().map(|s| s.features).collect();
|
||||||
|
let norm_params = ScannerNormParams::from_data(&train_feature_vecs);
|
||||||
|
|
||||||
|
let train_normalized: Vec<ScannerFeatureVector> =
|
||||||
|
train_feature_vecs.iter().map(|v| norm_params.normalize(v)).collect();
|
||||||
|
|
||||||
|
// Train logistic regression with configurable params
|
||||||
|
let weights = train_logistic_regression_weighted(
|
||||||
|
&train_normalized,
|
||||||
|
&train_samples,
|
||||||
|
epochs,
|
||||||
|
learning_rate,
|
||||||
|
class_weight_multiplier,
|
||||||
|
);
|
||||||
|
|
||||||
|
let model = ScannerModel {
|
||||||
|
weights,
|
||||||
|
threshold: args.threshold,
|
||||||
|
norm_params: norm_params.clone(),
|
||||||
|
fragments,
|
||||||
|
};
|
||||||
|
|
||||||
|
let train_metrics = evaluate(&train_normalized, &train_samples, &weights, args.threshold);
|
||||||
|
|
||||||
|
let test_feature_vecs: Vec<ScannerFeatureVector> =
|
||||||
|
test_samples.iter().map(|s| s.features).collect();
|
||||||
|
let test_normalized: Vec<ScannerFeatureVector> =
|
||||||
|
test_feature_vecs.iter().map(|v| norm_params.normalize(v)).collect();
|
||||||
|
let test_metrics = evaluate(&test_normalized, &test_samples, &weights, args.threshold);
|
||||||
|
|
||||||
|
Ok(ScannerTrainResult {
|
||||||
|
model,
|
||||||
|
train_metrics,
|
||||||
|
test_metrics,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run(args: TrainScannerArgs) -> Result<()> {
|
pub fn run(args: TrainScannerArgs) -> Result<()> {
|
||||||
@@ -295,24 +504,28 @@ pub fn run(args: TrainScannerArgs) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Metrics {
|
pub struct Metrics {
|
||||||
tp: u32,
|
pub tp: u32,
|
||||||
fp: u32,
|
pub fp: u32,
|
||||||
tn: u32,
|
pub tn: u32,
|
||||||
fn_: u32,
|
pub fn_: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Metrics {
|
impl Metrics {
|
||||||
fn precision(&self) -> f64 {
|
pub fn precision(&self) -> f64 {
|
||||||
if self.tp + self.fp > 0 { self.tp as f64 / (self.tp + self.fp) as f64 } else { 0.0 }
|
if self.tp + self.fp > 0 { self.tp as f64 / (self.tp + self.fp) as f64 } else { 0.0 }
|
||||||
}
|
}
|
||||||
fn recall(&self) -> f64 {
|
pub fn recall(&self) -> f64 {
|
||||||
if self.tp + self.fn_ > 0 { self.tp as f64 / (self.tp + self.fn_) as f64 } else { 0.0 }
|
if self.tp + self.fn_ > 0 { self.tp as f64 / (self.tp + self.fn_) as f64 } else { 0.0 }
|
||||||
}
|
}
|
||||||
fn f1(&self) -> f64 {
|
pub fn f1(&self) -> f64 {
|
||||||
|
self.fbeta(1.0)
|
||||||
|
}
|
||||||
|
pub fn fbeta(&self, beta: f64) -> f64 {
|
||||||
let p = self.precision();
|
let p = self.precision();
|
||||||
let r = self.recall();
|
let r = self.recall();
|
||||||
if p + r > 0.0 { 2.0 * p * r / (p + r) } else { 0.0 }
|
let b2 = beta * beta;
|
||||||
|
if p + r > 0.0 { (1.0 + b2) * p * r / (b2 * p + r) } else { 0.0 }
|
||||||
}
|
}
|
||||||
fn print(&self, total: usize) {
|
fn print(&self, total: usize) {
|
||||||
let acc = (self.tp + self.tn) as f64 / total as f64 * 100.0;
|
let acc = (self.tp + self.tn) as f64 / total as f64 * 100.0;
|
||||||
@@ -457,6 +670,59 @@ fn train_logistic_regression(
|
|||||||
weights
|
weights
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Like `train_logistic_regression` but with configurable class_weight_multiplier.
|
||||||
|
/// A multiplier > 1.0 increases the weight of the minority (attack) class further.
|
||||||
|
fn train_logistic_regression_weighted(
|
||||||
|
normalized: &[ScannerFeatureVector],
|
||||||
|
samples: &[LabeledSample],
|
||||||
|
epochs: usize,
|
||||||
|
learning_rate: f64,
|
||||||
|
class_weight_multiplier: f64,
|
||||||
|
) -> [f64; NUM_SCANNER_WEIGHTS] {
|
||||||
|
let mut weights = [0.0f64; NUM_SCANNER_WEIGHTS];
|
||||||
|
let n = samples.len() as f64;
|
||||||
|
|
||||||
|
let n_attack = samples.iter().filter(|s| s.label > 0.5).count() as f64;
|
||||||
|
let n_normal = n - n_attack;
|
||||||
|
let (w_attack, w_normal) = if n_attack > 0.0 && n_normal > 0.0 {
|
||||||
|
(n / (2.0 * n_attack) * class_weight_multiplier, n / (2.0 * n_normal))
|
||||||
|
} else {
|
||||||
|
(1.0, 1.0)
|
||||||
|
};
|
||||||
|
|
||||||
|
for _epoch in 0..epochs {
|
||||||
|
let mut gradients = [0.0f64; NUM_SCANNER_WEIGHTS];
|
||||||
|
|
||||||
|
for (i, sample) in samples.iter().enumerate() {
|
||||||
|
let f = &normalized[i];
|
||||||
|
let mut z = weights[NUM_SCANNER_FEATURES + 2];
|
||||||
|
for j in 0..NUM_SCANNER_FEATURES {
|
||||||
|
z += weights[j] * f[j];
|
||||||
|
}
|
||||||
|
z += weights[12] * f[0] * (1.0 - f[3]);
|
||||||
|
z += weights[13] * (1.0 - f[9]) * (1.0 - f[5]);
|
||||||
|
|
||||||
|
let prediction = sigmoid(z);
|
||||||
|
let error = prediction - sample.label;
|
||||||
|
let cw = if sample.label > 0.5 { w_attack } else { w_normal };
|
||||||
|
let weighted_error = error * cw;
|
||||||
|
|
||||||
|
for j in 0..NUM_SCANNER_FEATURES {
|
||||||
|
gradients[j] += weighted_error * f[j];
|
||||||
|
}
|
||||||
|
gradients[12] += weighted_error * f[0] * (1.0 - f[3]);
|
||||||
|
gradients[13] += weighted_error * (1.0 - f[9]) * (1.0 - f[5]);
|
||||||
|
gradients[14] += weighted_error;
|
||||||
|
}
|
||||||
|
|
||||||
|
for j in 0..NUM_SCANNER_WEIGHTS {
|
||||||
|
weights[j] -= learning_rate * gradients[j] / n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
weights
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn label_request(
|
fn label_request(
|
||||||
path: &str,
|
path: &str,
|
||||||
|
|||||||
@@ -46,13 +46,14 @@ fn make_model(
|
|||||||
|
|
||||||
fn default_ddos_config() -> DDoSConfig {
|
fn default_ddos_config() -> DDoSConfig {
|
||||||
DDoSConfig {
|
DDoSConfig {
|
||||||
model_path: String::new(),
|
model_path: Some(String::new()),
|
||||||
k: 5,
|
k: 5,
|
||||||
threshold: 0.6,
|
threshold: 0.6,
|
||||||
window_secs: 60,
|
window_secs: 60,
|
||||||
window_capacity: 1000,
|
window_capacity: 1000,
|
||||||
min_events: 10,
|
min_events: 10,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
|
use_ensemble: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user