From 007865fbe7a6ebd60e8ff6e855d0f12390b342bd Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Tue, 10 Mar 2026 23:38:19 +0000 Subject: [PATCH] feat(ddos): add KNN-based DDoS detection module 14-feature vector extraction, KNN classifier using fnntw, per-IP sliding window aggregation, and heuristic auto-labeling for training. Includes replay subcommand for offline evaluation and integration tests. Signed-off-by: Sienna Meridian Satterwhite --- src/ddos/audit_log.rs | 83 +++++ src/ddos/detector.rs | 100 ++++++ src/ddos/features.rs | 467 +++++++++++++++++++++++++ src/ddos/mod.rs | 6 + src/ddos/model.rs | 168 +++++++++ src/ddos/replay.rs | 291 ++++++++++++++++ src/ddos/train.rs | 298 ++++++++++++++++ tests/ddos_test.rs | 776 ++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 2189 insertions(+) create mode 100644 src/ddos/audit_log.rs create mode 100644 src/ddos/detector.rs create mode 100644 src/ddos/features.rs create mode 100644 src/ddos/mod.rs create mode 100644 src/ddos/model.rs create mode 100644 src/ddos/replay.rs create mode 100644 src/ddos/train.rs create mode 100644 tests/ddos_test.rs diff --git a/src/ddos/audit_log.rs b/src/ddos/audit_log.rs new file mode 100644 index 0000000..e022b55 --- /dev/null +++ b/src/ddos/audit_log.rs @@ -0,0 +1,83 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct AuditLog { + pub timestamp: String, + pub fields: AuditFields, +} + +#[derive(Deserialize)] +pub struct AuditFields { + pub method: String, + pub host: String, + pub path: String, + pub client_ip: String, + #[serde(deserialize_with = "flexible_u16")] + pub status: u16, + #[serde(deserialize_with = "flexible_u64")] + pub duration_ms: u64, + #[serde(default)] + pub backend: String, + #[serde(default)] + pub content_length: u64, + #[serde(default = "default_ua")] + pub user_agent: String, + #[serde(default)] + pub query: String, + #[serde(default)] + pub has_cookies: Option, + #[serde(default)] + pub referer: Option, + #[serde(default)] + pub accept_language: Option, + /// Optional ground-truth label from external datasets (e.g. CSIC 2010). + /// Values: "attack", "normal". When present, trainers should use this + /// instead of heuristic labeling. + #[serde(default)] + pub label: Option, +} + +fn default_ua() -> String { + "-".to_string() +} + +pub fn flexible_u64<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> std::result::Result { + #[derive(Deserialize)] + #[serde(untagged)] + enum StringOrNum { + Num(u64), + Str(String), + } + match StringOrNum::deserialize(deserializer)? { + StringOrNum::Num(n) => Ok(n), + StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom), + } +} + +pub fn flexible_u16<'de, D: serde::Deserializer<'de>>( + deserializer: D, +) -> std::result::Result { + #[derive(Deserialize)] + #[serde(untagged)] + enum StringOrNum { + Num(u16), + Str(String), + } + match StringOrNum::deserialize(deserializer)? { + StringOrNum::Num(n) => Ok(n), + StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom), + } +} + +/// Strip the port suffix from a socket address string. +pub fn strip_port(addr: &str) -> &str { + if addr.starts_with('[') { + addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr) + } else if let Some(pos) = addr.rfind(':') { + &addr[..pos] + } else { + addr + } +} diff --git a/src/ddos/detector.rs b/src/ddos/detector.rs new file mode 100644 index 0000000..eedeba4 --- /dev/null +++ b/src/ddos/detector.rs @@ -0,0 +1,100 @@ +use crate::config::DDoSConfig; +use crate::ddos::features::{method_to_u8, IpState, RequestEvent}; +use crate::ddos::model::{DDoSAction, TrainedModel}; +use rustc_hash::FxHashMap; +use std::hash::{Hash, Hasher}; +use std::net::IpAddr; +use std::sync::RwLock; +use std::time::Instant; + +const NUM_SHARDS: usize = 256; + +pub struct DDoSDetector { + model: TrainedModel, + shards: Vec>>, + window_secs: u64, + window_capacity: usize, + min_events: usize, +} + +fn shard_index(ip: &IpAddr) -> usize { + let mut h = rustc_hash::FxHasher::default(); + ip.hash(&mut h); + h.finish() as usize % NUM_SHARDS +} + +impl DDoSDetector { + pub fn new(model: TrainedModel, config: &DDoSConfig) -> Self { + let shards = (0..NUM_SHARDS) + .map(|_| RwLock::new(FxHashMap::default())) + .collect(); + Self { + model, + shards, + window_secs: config.window_secs, + window_capacity: config.window_capacity, + min_events: config.min_events, + } + } + + /// Record an incoming request and classify the IP. + /// Called from request_filter (before upstream). + pub fn check( + &self, + ip: IpAddr, + method: &str, + path: &str, + host: &str, + user_agent: &str, + content_length: u64, + has_cookies: bool, + has_referer: bool, + has_accept_language: bool, + ) -> DDoSAction { + let event = RequestEvent { + timestamp: Instant::now(), + method: method_to_u8(method), + path_hash: fx_hash(path), + host_hash: fx_hash(host), + user_agent_hash: fx_hash(user_agent), + status: 0, + duration_ms: 0, + content_length: content_length.min(u32::MAX as u64) as u32, + has_cookies, + has_referer, + has_accept_language, + suspicious_path: crate::ddos::features::is_suspicious_path(path), + }; + + let idx = shard_index(&ip); + let mut shard = self.shards[idx].write().unwrap_or_else(|e| e.into_inner()); + let state = shard + .entry(ip) + .or_insert_with(|| IpState::new(self.window_capacity)); + state.push(event); + + if state.len() < self.min_events { + return DDoSAction::Allow; + } + + let features = state.extract_features(self.window_secs); + self.model.classify(&features) + } + + /// Feed response data back into the IP's event history. + /// Called from logging() after the response is sent. + pub fn record_response(&self, _ip: IpAddr, _status: u16, _duration_ms: u32) { + // Status/duration from check() are 0-initialized; the next request + // will have fresh data. This is intentionally a no-op for now. + } + + pub fn point_count(&self) -> usize { + self.model.point_count() + } +} + +fn fx_hash(s: &str) -> u64 { + let mut h = rustc_hash::FxHasher::default(); + s.hash(&mut h); + h.finish() +} diff --git a/src/ddos/features.rs b/src/ddos/features.rs new file mode 100644 index 0000000..60a2fc1 --- /dev/null +++ b/src/ddos/features.rs @@ -0,0 +1,467 @@ +use rustc_hash::FxHashSet; +use serde::{Deserialize, Serialize}; +use std::time::Instant; + +pub const NUM_FEATURES: usize = 14; +pub type FeatureVector = [f64; NUM_FEATURES]; + +#[derive(Clone)] +pub struct RequestEvent { + pub timestamp: Instant, + /// GET=0, POST=1, PUT=2, DELETE=3, HEAD=4, PATCH=5, OPTIONS=6, other=7 + pub method: u8, + pub path_hash: u64, + pub host_hash: u64, + pub user_agent_hash: u64, + pub status: u16, + pub duration_ms: u32, + pub content_length: u32, + pub has_cookies: bool, + pub has_referer: bool, + pub has_accept_language: bool, + pub suspicious_path: bool, +} + +/// Known-bad path fragments that scanners/bots probe for. +const SUSPICIOUS_FRAGMENTS: &[&str] = &[ + ".env", ".git/", ".git\\", ".bak", ".sql", ".tar", ".zip", + "wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc", + "phpinfo", "phpmyadmin", "php-info", ".php", + "cgi-bin", "shell", "eval-stdin", + "/vendor/", "/telescope/", "/actuator/", + "/.htaccess", "/.htpasswd", + "/debug/", "/config.", "/admin/", + "yarn.lock", "yarn-debug", "package.json", "composer.json", +]; + +pub fn is_suspicious_path(path: &str) -> bool { + let lower = path.to_ascii_lowercase(); + SUSPICIOUS_FRAGMENTS.iter().any(|f| lower.contains(f)) +} + +pub struct IpState { + events: Vec, + cursor: usize, + count: usize, + capacity: usize, +} + +impl IpState { + pub fn new(capacity: usize) -> Self { + Self { + events: Vec::with_capacity(capacity), + cursor: 0, + count: 0, + capacity, + } + } + + pub fn push(&mut self, event: RequestEvent) { + if self.events.len() < self.capacity { + self.events.push(event); + } else { + self.events[self.cursor] = event; + } + self.cursor = (self.cursor + 1) % self.capacity; + self.count += 1; + } + + pub fn len(&self) -> usize { + self.events.len() + } + + /// Prune events older than `window` from the logical view. + /// Returns a slice of active events (not necessarily contiguous in ring buffer, + /// so we collect into a Vec). + fn active_events(&self, window_secs: u64) -> Vec<&RequestEvent> { + let now = Instant::now(); + let cutoff = std::time::Duration::from_secs(window_secs); + self.events + .iter() + .filter(|e| now.duration_since(e.timestamp) <= cutoff) + .collect() + } + + pub fn extract_features(&self, window_secs: u64) -> FeatureVector { + let events = self.active_events(window_secs); + let n = events.len() as f64; + if n < 1.0 { + return [0.0; NUM_FEATURES]; + } + + // 0: request_rate (requests / window_secs) + let request_rate = n / window_secs as f64; + + // 1: unique_paths + let unique_paths = { + let mut set = FxHashSet::default(); + for e in &events { + set.insert(e.path_hash); + } + set.len() as f64 + }; + + // 2: unique_hosts + let unique_hosts = { + let mut set = FxHashSet::default(); + for e in &events { + set.insert(e.host_hash); + } + set.len() as f64 + }; + + // 3: error_rate (fraction of 4xx/5xx) + let errors = events.iter().filter(|e| e.status >= 400).count() as f64; + let error_rate = errors / n; + + // 4: avg_duration_ms + let avg_duration_ms = + events.iter().map(|e| e.duration_ms as f64).sum::() / n; + + // 5: method_entropy (Shannon entropy of method distribution) + let method_entropy = { + let mut counts = [0u32; 8]; + for e in &events { + counts[e.method as usize % 8] += 1; + } + let mut entropy = 0.0f64; + for &c in &counts { + if c > 0 { + let p = c as f64 / n; + entropy -= p * p.ln(); + } + } + entropy + }; + + // 6: burst_score (inverse mean inter-arrival time) + let burst_score = if events.len() >= 2 { + let mut timestamps: Vec = + events.iter().map(|e| e.timestamp).collect(); + timestamps.sort(); + let total_span = timestamps + .last() + .unwrap() + .duration_since(*timestamps.first().unwrap()) + .as_secs_f64(); + if total_span > 0.0 { + (events.len() - 1) as f64 / total_span + } else { + n // all events at same instant = maximum burstiness + } + } else { + 0.0 + }; + + // 7: path_repetition (ratio of most-repeated path to total) + let path_repetition = { + let mut counts = rustc_hash::FxHashMap::default(); + for e in &events { + *counts.entry(e.path_hash).or_insert(0u32) += 1; + } + let max_count = counts.values().copied().max().unwrap_or(0) as f64; + max_count / n + }; + + // 8: avg_content_length + let avg_content_length = + events.iter().map(|e| e.content_length as f64).sum::() / n; + + // 9: unique_user_agents + let unique_user_agents = { + let mut set = FxHashSet::default(); + for e in &events { + set.insert(e.user_agent_hash); + } + set.len() as f64 + }; + + // 10: cookie_ratio (fraction of requests that have cookies) + let cookie_ratio = + events.iter().filter(|e| e.has_cookies).count() as f64 / n; + + // 11: referer_ratio (fraction of requests with a referer) + let referer_ratio = + events.iter().filter(|e| e.has_referer).count() as f64 / n; + + // 12: accept_language_ratio (fraction with accept-language) + let accept_language_ratio = + events.iter().filter(|e| e.has_accept_language).count() as f64 / n; + + // 13: suspicious_path_ratio (fraction hitting known-bad paths) + let suspicious_path_ratio = + events.iter().filter(|e| e.suspicious_path).count() as f64 / n; + + [ + request_rate, + unique_paths, + unique_hosts, + error_rate, + avg_duration_ms, + method_entropy, + burst_score, + path_repetition, + avg_content_length, + unique_user_agents, + cookie_ratio, + referer_ratio, + accept_language_ratio, + suspicious_path_ratio, + ] + } +} + +pub fn method_to_u8(method: &str) -> u8 { + match method { + "GET" => 0, + "POST" => 1, + "PUT" => 2, + "DELETE" => 3, + "HEAD" => 4, + "PATCH" => 5, + "OPTIONS" => 6, + _ => 7, + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NormParams { + pub mins: [f64; NUM_FEATURES], + pub maxs: [f64; NUM_FEATURES], +} + +impl NormParams { + pub fn from_data(vectors: &[FeatureVector]) -> Self { + let mut mins = [f64::MAX; NUM_FEATURES]; + let mut maxs = [f64::MIN; NUM_FEATURES]; + for v in vectors { + for i in 0..NUM_FEATURES { + mins[i] = mins[i].min(v[i]); + maxs[i] = maxs[i].max(v[i]); + } + } + Self { mins, maxs } + } + + pub fn normalize(&self, v: &FeatureVector) -> FeatureVector { + let mut out = [0.0; NUM_FEATURES]; + for i in 0..NUM_FEATURES { + let range = self.maxs[i] - self.mins[i]; + out[i] = if range > 0.0 { + ((v[i] - self.mins[i]) / range).clamp(0.0, 1.0) + } else { + 0.0 + }; + } + out + } +} + +/// Feature extraction from parsed log entries (used by training pipeline). +/// Unlike IpState which uses Instant, this uses f64 timestamps from log parsing. +pub struct LogIpState { + pub timestamps: Vec, + pub methods: Vec, + pub path_hashes: Vec, + pub host_hashes: Vec, + pub user_agent_hashes: Vec, + pub statuses: Vec, + pub durations: Vec, + pub content_lengths: Vec, + pub has_cookies: Vec, + pub has_referer: Vec, + pub has_accept_language: Vec, + pub suspicious_paths: Vec, +} + +impl LogIpState { + pub fn new() -> Self { + Self { + timestamps: Vec::new(), + methods: Vec::new(), + path_hashes: Vec::new(), + host_hashes: Vec::new(), + user_agent_hashes: Vec::new(), + statuses: Vec::new(), + durations: Vec::new(), + content_lengths: Vec::new(), + has_cookies: Vec::new(), + has_referer: Vec::new(), + has_accept_language: Vec::new(), + suspicious_paths: Vec::new(), + } + } + + pub fn extract_features_for_window( + &self, + start: usize, + end: usize, + window_secs: f64, + ) -> FeatureVector { + let n = (end - start) as f64; + if n < 1.0 { + return [0.0; NUM_FEATURES]; + } + + let request_rate = n / window_secs; + + let unique_paths = { + let mut set = FxHashSet::default(); + for i in start..end { + set.insert(self.path_hashes[i]); + } + set.len() as f64 + }; + + let unique_hosts = { + let mut set = FxHashSet::default(); + for i in start..end { + set.insert(self.host_hashes[i]); + } + set.len() as f64 + }; + + let errors = self.statuses[start..end] + .iter() + .filter(|&&s| s >= 400) + .count() as f64; + let error_rate = errors / n; + + let avg_duration_ms = + self.durations[start..end].iter().map(|&d| d as f64).sum::() / n; + + let method_entropy = { + let mut counts = [0u32; 8]; + for i in start..end { + counts[self.methods[i] as usize % 8] += 1; + } + let mut entropy = 0.0f64; + for &c in &counts { + if c > 0 { + let p = c as f64 / n; + entropy -= p * p.ln(); + } + } + entropy + }; + + let burst_score = if (end - start) >= 2 { + let total_span = + self.timestamps[end - 1] - self.timestamps[start]; + if total_span > 0.0 { + (end - start - 1) as f64 / total_span + } else { + n + } + } else { + 0.0 + }; + + let path_repetition = { + let mut counts = rustc_hash::FxHashMap::default(); + for i in start..end { + *counts.entry(self.path_hashes[i]).or_insert(0u32) += 1; + } + let max_count = counts.values().copied().max().unwrap_or(0) as f64; + max_count / n + }; + + let avg_content_length = self.content_lengths[start..end] + .iter() + .map(|&c| c as f64) + .sum::() + / n; + + let unique_user_agents = { + let mut set = FxHashSet::default(); + for i in start..end { + set.insert(self.user_agent_hashes[i]); + } + set.len() as f64 + }; + + let cookie_ratio = + self.has_cookies[start..end].iter().filter(|&&v| v).count() as f64 / n; + let referer_ratio = + self.has_referer[start..end].iter().filter(|&&v| v).count() as f64 / n; + let accept_language_ratio = + self.has_accept_language[start..end].iter().filter(|&&v| v).count() as f64 / n; + let suspicious_path_ratio = + self.suspicious_paths[start..end].iter().filter(|&&v| v).count() as f64 / n; + + [ + request_rate, + unique_paths, + unique_hosts, + error_rate, + avg_duration_ms, + method_entropy, + burst_score, + path_repetition, + avg_content_length, + unique_user_agents, + cookie_ratio, + referer_ratio, + accept_language_ratio, + suspicious_path_ratio, + ] + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rustc_hash::FxHasher; + use std::hash::{Hash, Hasher}; + + fn fx(s: &str) -> u64 { + let mut h = FxHasher::default(); + s.hash(&mut h); + h.finish() + } + + #[test] + fn test_single_event_features() { + let mut state = IpState::new(100); + state.push(RequestEvent { + timestamp: Instant::now(), + method: 0, + path_hash: fx("/"), + host_hash: fx("example.com"), + user_agent_hash: fx("curl/7.0"), + status: 200, + duration_ms: 10, + content_length: 0, + has_cookies: true, + has_referer: false, + has_accept_language: true, + suspicious_path: false, + }); + let features = state.extract_features(60); + // request_rate = 1/60 + assert!(features[0] > 0.0); + // error_rate = 0 + assert_eq!(features[3], 0.0); + // path_repetition = 1.0 (only one path) + assert_eq!(features[7], 1.0); + // cookie_ratio = 1.0 (single event with cookies) + assert_eq!(features[10], 1.0); + // referer_ratio = 0.0 + assert_eq!(features[11], 0.0); + // accept_language_ratio = 1.0 + assert_eq!(features[12], 1.0); + // suspicious_path_ratio = 0.0 + assert_eq!(features[13], 0.0); + } + + #[test] + fn test_norm_params() { + let data = vec![[0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 20.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]; + let params = NormParams::from_data(&data); + let normalized = params.normalize(&[0.5, 15.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]); + for &v in &normalized { + assert!((v - 0.5).abs() < 1e-10); + } + } +} diff --git a/src/ddos/mod.rs b/src/ddos/mod.rs new file mode 100644 index 0000000..93f324d --- /dev/null +++ b/src/ddos/mod.rs @@ -0,0 +1,6 @@ +pub mod audit_log; +pub mod detector; +pub mod features; +pub mod model; +pub mod replay; +pub mod train; diff --git a/src/ddos/model.rs b/src/ddos/model.rs new file mode 100644 index 0000000..a5c5496 --- /dev/null +++ b/src/ddos/model.rs @@ -0,0 +1,168 @@ +use crate::ddos::features::{FeatureVector, NormParams, NUM_FEATURES}; +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum TrafficLabel { + Normal, + Attack, +} + +#[derive(Serialize, Deserialize)] +pub struct SerializedModel { + pub points: Vec, + pub labels: Vec, + pub norm_params: NormParams, + pub k: usize, + pub threshold: f64, +} + +pub struct TrainedModel { + /// Stored points (normalized). The kD-tree borrows these. + points: Vec<[f64; NUM_FEATURES]>, + labels: Vec, + norm_params: NormParams, + k: usize, + threshold: f64, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DDoSAction { + Allow, + Block, +} + +impl TrainedModel { + pub fn load(path: &Path, k_override: Option, threshold_override: Option) -> Result { + let data = std::fs::read(path) + .with_context(|| format!("reading model from {}", path.display()))?; + let model: SerializedModel = + bincode::deserialize(&data).context("deserializing model")?; + Ok(Self { + points: model.points, + labels: model.labels, + norm_params: model.norm_params, + k: k_override.unwrap_or(model.k), + threshold: threshold_override.unwrap_or(model.threshold), + }) + } + + pub fn from_serialized(model: SerializedModel) -> Self { + Self { + points: model.points, + labels: model.labels, + norm_params: model.norm_params, + k: model.k, + threshold: model.threshold, + } + } + + pub fn classify(&self, features: &FeatureVector) -> DDoSAction { + let normalized = self.norm_params.normalize(features); + + if self.points.is_empty() { + return DDoSAction::Allow; + } + + // Build tree on-the-fly for query. In production with many queries, + // we'd cache this, but the tree build is fast for <100K points. + // fnntw::Tree borrows data, so we build it here. + let tree = match fnntw::Tree::<'_, f64, NUM_FEATURES>::new(&self.points, 32) { + Ok(t) => t, + Err(_) => return DDoSAction::Allow, + }; + + let k = self.k.min(self.points.len()); + let result = tree.query_nearest_k(&normalized, k); + match result { + Ok((_distances, indices)) => { + let attack_count = indices + .iter() + .filter(|&&idx| self.labels[idx as usize] == TrafficLabel::Attack) + .count(); + let attack_frac = attack_count as f64 / k as f64; + if attack_frac >= self.threshold { + DDoSAction::Block + } else { + DDoSAction::Allow + } + } + Err(_) => DDoSAction::Allow, + } + } + + pub fn norm_params(&self) -> &NormParams { + &self.norm_params + } + + pub fn point_count(&self) -> usize { + self.points.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_classify_empty_model() { + let model = TrainedModel { + points: vec![], + labels: vec![], + norm_params: NormParams { + mins: [0.0; NUM_FEATURES], + maxs: [1.0; NUM_FEATURES], + }, + k: 5, + threshold: 0.6, + }; + assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow); + } + + fn make_test_points(n: usize) -> Vec { + (0..n) + .map(|i| { + let mut v = [0.0; NUM_FEATURES]; + for d in 0..NUM_FEATURES { + v[d] = ((i * (d + 1)) as f64 / n as f64) % 1.0; + } + v + }) + .collect() + } + + #[test] + fn test_classify_all_attack() { + let points = make_test_points(100); + let labels = vec![TrafficLabel::Attack; 100]; + let model = TrainedModel { + points, + labels, + norm_params: NormParams { + mins: [0.0; NUM_FEATURES], + maxs: [1.0; NUM_FEATURES], + }, + k: 5, + threshold: 0.6, + }; + assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Block); + } + + #[test] + fn test_classify_all_normal() { + let points = make_test_points(100); + let labels = vec![TrafficLabel::Normal; 100]; + let model = TrainedModel { + points, + labels, + norm_params: NormParams { + mins: [0.0; NUM_FEATURES], + maxs: [1.0; NUM_FEATURES], + }, + k: 5, + threshold: 0.6, + }; + assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow); + } +} diff --git a/src/ddos/replay.rs b/src/ddos/replay.rs new file mode 100644 index 0000000..7e816b9 --- /dev/null +++ b/src/ddos/replay.rs @@ -0,0 +1,291 @@ +use crate::config::{DDoSConfig, RateLimitConfig}; +use crate::ddos::audit_log::{self, AuditLog}; +use crate::ddos::detector::DDoSDetector; +use crate::ddos::model::{DDoSAction, TrainedModel}; +use crate::rate_limit::key::RateLimitKey; +use crate::rate_limit::limiter::{RateLimitResult, RateLimiter}; +use anyhow::{Context, Result}; +use rustc_hash::FxHashMap; +use std::io::BufRead; +use std::net::IpAddr; +use std::sync::Arc; + +pub struct ReplayArgs { + pub input: String, + pub model_path: String, + pub config_path: Option, + pub k: usize, + pub threshold: f64, + pub window_secs: u64, + pub min_events: usize, + pub rate_limit: bool, +} + +struct ReplayStats { + total: u64, + skipped: u64, + ddos_blocked: u64, + rate_limited: u64, + allowed: u64, + ddos_blocked_ips: FxHashMap, + rate_limited_ips: FxHashMap, +} + +pub fn run(args: ReplayArgs) -> Result<()> { + eprintln!("Loading model from {}...", args.model_path); + let model = TrainedModel::load( + std::path::Path::new(&args.model_path), + Some(args.k), + Some(args.threshold), + ) + .with_context(|| format!("loading model from {}", args.model_path))?; + eprintln!(" {} training points, k={}, threshold={}", model.point_count(), args.k, args.threshold); + + let ddos_cfg = DDoSConfig { + model_path: args.model_path.clone(), + k: args.k, + threshold: args.threshold, + window_secs: args.window_secs, + window_capacity: 1000, + min_events: args.min_events, + enabled: true, + }; + let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg)); + + // Optionally set up rate limiter + let rate_limiter = if args.rate_limit { + let rl_cfg = if let Some(cfg_path) = &args.config_path { + let cfg = crate::config::Config::load(cfg_path)?; + cfg.rate_limit.unwrap_or_else(default_rate_limit_config) + } else { + default_rate_limit_config() + }; + eprintln!( + " Rate limiter: auth burst={} rate={}/s, unauth burst={} rate={}/s", + rl_cfg.authenticated.burst, + rl_cfg.authenticated.rate, + rl_cfg.unauthenticated.burst, + rl_cfg.unauthenticated.rate, + ); + Some(RateLimiter::new(&rl_cfg)) + } else { + None + }; + + eprintln!("Replaying {}...\n", args.input); + + let file = std::fs::File::open(&args.input) + .with_context(|| format!("opening {}", args.input))?; + let reader = std::io::BufReader::new(file); + + let mut stats = ReplayStats { + total: 0, + skipped: 0, + ddos_blocked: 0, + rate_limited: 0, + allowed: 0, + ddos_blocked_ips: FxHashMap::default(), + rate_limited_ips: FxHashMap::default(), + }; + + for line in reader.lines() { + let line = line?; + let entry: AuditLog = match serde_json::from_str(&line) { + Ok(e) => e, + Err(_) => { + stats.skipped += 1; + continue; + } + }; + + if entry.fields.method.is_empty() { + stats.skipped += 1; + continue; + } + + stats.total += 1; + + let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string(); + let ip: IpAddr = match ip_str.parse() { + Ok(ip) => ip, + Err(_) => { + stats.skipped += 1; + continue; + } + }; + + // DDoS check + let has_cookies = entry.fields.has_cookies.unwrap_or(false); + let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false); + let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false); + let ddos_action = detector.check( + ip, + &entry.fields.method, + &entry.fields.path, + &entry.fields.host, + &entry.fields.user_agent, + entry.fields.content_length, + has_cookies, + has_referer, + has_accept_language, + ); + + if ddos_action == DDoSAction::Block { + stats.ddos_blocked += 1; + *stats.ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1; + continue; + } + + // Rate limit check + if let Some(limiter) = &rate_limiter { + // Audit logs don't have auth headers, so all traffic is keyed by IP + let rl_key = RateLimitKey::Ip(ip); + if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) { + stats.rate_limited += 1; + *stats.rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1; + continue; + } + } + + stats.allowed += 1; + } + + // Report + let total = stats.total; + eprintln!("═══ Replay Results ═══════════════════════════════════════"); + eprintln!(" Total requests: {total}"); + eprintln!(" Skipped (parse): {}", stats.skipped); + eprintln!(" Allowed: {} ({:.1}%)", stats.allowed, pct(stats.allowed, total)); + eprintln!(" DDoS blocked: {} ({:.1}%)", stats.ddos_blocked, pct(stats.ddos_blocked, total)); + if rate_limiter.is_some() { + eprintln!(" Rate limited: {} ({:.1}%)", stats.rate_limited, pct(stats.rate_limited, total)); + } + + if !stats.ddos_blocked_ips.is_empty() { + eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────"); + let mut sorted: Vec<_> = stats.ddos_blocked_ips.iter().collect(); + sorted.sort_by(|a, b| b.1.cmp(a.1)); + for (ip, count) in sorted.iter().take(20) { + eprintln!(" {:<40} {} reqs blocked", ip, count); + } + } + + if !stats.rate_limited_ips.is_empty() { + eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────"); + let mut sorted: Vec<_> = stats.rate_limited_ips.iter().collect(); + sorted.sort_by(|a, b| b.1.cmp(a.1)); + for (ip, count) in sorted.iter().take(20) { + eprintln!(" {:<40} {} reqs limited", ip, count); + } + } + + // Check for false positives: IPs that were blocked but had 2xx statuses in the original logs + eprintln!("\n── False positive check ──────────────────────────────────"); + check_false_positives(&args.input, &stats)?; + + eprintln!("══════════════════════════════════════════════════════════"); + Ok(()) +} + +/// Re-scan the log to find blocked IPs that had mostly 2xx responses originally +/// (i.e. they were legitimate traffic that the model would incorrectly block). +fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> { + let blocked_ips: rustc_hash::FxHashSet<&str> = stats + .ddos_blocked_ips + .keys() + .chain(stats.rate_limited_ips.keys()) + .map(|s| s.as_str()) + .collect(); + + if blocked_ips.is_empty() { + eprintln!(" No blocked IPs — nothing to check."); + return Ok(()); + } + + // Collect original status codes for blocked IPs + let file = std::fs::File::open(input)?; + let reader = std::io::BufReader::new(file); + let mut ip_statuses: FxHashMap> = FxHashMap::default(); + + for line in reader.lines() { + let line = line?; + let entry: AuditLog = match serde_json::from_str(&line) { + Ok(e) => e, + Err(_) => continue, + }; + let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string(); + if blocked_ips.contains(ip_str.as_str()) { + ip_statuses + .entry(ip_str) + .or_default() + .push(entry.fields.status); + } + } + + let mut suspects = Vec::new(); + for (ip, statuses) in &ip_statuses { + let total = statuses.len(); + let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count(); + let ok_pct = (ok_count as f64 / total as f64) * 100.0; + // If >60% of original responses were 2xx/3xx, this might be a false positive + if ok_pct > 60.0 { + let blocked = stats + .ddos_blocked_ips + .get(ip) + .copied() + .unwrap_or(0) + + stats + .rate_limited_ips + .get(ip) + .copied() + .unwrap_or(0); + suspects.push((ip.clone(), total, ok_pct, blocked)); + } + } + + if suspects.is_empty() { + eprintln!(" No likely false positives found."); + } else { + suspects.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap()); + eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses:", suspects.len()); + for (ip, total, ok_pct, blocked) in suspects.iter().take(15) { + eprintln!( + " {:<40} {}/{} reqs were 2xx/3xx ({:.0}%), {} blocked", + ip, ((*ok_pct / 100.0) * *total as f64) as u64, total, ok_pct, blocked, + ); + } + } + + Ok(()) +} + +fn default_rate_limit_config() -> RateLimitConfig { + RateLimitConfig { + enabled: true, + bypass_cidrs: vec![ + "10.0.0.0/8".into(), + "172.16.0.0/12".into(), + "192.168.0.0/16".into(), + "100.64.0.0/10".into(), + "fd00::/8".into(), + ], + eviction_interval_secs: 300, + stale_after_secs: 600, + authenticated: crate::config::BucketConfig { + burst: 200, + rate: 50.0, + }, + unauthenticated: crate::config::BucketConfig { + burst: 60, + rate: 15.0, + }, + } +} + +fn pct(n: u64, total: u64) -> f64 { + if total == 0 { + 0.0 + } else { + (n as f64 / total as f64) * 100.0 + } +} diff --git a/src/ddos/train.rs b/src/ddos/train.rs new file mode 100644 index 0000000..99d6860 --- /dev/null +++ b/src/ddos/train.rs @@ -0,0 +1,298 @@ +use crate::ddos::audit_log::AuditLog; +use crate::ddos::audit_log; +use crate::ddos::features::{method_to_u8, FeatureVector, LogIpState, NormParams, NUM_FEATURES}; +use crate::ddos::model::{SerializedModel, TrafficLabel}; +use anyhow::{bail, Context, Result}; +use rustc_hash::{FxHashMap, FxHashSet}; +use serde::Deserialize; +use std::hash::{Hash, Hasher}; +use std::io::BufRead; + +#[derive(Deserialize)] +pub struct HeuristicThresholds { + /// Requests/second above which an IP is labeled attack + #[serde(default = "default_rate_threshold")] + pub request_rate: f64, + /// Path repetition ratio above which an IP is labeled attack + #[serde(default = "default_repetition_threshold")] + pub path_repetition: f64, + /// Error rate above which an IP is labeled attack + #[serde(default = "default_error_threshold")] + pub error_rate: f64, + /// Suspicious path ratio above which an IP is labeled attack + #[serde(default = "default_suspicious_path_threshold")] + pub suspicious_path_ratio: f64, + /// Cookie ratio below which (combined with high unique paths) labels attack + #[serde(default = "default_no_cookies_threshold")] + pub no_cookies_threshold: f64, + /// Unique path count above which no-cookie traffic is labeled attack + #[serde(default = "default_no_cookies_path_count")] + pub no_cookies_path_count: f64, + /// Minimum events to consider an IP for labeling + #[serde(default = "default_min_events")] + pub min_events: usize, +} + +fn default_rate_threshold() -> f64 { 10.0 } +fn default_repetition_threshold() -> f64 { 0.9 } +fn default_error_threshold() -> f64 { 0.7 } +fn default_suspicious_path_threshold() -> f64 { 0.3 } +fn default_no_cookies_threshold() -> f64 { 0.05 } +fn default_no_cookies_path_count() -> f64 { 20.0 } +fn default_min_events() -> usize { 10 } + +pub struct TrainArgs { + pub input: String, + pub output: String, + pub attack_ips: Option, + pub normal_ips: Option, + pub heuristics: Option, + pub k: usize, + pub threshold: f64, + pub window_secs: u64, + pub min_events: usize, +} + +fn fx_hash(s: &str) -> u64 { + let mut h = rustc_hash::FxHasher::default(); + s.hash(&mut h); + h.finish() +} + +fn parse_timestamp(ts: &str) -> f64 { + // Parse ISO 8601 timestamp to seconds since epoch (approximate). + // We only need relative ordering within a log file. + // Format: "2026-03-07T17:41:40.705326Z" + let parts: Vec<&str> = ts.split('T').collect(); + if parts.len() != 2 { + return 0.0; + } + let date_parts: Vec<&str> = parts[0].split('-').collect(); + let time_str = parts[1].trim_end_matches('Z'); + let time_parts: Vec<&str> = time_str.split(':').collect(); + if date_parts.len() != 3 || time_parts.len() != 3 { + return 0.0; + } + let day: f64 = date_parts[2].parse().unwrap_or(0.0); + let hour: f64 = time_parts[0].parse().unwrap_or(0.0); + let min: f64 = time_parts[1].parse().unwrap_or(0.0); + let sec: f64 = time_parts[2].parse().unwrap_or(0.0); + // Relative seconds (day * 86400 + time) + day * 86400.0 + hour * 3600.0 + min * 60.0 + sec +} + + +pub fn run(args: TrainArgs) -> Result<()> { + eprintln!("Parsing logs from {}...", args.input); + + // Parse logs into per-IP state + let mut ip_states: FxHashMap = FxHashMap::default(); + let file = std::fs::File::open(&args.input) + .with_context(|| format!("opening {}", args.input))?; + let reader = std::io::BufReader::new(file); + + let mut total_lines = 0u64; + let mut parse_errors = 0u64; + + for line in reader.lines() { + let line = line?; + total_lines += 1; + let entry: AuditLog = match serde_json::from_str(&line) { + Ok(e) => e, + Err(_) => { + parse_errors += 1; + continue; + } + }; + + // Skip non-audit entries + if entry.fields.method.is_empty() { + continue; + } + + let ip = audit_log::strip_port(&entry.fields.client_ip).to_string(); + let ts = parse_timestamp(&entry.timestamp); + + let state = ip_states.entry(ip).or_insert_with(LogIpState::new); + state.timestamps.push(ts); + state.methods.push(method_to_u8(&entry.fields.method)); + state.path_hashes.push(fx_hash(&entry.fields.path)); + state.host_hashes.push(fx_hash(&entry.fields.host)); + state + .user_agent_hashes + .push(fx_hash(&entry.fields.user_agent)); + state.statuses.push(entry.fields.status); + state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32); + state + .content_lengths + .push(entry.fields.content_length.min(u32::MAX as u64) as u32); + state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false)); + state.has_referer.push( + entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false), + ); + state.has_accept_language.push( + entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false), + ); + state.suspicious_paths.push( + crate::ddos::features::is_suspicious_path(&entry.fields.path), + ); + } + + eprintln!( + "Parsed {} lines ({} errors), {} unique IPs", + total_lines, + parse_errors, + ip_states.len() + ); + + // Extract feature vectors per IP (using sliding windows) + let window_secs = args.window_secs as f64; + let mut ip_features: FxHashMap> = FxHashMap::default(); + + for (ip, state) in &ip_states { + let n = state.timestamps.len(); + if n < args.min_events { + continue; + } + // Extract one feature vector per window + let mut features = Vec::new(); + let mut start = 0; + for end in 1..n { + let span = state.timestamps[end] - state.timestamps[start]; + if span >= window_secs || end == n - 1 { + let fv = state.extract_features_for_window(start, end + 1, window_secs); + features.push(fv); + start = end + 1; + } + } + if !features.is_empty() { + ip_features.insert(ip.clone(), features); + } + } + + // Label IPs + let mut ip_labels: FxHashMap = FxHashMap::default(); + + if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) { + // IP list mode + let attack_ips: FxHashSet = std::fs::read_to_string(attack_file) + .context("reading attack IPs file")? + .lines() + .map(|l| l.trim().to_string()) + .filter(|l| !l.is_empty()) + .collect(); + let normal_ips: FxHashSet = std::fs::read_to_string(normal_file) + .context("reading normal IPs file")? + .lines() + .map(|l| l.trim().to_string()) + .filter(|l| !l.is_empty()) + .collect(); + + for ip in ip_features.keys() { + if attack_ips.contains(ip) { + ip_labels.insert(ip.clone(), TrafficLabel::Attack); + } else if normal_ips.contains(ip) { + ip_labels.insert(ip.clone(), TrafficLabel::Normal); + } + } + } else if let Some(heuristics_file) = &args.heuristics { + // Heuristic auto-labeling + let heuristics_str = std::fs::read_to_string(heuristics_file) + .context("reading heuristics file")?; + let thresholds: HeuristicThresholds = + toml::from_str(&heuristics_str).context("parsing heuristics TOML")?; + + for (ip, features) in &ip_features { + // Use the aggregate (last/max) feature vector for labeling + let avg = average_features(features); + let is_attack = avg[0] > thresholds.request_rate // request_rate + || avg[7] > thresholds.path_repetition // path_repetition + || avg[3] > thresholds.error_rate // error_rate + || avg[13] > thresholds.suspicious_path_ratio // suspicious_path_ratio + || (avg[10] < thresholds.no_cookies_threshold // no cookies + high unique paths + && avg[1] > thresholds.no_cookies_path_count); + ip_labels.insert( + ip.clone(), + if is_attack { + TrafficLabel::Attack + } else { + TrafficLabel::Normal + }, + ); + } + } else { + bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling"); + } + + // Build training dataset + let mut all_points: Vec = Vec::new(); + let mut all_labels: Vec = Vec::new(); + + for (ip, features) in &ip_features { + if let Some(&label) = ip_labels.get(ip) { + for fv in features { + all_points.push(*fv); + all_labels.push(label); + } + } + } + + if all_points.is_empty() { + bail!("No labeled data points found. Check your IP lists or heuristic thresholds."); + } + + let attack_count = all_labels + .iter() + .filter(|&&l| l == TrafficLabel::Attack) + .count(); + let normal_count = all_labels.len() - attack_count; + eprintln!( + "Training with {} points ({} attack, {} normal)", + all_points.len(), + attack_count, + normal_count + ); + + // Normalize + let norm_params = NormParams::from_data(&all_points); + let normalized: Vec = all_points + .iter() + .map(|v| norm_params.normalize(v)) + .collect(); + + // Serialize + let model = SerializedModel { + points: normalized, + labels: all_labels, + norm_params, + k: args.k, + threshold: args.threshold, + }; + + let encoded = bincode::serialize(&model).context("serializing model")?; + std::fs::write(&args.output, &encoded) + .with_context(|| format!("writing model to {}", args.output))?; + + eprintln!( + "Model saved to {} ({} bytes, {} points)", + args.output, + encoded.len(), + model.points.len() + ); + + Ok(()) +} + +fn average_features(features: &[FeatureVector]) -> FeatureVector { + let n = features.len() as f64; + let mut avg = [0.0; NUM_FEATURES]; + for fv in features { + for i in 0..NUM_FEATURES { + avg[i] += fv[i]; + } + } + for v in &mut avg { + *v /= n; + } + avg +} diff --git a/tests/ddos_test.rs b/tests/ddos_test.rs new file mode 100644 index 0000000..dc9daf6 --- /dev/null +++ b/tests/ddos_test.rs @@ -0,0 +1,776 @@ +//! Extensive DDoS detection tests. +//! +//! These tests build realistic traffic profiles — normal browsing, API usage, +//! webhook bursts, etc. — and verify the model never blocks legitimate traffic. +//! Attack scenarios are also tested to confirm blocking works. + +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +use sunbeam_proxy::config::DDoSConfig; +use sunbeam_proxy::ddos::detector::DDoSDetector; +use sunbeam_proxy::ddos::features::{NormParams, NUM_FEATURES}; +use sunbeam_proxy::ddos::model::{DDoSAction, SerializedModel, TrafficLabel, TrainedModel}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Build a model from explicit normal/attack feature vectors. +fn make_model( + normal: &[[f64; NUM_FEATURES]], + attack: &[[f64; NUM_FEATURES]], + k: usize, + threshold: f64, +) -> TrainedModel { + let mut points = Vec::new(); + let mut labels = Vec::new(); + for v in normal { + points.push(*v); + labels.push(TrafficLabel::Normal); + } + for v in attack { + points.push(*v); + labels.push(TrafficLabel::Attack); + } + let norm_params = NormParams::from_data(&points); + let normalized: Vec<[f64; NUM_FEATURES]> = + points.iter().map(|v| norm_params.normalize(v)).collect(); + TrainedModel::from_serialized(SerializedModel { + points: normalized, + labels, + norm_params, + k, + threshold, + }) +} + +fn default_ddos_config() -> DDoSConfig { + DDoSConfig { + model_path: String::new(), + k: 5, + threshold: 0.6, + window_secs: 60, + window_capacity: 1000, + min_events: 10, + enabled: true, + } +} + +fn make_detector(model: TrainedModel, min_events: usize) -> DDoSDetector { + let mut cfg = default_ddos_config(); + cfg.min_events = min_events; + DDoSDetector::new(model, &cfg) +} + +/// Feature vector indices (matching features.rs order): +/// 0: request_rate (requests / window_secs) +/// 1: unique_paths (count of distinct paths) +/// 2: unique_hosts (count of distinct hosts) +/// 3: error_rate (fraction 4xx/5xx) +/// 4: avg_duration_ms (mean response time) +/// 5: method_entropy (Shannon entropy of methods) +/// 6: burst_score (inverse mean inter-arrival) +/// 7: path_repetition (most-repeated path / total) +/// 8: avg_content_length (mean body size) +/// 9: unique_user_agents (count of distinct UAs) +/// 10: cookie_ratio (fraction with cookies) +/// 11: referer_ratio (fraction with referer) +/// 12: accept_language_ratio (fraction with accept-language) +/// 13: suspicious_path_ratio (fraction hitting known-bad paths) +/// 9: unique_user_agents (count of distinct UAs) + +// Realistic normal traffic profiles +fn normal_browser_browsing() -> [f64; NUM_FEATURES] { + // A human browsing a site: ~0.5 req/s, many paths, 1 host, low errors, + // ~150ms avg latency, mostly GET, moderate spacing, diverse paths, no body, 1 UA + // cookies=yes, referer=sometimes, accept-lang=yes, suspicious=no + [0.5, 12.0, 1.0, 0.02, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0] +} + +fn normal_api_client() -> [f64; NUM_FEATURES] { + // Backend API client: ~2 req/s, hits a few endpoints, 1 host, ~5% errors (retries), + // ~50ms latency, mix of GET/POST, steady rate, some path repetition, small bodies, 1 UA + // cookies=yes (session), referer=no, accept-lang=no, suspicious=no + [2.0, 5.0, 1.0, 0.05, 50.0, 0.69, 2.5, 0.4, 512.0, 1.0, 1.0, 0.0, 0.0, 0.0] +} + +fn normal_webhook_burst() -> [f64; NUM_FEATURES] { + // CI/CD or webhook burst: ~10 req/s for a short period, 1-2 paths, 1 host, + // 0% errors, fast responses, all POST, bursty, high path repetition, medium bodies, 1 UA + // cookies=no (machine), referer=no, accept-lang=no, suspicious=no + [10.0, 2.0, 1.0, 0.0, 25.0, 0.0, 12.0, 0.8, 2048.0, 1.0, 0.0, 0.0, 0.0, 0.0] +} + +fn normal_health_check() -> [f64; NUM_FEATURES] { + // Health check probe: ~0.2 req/s, 1 path, 1 host, 0% errors, ~5ms latency, + // all GET, very regular, 100% same path, no body, 1 UA + // cookies=no (probe), referer=no, accept-lang=no, suspicious=no + [0.2, 1.0, 1.0, 0.0, 5.0, 0.0, 0.2, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] +} + +fn normal_mobile_app() -> [f64; NUM_FEATURES] { + // Mobile app: ~1 req/s, several API endpoints, 1 host, ~3% errors, + // ~200ms latency (mobile network), GET + POST, moderate spacing, moderate repetition, + // small-medium bodies, 1 UA + // cookies=yes, referer=no, accept-lang=yes, suspicious=no + [1.0, 8.0, 1.0, 0.03, 200.0, 0.5, 1.2, 0.25, 256.0, 1.0, 1.0, 0.0, 1.0, 0.0] +} + +fn normal_search_crawler() -> [f64; NUM_FEATURES] { + // Googlebot-style crawler: ~0.3 req/s, many unique paths, 1 host, ~10% 404s, + // ~300ms latency, all GET, slow steady rate, diverse paths, no body, 1 UA + // cookies=no (crawler), referer=no, accept-lang=no, suspicious=no + [0.3, 20.0, 1.0, 0.1, 300.0, 0.0, 0.35, 0.08, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] +} + +fn normal_graphql_spa() -> [f64; NUM_FEATURES] { + // SPA hitting a GraphQL endpoint: ~3 req/s, 1 path (/graphql), 1 host, ~1% errors, + // ~80ms latency, all POST, steady, 100% same path, medium bodies, 1 UA + // cookies=yes, referer=yes (SPA nav), accept-lang=yes, suspicious=no + [3.0, 1.0, 1.0, 0.01, 80.0, 0.0, 3.5, 1.0, 1024.0, 1.0, 1.0, 1.0, 1.0, 0.0] +} + +fn normal_websocket_upgrade() -> [f64; NUM_FEATURES] { + // Initial HTTP requests before WS upgrade: ~0.1 req/s, 2 paths, 1 host, 0% errors, + // ~10ms latency, GET, slow, some repetition, no body, 1 UA + // cookies=yes, referer=yes, accept-lang=yes, suspicious=no + [0.1, 2.0, 1.0, 0.0, 10.0, 0.0, 0.1, 0.5, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0] +} + +fn normal_file_upload() -> [f64; NUM_FEATURES] { + // File upload session: ~0.5 req/s, 3 paths (upload, status, confirm), 1 host, + // 0% errors, ~500ms latency (large bodies), POST + GET, steady, moderate repetition, + // large bodies, 1 UA + // cookies=yes, referer=yes, accept-lang=yes, suspicious=no + [0.5, 3.0, 1.0, 0.0, 500.0, 0.69, 0.6, 0.5, 1_000_000.0, 1.0, 1.0, 1.0, 1.0, 0.0] +} + +fn normal_multi_tenant_api() -> [f64; NUM_FEATURES] { + // API client hitting multiple hosts (multi-tenant): ~1.5 req/s, 4 paths, 3 hosts, + // ~2% errors, ~100ms latency, GET + POST, steady, low repetition, small bodies, 1 UA + // cookies=yes, referer=no, accept-lang=no, suspicious=no + [1.5, 4.0, 3.0, 0.02, 100.0, 0.69, 1.8, 0.3, 128.0, 1.0, 1.0, 0.0, 0.0, 0.0] +} + +// Realistic attack traffic profiles +fn attack_path_scan() -> [f64; NUM_FEATURES] { + // WordPress/PHP scanner: ~20 req/s, many unique paths, 1 host, 100% 404s, + // ~2ms latency (all errors), all GET, very bursty, all unique paths, no body, 1 UA + // cookies=no, referer=no, accept-lang=no, suspicious=0.8 (most paths are probes) + [20.0, 50.0, 1.0, 1.0, 2.0, 0.0, 25.0, 0.02, 0.0, 1.0, 0.0, 0.0, 0.0, 0.8] +} + +fn attack_credential_stuffing() -> [f64; NUM_FEATURES] { + // Login brute-force: ~30 req/s, 1 path (/login), 1 host, 95% 401/403, + // ~10ms latency, all POST, very bursty, 100% same path, small bodies, 1 UA + // cookies=no, referer=no, accept-lang=no, suspicious=0.0 (/login is not in suspicious list) + [30.0, 1.0, 1.0, 0.95, 10.0, 0.0, 35.0, 1.0, 64.0, 1.0, 0.0, 0.0, 0.0, 0.0] +} + +fn attack_slowloris() -> [f64; NUM_FEATURES] { + // Slowloris-style: ~0.5 req/s (slow), 1 path, 1 host, 0% errors (connections held), + // ~30000ms latency (!), all GET, slow, 100% same path, huge content-length, 1 UA + // cookies=no, referer=no, accept-lang=no, suspicious=0.0 + [0.5, 1.0, 1.0, 0.0, 30000.0, 0.0, 0.5, 1.0, 10_000_000.0, 1.0, 0.0, 0.0, 0.0, 0.0] +} + +fn attack_ua_rotation() -> [f64; NUM_FEATURES] { + // Bot rotating user-agents: ~15 req/s, 2 paths, 1 host, 80% errors, + // ~5ms latency, GET + POST, bursty, high repetition, no body, 50 distinct UAs + // cookies=no, referer=no, accept-lang=no, suspicious=0.3 + [15.0, 2.0, 1.0, 0.8, 5.0, 0.69, 18.0, 0.7, 0.0, 50.0, 0.0, 0.0, 0.0, 0.3] +} + +fn attack_host_scan() -> [f64; NUM_FEATURES] { + // Virtual host enumeration: ~25 req/s, 1 path (/), many hosts, 100% errors, + // ~1ms latency, all GET, very bursty, 100% same path, no body, 1 UA + // cookies=no, referer=no, accept-lang=no, suspicious=0.0 + [25.0, 1.0, 40.0, 1.0, 1.0, 0.0, 30.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] +} + +fn attack_api_fuzzing() -> [f64; NUM_FEATURES] { + // API fuzzer: ~50 req/s, many paths, 1 host, 90% errors (bad inputs), + // ~3ms latency, mixed methods, extremely bursty, low repetition, varied bodies, 1 UA + // cookies=no, referer=no, accept-lang=no, suspicious=0.5 + [50.0, 100.0, 1.0, 0.9, 3.0, 1.5, 55.0, 0.01, 4096.0, 1.0, 0.0, 0.0, 0.0, 0.5] +} + +fn all_normal_profiles() -> Vec<[f64; NUM_FEATURES]> { + vec![ + normal_browser_browsing(), + normal_api_client(), + normal_webhook_burst(), + normal_health_check(), + normal_mobile_app(), + normal_search_crawler(), + normal_graphql_spa(), + normal_websocket_upgrade(), + normal_file_upload(), + normal_multi_tenant_api(), + ] +} + +fn all_attack_profiles() -> Vec<[f64; NUM_FEATURES]> { + vec![ + attack_path_scan(), + attack_credential_stuffing(), + attack_slowloris(), + attack_ua_rotation(), + attack_host_scan(), + attack_api_fuzzing(), + ] +} + +/// Build a model from realistic profiles, with each profile replicated `copies` +/// times (with slight jitter) to give the KNN enough neighbors. +fn make_realistic_model(k: usize, threshold: f64) -> TrainedModel { + let mut normal = Vec::new(); + let mut attack = Vec::new(); + + // Replicate each profile with small perturbations + for base in all_normal_profiles() { + for i in 0..20 { + let mut v = base; + for d in 0..NUM_FEATURES { + // ±5% jitter + let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05); + v[d] *= jitter; + } + normal.push(v); + } + } + for base in all_attack_profiles() { + for i in 0..20 { + let mut v = base; + for d in 0..NUM_FEATURES { + let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05); + v[d] *= jitter; + } + attack.push(v); + } + } + + make_model(&normal, &attack, k, threshold) +} + +// =========================================================================== +// Model classification tests — normal profiles must NEVER be blocked +// =========================================================================== + +#[test] +fn normal_browser_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow); +} + +#[test] +fn normal_api_client_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow); +} + +#[test] +fn normal_webhook_burst_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&normal_webhook_burst()), DDoSAction::Allow); +} + +#[test] +fn normal_health_check_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&normal_health_check()), DDoSAction::Allow); +} + +#[test] +fn normal_mobile_app_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&normal_mobile_app()), DDoSAction::Allow); +} + +#[test] +fn normal_search_crawler_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&normal_search_crawler()), DDoSAction::Allow); +} + +#[test] +fn normal_graphql_spa_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&normal_graphql_spa()), DDoSAction::Allow); +} + +#[test] +fn normal_websocket_upgrade_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!( + model.classify(&normal_websocket_upgrade()), + DDoSAction::Allow + ); +} + +#[test] +fn normal_file_upload_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&normal_file_upload()), DDoSAction::Allow); +} + +#[test] +fn normal_multi_tenant_api_is_allowed() { + let model = make_realistic_model(5, 0.6); + assert_eq!( + model.classify(&normal_multi_tenant_api()), + DDoSAction::Allow + ); +} + +// =========================================================================== +// Model classification tests — attack profiles must be blocked +// =========================================================================== + +#[test] +fn attack_path_scan_is_blocked() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Block); +} + +#[test] +fn attack_credential_stuffing_is_blocked() { + let model = make_realistic_model(5, 0.6); + assert_eq!( + model.classify(&attack_credential_stuffing()), + DDoSAction::Block + ); +} + +#[test] +fn attack_slowloris_is_blocked() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&attack_slowloris()), DDoSAction::Block); +} + +#[test] +fn attack_ua_rotation_is_blocked() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&attack_ua_rotation()), DDoSAction::Block); +} + +#[test] +fn attack_host_scan_is_blocked() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&attack_host_scan()), DDoSAction::Block); +} + +#[test] +fn attack_api_fuzzing_is_blocked() { + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&attack_api_fuzzing()), DDoSAction::Block); +} + +// =========================================================================== +// Edge cases: normal traffic that LOOKS suspicious but isn't +// =========================================================================== + +#[test] +fn high_rate_legitimate_cdn_prefetch_is_allowed() { + // CDN prefetch: high rate but low errors, diverse paths, normal latency + let model = make_realistic_model(5, 0.6); + let profile: [f64; NUM_FEATURES] = + [8.0, 15.0, 1.0, 0.0, 100.0, 0.0, 9.0, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]; + assert_eq!(model.classify(&profile), DDoSAction::Allow); +} + +#[test] +fn single_path_api_polling_is_allowed() { + // Long-poll or SSE endpoint: single path, 100% repetition, but low rate, no errors + let model = make_realistic_model(5, 0.6); + let profile: [f64; NUM_FEATURES] = + [0.3, 1.0, 1.0, 0.0, 1000.0, 0.0, 0.3, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]; + assert_eq!(model.classify(&profile), DDoSAction::Allow); +} + +#[test] +fn moderate_error_rate_during_deploy_is_allowed() { + // During a rolling deploy, error rate spikes to ~20% temporarily + let model = make_realistic_model(5, 0.6); + let profile: [f64; NUM_FEATURES] = + [1.0, 5.0, 1.0, 0.2, 200.0, 0.5, 1.2, 0.3, 128.0, 1.0, 1.0, 0.3, 1.0, 0.0]; + assert_eq!(model.classify(&profile), DDoSAction::Allow); +} + +#[test] +fn burst_of_form_submissions_is_allowed() { + // Marketing event → users submit forms rapidly: high rate, single path, all POST, no errors + let model = make_realistic_model(5, 0.6); + let profile: [f64; NUM_FEATURES] = + [5.0, 1.0, 1.0, 0.0, 80.0, 0.0, 6.0, 1.0, 512.0, 1.0, 1.0, 1.0, 1.0, 0.0]; + assert_eq!(model.classify(&profile), DDoSAction::Allow); +} + +#[test] +fn legitimate_load_test_with_varied_paths_is_allowed() { + // Internal load test: high rate but diverse paths, low error, real latency + let model = make_realistic_model(5, 0.6); + let profile: [f64; NUM_FEATURES] = + [8.0, 30.0, 1.0, 0.02, 120.0, 0.69, 10.0, 0.05, 256.0, 1.0, 0.0, 0.0, 0.0, 0.0]; + assert_eq!(model.classify(&profile), DDoSAction::Allow); +} + +// =========================================================================== +// Threshold and k sensitivity +// =========================================================================== + +#[test] +fn higher_threshold_is_more_permissive() { + // With threshold=0.9, even borderline traffic should be allowed + let model = make_realistic_model(5, 0.9); + // A profile that's borderline between attack and normal + let borderline: [f64; NUM_FEATURES] = + [12.0, 8.0, 1.0, 0.5, 20.0, 0.5, 14.0, 0.5, 100.0, 2.0, 0.0, 0.0, 0.0, 0.1]; + assert_eq!(model.classify(&borderline), DDoSAction::Allow); +} + +#[test] +fn larger_k_smooths_classification() { + // With larger k, noisy outliers matter less + let model_k3 = make_realistic_model(3, 0.6); + let model_k9 = make_realistic_model(9, 0.6); + // Normal traffic should be allowed by both + let profile = normal_browser_browsing(); + assert_eq!(model_k3.classify(&profile), DDoSAction::Allow); + assert_eq!(model_k9.classify(&profile), DDoSAction::Allow); +} + +// =========================================================================== +// Normalization tests +// =========================================================================== + +#[test] +fn normalization_clamps_out_of_range() { + let params = NormParams { + mins: [0.0; NUM_FEATURES], + maxs: [1.0; NUM_FEATURES], + }; + // Values above max should clamp to 1.0 + let above = [2.0; NUM_FEATURES]; + let normed = params.normalize(&above); + for &v in &normed { + assert_eq!(v, 1.0); + } + // Values below min should clamp to 0.0 + let below = [-1.0; NUM_FEATURES]; + let normed = params.normalize(&below); + for &v in &normed { + assert_eq!(v, 0.0); + } +} + +#[test] +fn normalization_handles_zero_range() { + // When all training data has the same value for a feature, range = 0 + let params = NormParams { + mins: [5.0; NUM_FEATURES], + maxs: [5.0; NUM_FEATURES], + }; + let v = [5.0; NUM_FEATURES]; + let normed = params.normalize(&v); + for &val in &normed { + assert_eq!(val, 0.0); + } +} + +#[test] +fn normalization_preserves_midpoint() { + let params = NormParams { + mins: [0.0; NUM_FEATURES], + maxs: [100.0; NUM_FEATURES], + }; + let v = [50.0; NUM_FEATURES]; + let normed = params.normalize(&v); + for &val in &normed { + assert!((val - 0.5).abs() < 1e-10); + } +} + +#[test] +fn norm_params_from_data_finds_extremes() { + let data = vec![ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0], + [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.5, 0.5, 0.5, 0.5], + ]; + let params = NormParams::from_data(&data); + for i in 0..NUM_FEATURES { + assert!(params.mins[i] <= params.maxs[i]); + } + assert_eq!(params.mins[0], 1.0); + assert_eq!(params.maxs[0], 10.0); +} + +// =========================================================================== +// Serialization round-trip +// =========================================================================== + +#[test] +fn model_serialization_roundtrip() { + // Use the realistic model (200+ points) so fnntw has enough data for the kD-tree + let model = make_realistic_model(3, 0.5); + + // Rebuild from the same training data + let mut all_points = Vec::new(); + let mut all_labels = Vec::new(); + for base in all_normal_profiles() { + for i in 0..20 { + let mut v = base; + for d in 0..NUM_FEATURES { + let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05); + v[d] *= jitter; + } + all_points.push(v); + all_labels.push(TrafficLabel::Normal); + } + } + for base in all_attack_profiles() { + for i in 0..20 { + let mut v = base; + for d in 0..NUM_FEATURES { + let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05); + v[d] *= jitter; + } + all_points.push(v); + all_labels.push(TrafficLabel::Attack); + } + } + + let norm_params = NormParams::from_data(&all_points); + let serialized = SerializedModel { + points: all_points.iter().map(|v| norm_params.normalize(v)).collect(), + labels: all_labels, + norm_params, + k: 3, + threshold: 0.5, + }; + + let encoded = bincode::serialize(&serialized).unwrap(); + let decoded: SerializedModel = bincode::deserialize(&encoded).unwrap(); + + assert_eq!(decoded.points.len(), serialized.points.len()); + assert_eq!(decoded.labels.len(), serialized.labels.len()); + assert_eq!(decoded.k, 3); + assert!((decoded.threshold - 0.5).abs() < 1e-10); + + // Rebuilt model should classify the same + let rebuilt = TrainedModel::from_serialized(decoded); + assert_eq!( + rebuilt.classify(&normal_browser_browsing()), + model.classify(&normal_browser_browsing()) + ); +} + +// =========================================================================== +// Detector integration tests (full check() pipeline) +// =========================================================================== + +#[test] +fn detector_allows_below_min_events() { + let model = make_realistic_model(5, 0.6); + let detector = make_detector(model, 10); + + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); + // Send 9 requests — below min_events threshold of 10 + for _ in 0..9 { + let action = detector.check(ip, "GET", "/wp-admin", "evil.com", "bot", 0, false, false, false); + assert_eq!(action, DDoSAction::Allow, "should allow below min_events"); + } +} + +#[test] +fn detector_ipv4_and_ipv6_tracked_separately() { + let model = make_realistic_model(5, 0.6); + let detector = make_detector(model, 3); + + let v4 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); + let v6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)); + + // Send events to v4 only + for _ in 0..5 { + detector.check(v4, "GET", "/", "example.com", "Mozilla/5.0", 0, true, false, true); + } + + // v6 should still have 0 events (below min_events) + let action = detector.check(v6, "GET", "/", "example.com", "Mozilla/5.0", 0, true, false, true); + assert_eq!(action, DDoSAction::Allow); +} + +#[test] +fn detector_normal_browsing_pattern_is_allowed() { + let model = make_realistic_model(5, 0.6); + let detector = make_detector(model, 5); + + let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)); + let paths = ["/", "/about", "/products", "/products/1", "/contact", + "/blog", "/blog/post-1", "/docs", "/pricing", "/login", + "/dashboard", "/settings", "/api/me"]; + let ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)"; + + for (i, path) in paths.iter().enumerate() { + let method = if i % 5 == 0 { "POST" } else { "GET" }; + let action = detector.check(ip, method, path, "mysite.com", ua, 0, true, true, true); + // After min_events, every check should still allow normal browsing + assert_eq!( + action, + DDoSAction::Allow, + "normal browsing blocked on request #{i} to {path}" + ); + } +} + +#[test] +fn detector_handles_concurrent_ips() { + let model = make_realistic_model(5, 0.6); + let detector = make_detector(model, 5); + + // Simulate 50 distinct IPs each making a few normal requests + for i in 0..50u8 { + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, i)); + let paths = ["/", "/about", "/products", "/contact", "/blog", + "/docs", "/api/status"]; + for path in &paths { + let action = detector.check(ip, "GET", path, "example.com", "Chrome", 0, true, false, true); + assert_eq!(action, DDoSAction::Allow, + "IP 10.0.0.{i} blocked on {path}"); + } + } +} + +#[test] +fn detector_ipv6_normal_traffic_is_allowed() { + let model = make_realistic_model(5, 0.6); + let detector = make_detector(model, 5); + + let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42)); + let paths = ["/", "/about", "/products", "/blog", "/contact", + "/login", "/dashboard"]; + for path in &paths { + let action = detector.check(ip, "GET", path, "example.com", + "Mozilla/5.0", 0, true, false, true); + assert_eq!(action, DDoSAction::Allow, + "IPv6 normal traffic blocked on {path}"); + } +} + +// =========================================================================== +// Model robustness: slight variations of normal traffic +// =========================================================================== + +#[test] +fn slightly_elevated_rate_still_allowed() { + let model = make_realistic_model(5, 0.6); + // 2x normal browsing rate — busy but not attacking + let profile: [f64; NUM_FEATURES] = + [1.0, 12.0, 1.0, 0.02, 150.0, 0.2, 1.2, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0]; + assert_eq!(model.classify(&profile), DDoSAction::Allow); +} + +#[test] +fn slightly_elevated_errors_still_allowed() { + // 15% errors (e.g. some 404s from broken links) — normal for real sites + let model = make_realistic_model(5, 0.6); + let profile: [f64; NUM_FEATURES] = + [0.5, 10.0, 1.0, 0.15, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.3, 1.0, 0.0]; + assert_eq!(model.classify(&profile), DDoSAction::Allow); +} + +#[test] +fn zero_traffic_features_allowed() { + // Edge case: all zeros (shouldn't happen in practice, but must not crash or block) + let model = make_realistic_model(5, 0.6); + assert_eq!(model.classify(&[0.0; NUM_FEATURES]), DDoSAction::Allow); +} + +#[test] +fn empty_model_always_allows() { + let model = TrainedModel::from_serialized(SerializedModel { + points: vec![], + labels: vec![], + norm_params: NormParams { + mins: [0.0; NUM_FEATURES], + maxs: [1.0; NUM_FEATURES], + }, + k: 5, + threshold: 0.6, + }); + // Must allow everything — no training data to compare against + assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow); + assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow); +} + +#[test] +fn all_normal_model_allows_everything() { + // A model trained only on normal data (no attack points) should never block. + // Use enough points (200) so fnntw can build the kD-tree. + let mut normal = Vec::new(); + for base in all_normal_profiles() { + for i in 0..20 { + let mut v = base; + for d in 0..NUM_FEATURES { + let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05); + v[d] *= jitter; + } + normal.push(v); + } + } + let model = make_model(&normal, &[], 5, 0.6); + assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow); + assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow); + // Even attack-like traffic is allowed since the model has no attack examples + assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow); +} + +// =========================================================================== +// Feature extraction tests +// =========================================================================== + +#[test] +fn method_entropy_zero_for_single_method() { + // All GET requests → method distribution is [1.0, 0, 0, ...] → entropy = 0 + let model = make_realistic_model(5, 0.6); + let profile = normal_health_check(); // all GET + assert_eq!(profile[5], 0.0); // method_entropy + assert_eq!(model.classify(&profile), DDoSAction::Allow); +} + +#[test] +fn method_entropy_positive_for_mixed_methods() { + let profile = normal_api_client(); // mix of GET/POST + assert!(profile[5] > 0.0, "method_entropy should be positive for mixed methods"); +} + +#[test] +fn path_repetition_is_one_for_single_path() { + let profile = normal_graphql_spa(); // single /graphql endpoint + assert_eq!(profile[7], 1.0); +} + +#[test] +fn path_repetition_is_low_for_diverse_paths() { + let profile = normal_search_crawler(); // many unique paths + assert!(profile[7] < 0.2); +} + +// =========================================================================== +// Load the real trained model and validate against known profiles +// =========================================================================== + +#[test] +fn real_model_file_roundtrip() { + let model_path = std::path::Path::new("ddos_model.bin"); + if !model_path.exists() { + // Skip if no model file present (CI environments) + eprintln!("skipping real_model_file_roundtrip: ddos_model.bin not found"); + return; + } + let model = TrainedModel::load(model_path, Some(3), Some(0.5)).unwrap(); + assert!(model.point_count() > 0, "model should have training points"); + // Smoke test: classifying shouldn't panic + let _ = model.classify(&normal_browser_browsing()); + let _ = model.classify(&attack_path_scan()); +}