feat(ddos): add KNN-based DDoS detection module
14-feature vector extraction, KNN classifier using fnntw, per-IP sliding window aggregation, and heuristic auto-labeling for training. Includes replay subcommand for offline evaluation and integration tests. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
83
src/ddos/audit_log.rs
Normal file
83
src/ddos/audit_log.rs
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct AuditLog {
|
||||||
|
pub timestamp: String,
|
||||||
|
pub fields: AuditFields,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct AuditFields {
|
||||||
|
pub method: String,
|
||||||
|
pub host: String,
|
||||||
|
pub path: String,
|
||||||
|
pub client_ip: String,
|
||||||
|
#[serde(deserialize_with = "flexible_u16")]
|
||||||
|
pub status: u16,
|
||||||
|
#[serde(deserialize_with = "flexible_u64")]
|
||||||
|
pub duration_ms: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
pub backend: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub content_length: u64,
|
||||||
|
#[serde(default = "default_ua")]
|
||||||
|
pub user_agent: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub query: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub has_cookies: Option<bool>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub referer: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub accept_language: Option<String>,
|
||||||
|
/// Optional ground-truth label from external datasets (e.g. CSIC 2010).
|
||||||
|
/// Values: "attack", "normal". When present, trainers should use this
|
||||||
|
/// instead of heuristic labeling.
|
||||||
|
#[serde(default)]
|
||||||
|
pub label: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ua() -> String {
|
||||||
|
"-".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flexible_u64<'de, D: serde::Deserializer<'de>>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> std::result::Result<u64, D::Error> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum StringOrNum {
|
||||||
|
Num(u64),
|
||||||
|
Str(String),
|
||||||
|
}
|
||||||
|
match StringOrNum::deserialize(deserializer)? {
|
||||||
|
StringOrNum::Num(n) => Ok(n),
|
||||||
|
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flexible_u16<'de, D: serde::Deserializer<'de>>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> std::result::Result<u16, D::Error> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum StringOrNum {
|
||||||
|
Num(u16),
|
||||||
|
Str(String),
|
||||||
|
}
|
||||||
|
match StringOrNum::deserialize(deserializer)? {
|
||||||
|
StringOrNum::Num(n) => Ok(n),
|
||||||
|
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Strip the port suffix from a socket address string.
|
||||||
|
pub fn strip_port(addr: &str) -> &str {
|
||||||
|
if addr.starts_with('[') {
|
||||||
|
addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr)
|
||||||
|
} else if let Some(pos) = addr.rfind(':') {
|
||||||
|
&addr[..pos]
|
||||||
|
} else {
|
||||||
|
addr
|
||||||
|
}
|
||||||
|
}
|
||||||
100
src/ddos/detector.rs
Normal file
100
src/ddos/detector.rs
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
use crate::config::DDoSConfig;
|
||||||
|
use crate::ddos::features::{method_to_u8, IpState, RequestEvent};
|
||||||
|
use crate::ddos::model::{DDoSAction, TrainedModel};
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::RwLock;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
const NUM_SHARDS: usize = 256;
|
||||||
|
|
||||||
|
pub struct DDoSDetector {
|
||||||
|
model: TrainedModel,
|
||||||
|
shards: Vec<RwLock<FxHashMap<IpAddr, IpState>>>,
|
||||||
|
window_secs: u64,
|
||||||
|
window_capacity: usize,
|
||||||
|
min_events: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shard_index(ip: &IpAddr) -> usize {
|
||||||
|
let mut h = rustc_hash::FxHasher::default();
|
||||||
|
ip.hash(&mut h);
|
||||||
|
h.finish() as usize % NUM_SHARDS
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DDoSDetector {
|
||||||
|
pub fn new(model: TrainedModel, config: &DDoSConfig) -> Self {
|
||||||
|
let shards = (0..NUM_SHARDS)
|
||||||
|
.map(|_| RwLock::new(FxHashMap::default()))
|
||||||
|
.collect();
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
shards,
|
||||||
|
window_secs: config.window_secs,
|
||||||
|
window_capacity: config.window_capacity,
|
||||||
|
min_events: config.min_events,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record an incoming request and classify the IP.
|
||||||
|
/// Called from request_filter (before upstream).
|
||||||
|
pub fn check(
|
||||||
|
&self,
|
||||||
|
ip: IpAddr,
|
||||||
|
method: &str,
|
||||||
|
path: &str,
|
||||||
|
host: &str,
|
||||||
|
user_agent: &str,
|
||||||
|
content_length: u64,
|
||||||
|
has_cookies: bool,
|
||||||
|
has_referer: bool,
|
||||||
|
has_accept_language: bool,
|
||||||
|
) -> DDoSAction {
|
||||||
|
let event = RequestEvent {
|
||||||
|
timestamp: Instant::now(),
|
||||||
|
method: method_to_u8(method),
|
||||||
|
path_hash: fx_hash(path),
|
||||||
|
host_hash: fx_hash(host),
|
||||||
|
user_agent_hash: fx_hash(user_agent),
|
||||||
|
status: 0,
|
||||||
|
duration_ms: 0,
|
||||||
|
content_length: content_length.min(u32::MAX as u64) as u32,
|
||||||
|
has_cookies,
|
||||||
|
has_referer,
|
||||||
|
has_accept_language,
|
||||||
|
suspicious_path: crate::ddos::features::is_suspicious_path(path),
|
||||||
|
};
|
||||||
|
|
||||||
|
let idx = shard_index(&ip);
|
||||||
|
let mut shard = self.shards[idx].write().unwrap_or_else(|e| e.into_inner());
|
||||||
|
let state = shard
|
||||||
|
.entry(ip)
|
||||||
|
.or_insert_with(|| IpState::new(self.window_capacity));
|
||||||
|
state.push(event);
|
||||||
|
|
||||||
|
if state.len() < self.min_events {
|
||||||
|
return DDoSAction::Allow;
|
||||||
|
}
|
||||||
|
|
||||||
|
let features = state.extract_features(self.window_secs);
|
||||||
|
self.model.classify(&features)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feed response data back into the IP's event history.
|
||||||
|
/// Called from logging() after the response is sent.
|
||||||
|
pub fn record_response(&self, _ip: IpAddr, _status: u16, _duration_ms: u32) {
|
||||||
|
// Status/duration from check() are 0-initialized; the next request
|
||||||
|
// will have fresh data. This is intentionally a no-op for now.
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn point_count(&self) -> usize {
|
||||||
|
self.model.point_count()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fx_hash(s: &str) -> u64 {
|
||||||
|
let mut h = rustc_hash::FxHasher::default();
|
||||||
|
s.hash(&mut h);
|
||||||
|
h.finish()
|
||||||
|
}
|
||||||
467
src/ddos/features.rs
Normal file
467
src/ddos/features.rs
Normal file
@@ -0,0 +1,467 @@
|
|||||||
|
use rustc_hash::FxHashSet;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
pub const NUM_FEATURES: usize = 14;
|
||||||
|
pub type FeatureVector = [f64; NUM_FEATURES];
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct RequestEvent {
|
||||||
|
pub timestamp: Instant,
|
||||||
|
/// GET=0, POST=1, PUT=2, DELETE=3, HEAD=4, PATCH=5, OPTIONS=6, other=7
|
||||||
|
pub method: u8,
|
||||||
|
pub path_hash: u64,
|
||||||
|
pub host_hash: u64,
|
||||||
|
pub user_agent_hash: u64,
|
||||||
|
pub status: u16,
|
||||||
|
pub duration_ms: u32,
|
||||||
|
pub content_length: u32,
|
||||||
|
pub has_cookies: bool,
|
||||||
|
pub has_referer: bool,
|
||||||
|
pub has_accept_language: bool,
|
||||||
|
pub suspicious_path: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Known-bad path fragments that scanners/bots probe for.
|
||||||
|
const SUSPICIOUS_FRAGMENTS: &[&str] = &[
|
||||||
|
".env", ".git/", ".git\\", ".bak", ".sql", ".tar", ".zip",
|
||||||
|
"wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc",
|
||||||
|
"phpinfo", "phpmyadmin", "php-info", ".php",
|
||||||
|
"cgi-bin", "shell", "eval-stdin",
|
||||||
|
"/vendor/", "/telescope/", "/actuator/",
|
||||||
|
"/.htaccess", "/.htpasswd",
|
||||||
|
"/debug/", "/config.", "/admin/",
|
||||||
|
"yarn.lock", "yarn-debug", "package.json", "composer.json",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub fn is_suspicious_path(path: &str) -> bool {
|
||||||
|
let lower = path.to_ascii_lowercase();
|
||||||
|
SUSPICIOUS_FRAGMENTS.iter().any(|f| lower.contains(f))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct IpState {
|
||||||
|
events: Vec<RequestEvent>,
|
||||||
|
cursor: usize,
|
||||||
|
count: usize,
|
||||||
|
capacity: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IpState {
|
||||||
|
pub fn new(capacity: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
events: Vec::with_capacity(capacity),
|
||||||
|
cursor: 0,
|
||||||
|
count: 0,
|
||||||
|
capacity,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, event: RequestEvent) {
|
||||||
|
if self.events.len() < self.capacity {
|
||||||
|
self.events.push(event);
|
||||||
|
} else {
|
||||||
|
self.events[self.cursor] = event;
|
||||||
|
}
|
||||||
|
self.cursor = (self.cursor + 1) % self.capacity;
|
||||||
|
self.count += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.events.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prune events older than `window` from the logical view.
|
||||||
|
/// Returns a slice of active events (not necessarily contiguous in ring buffer,
|
||||||
|
/// so we collect into a Vec).
|
||||||
|
fn active_events(&self, window_secs: u64) -> Vec<&RequestEvent> {
|
||||||
|
let now = Instant::now();
|
||||||
|
let cutoff = std::time::Duration::from_secs(window_secs);
|
||||||
|
self.events
|
||||||
|
.iter()
|
||||||
|
.filter(|e| now.duration_since(e.timestamp) <= cutoff)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_features(&self, window_secs: u64) -> FeatureVector {
|
||||||
|
let events = self.active_events(window_secs);
|
||||||
|
let n = events.len() as f64;
|
||||||
|
if n < 1.0 {
|
||||||
|
return [0.0; NUM_FEATURES];
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0: request_rate (requests / window_secs)
|
||||||
|
let request_rate = n / window_secs as f64;
|
||||||
|
|
||||||
|
// 1: unique_paths
|
||||||
|
let unique_paths = {
|
||||||
|
let mut set = FxHashSet::default();
|
||||||
|
for e in &events {
|
||||||
|
set.insert(e.path_hash);
|
||||||
|
}
|
||||||
|
set.len() as f64
|
||||||
|
};
|
||||||
|
|
||||||
|
// 2: unique_hosts
|
||||||
|
let unique_hosts = {
|
||||||
|
let mut set = FxHashSet::default();
|
||||||
|
for e in &events {
|
||||||
|
set.insert(e.host_hash);
|
||||||
|
}
|
||||||
|
set.len() as f64
|
||||||
|
};
|
||||||
|
|
||||||
|
// 3: error_rate (fraction of 4xx/5xx)
|
||||||
|
let errors = events.iter().filter(|e| e.status >= 400).count() as f64;
|
||||||
|
let error_rate = errors / n;
|
||||||
|
|
||||||
|
// 4: avg_duration_ms
|
||||||
|
let avg_duration_ms =
|
||||||
|
events.iter().map(|e| e.duration_ms as f64).sum::<f64>() / n;
|
||||||
|
|
||||||
|
// 5: method_entropy (Shannon entropy of method distribution)
|
||||||
|
let method_entropy = {
|
||||||
|
let mut counts = [0u32; 8];
|
||||||
|
for e in &events {
|
||||||
|
counts[e.method as usize % 8] += 1;
|
||||||
|
}
|
||||||
|
let mut entropy = 0.0f64;
|
||||||
|
for &c in &counts {
|
||||||
|
if c > 0 {
|
||||||
|
let p = c as f64 / n;
|
||||||
|
entropy -= p * p.ln();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entropy
|
||||||
|
};
|
||||||
|
|
||||||
|
// 6: burst_score (inverse mean inter-arrival time)
|
||||||
|
let burst_score = if events.len() >= 2 {
|
||||||
|
let mut timestamps: Vec<Instant> =
|
||||||
|
events.iter().map(|e| e.timestamp).collect();
|
||||||
|
timestamps.sort();
|
||||||
|
let total_span = timestamps
|
||||||
|
.last()
|
||||||
|
.unwrap()
|
||||||
|
.duration_since(*timestamps.first().unwrap())
|
||||||
|
.as_secs_f64();
|
||||||
|
if total_span > 0.0 {
|
||||||
|
(events.len() - 1) as f64 / total_span
|
||||||
|
} else {
|
||||||
|
n // all events at same instant = maximum burstiness
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
// 7: path_repetition (ratio of most-repeated path to total)
|
||||||
|
let path_repetition = {
|
||||||
|
let mut counts = rustc_hash::FxHashMap::default();
|
||||||
|
for e in &events {
|
||||||
|
*counts.entry(e.path_hash).or_insert(0u32) += 1;
|
||||||
|
}
|
||||||
|
let max_count = counts.values().copied().max().unwrap_or(0) as f64;
|
||||||
|
max_count / n
|
||||||
|
};
|
||||||
|
|
||||||
|
// 8: avg_content_length
|
||||||
|
let avg_content_length =
|
||||||
|
events.iter().map(|e| e.content_length as f64).sum::<f64>() / n;
|
||||||
|
|
||||||
|
// 9: unique_user_agents
|
||||||
|
let unique_user_agents = {
|
||||||
|
let mut set = FxHashSet::default();
|
||||||
|
for e in &events {
|
||||||
|
set.insert(e.user_agent_hash);
|
||||||
|
}
|
||||||
|
set.len() as f64
|
||||||
|
};
|
||||||
|
|
||||||
|
// 10: cookie_ratio (fraction of requests that have cookies)
|
||||||
|
let cookie_ratio =
|
||||||
|
events.iter().filter(|e| e.has_cookies).count() as f64 / n;
|
||||||
|
|
||||||
|
// 11: referer_ratio (fraction of requests with a referer)
|
||||||
|
let referer_ratio =
|
||||||
|
events.iter().filter(|e| e.has_referer).count() as f64 / n;
|
||||||
|
|
||||||
|
// 12: accept_language_ratio (fraction with accept-language)
|
||||||
|
let accept_language_ratio =
|
||||||
|
events.iter().filter(|e| e.has_accept_language).count() as f64 / n;
|
||||||
|
|
||||||
|
// 13: suspicious_path_ratio (fraction hitting known-bad paths)
|
||||||
|
let suspicious_path_ratio =
|
||||||
|
events.iter().filter(|e| e.suspicious_path).count() as f64 / n;
|
||||||
|
|
||||||
|
[
|
||||||
|
request_rate,
|
||||||
|
unique_paths,
|
||||||
|
unique_hosts,
|
||||||
|
error_rate,
|
||||||
|
avg_duration_ms,
|
||||||
|
method_entropy,
|
||||||
|
burst_score,
|
||||||
|
path_repetition,
|
||||||
|
avg_content_length,
|
||||||
|
unique_user_agents,
|
||||||
|
cookie_ratio,
|
||||||
|
referer_ratio,
|
||||||
|
accept_language_ratio,
|
||||||
|
suspicious_path_ratio,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn method_to_u8(method: &str) -> u8 {
|
||||||
|
match method {
|
||||||
|
"GET" => 0,
|
||||||
|
"POST" => 1,
|
||||||
|
"PUT" => 2,
|
||||||
|
"DELETE" => 3,
|
||||||
|
"HEAD" => 4,
|
||||||
|
"PATCH" => 5,
|
||||||
|
"OPTIONS" => 6,
|
||||||
|
_ => 7,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct NormParams {
|
||||||
|
pub mins: [f64; NUM_FEATURES],
|
||||||
|
pub maxs: [f64; NUM_FEATURES],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NormParams {
|
||||||
|
pub fn from_data(vectors: &[FeatureVector]) -> Self {
|
||||||
|
let mut mins = [f64::MAX; NUM_FEATURES];
|
||||||
|
let mut maxs = [f64::MIN; NUM_FEATURES];
|
||||||
|
for v in vectors {
|
||||||
|
for i in 0..NUM_FEATURES {
|
||||||
|
mins[i] = mins[i].min(v[i]);
|
||||||
|
maxs[i] = maxs[i].max(v[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self { mins, maxs }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize(&self, v: &FeatureVector) -> FeatureVector {
|
||||||
|
let mut out = [0.0; NUM_FEATURES];
|
||||||
|
for i in 0..NUM_FEATURES {
|
||||||
|
let range = self.maxs[i] - self.mins[i];
|
||||||
|
out[i] = if range > 0.0 {
|
||||||
|
((v[i] - self.mins[i]) / range).clamp(0.0, 1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature extraction from parsed log entries (used by training pipeline).
|
||||||
|
/// Unlike IpState which uses Instant, this uses f64 timestamps from log parsing.
|
||||||
|
pub struct LogIpState {
|
||||||
|
pub timestamps: Vec<f64>,
|
||||||
|
pub methods: Vec<u8>,
|
||||||
|
pub path_hashes: Vec<u64>,
|
||||||
|
pub host_hashes: Vec<u64>,
|
||||||
|
pub user_agent_hashes: Vec<u64>,
|
||||||
|
pub statuses: Vec<u16>,
|
||||||
|
pub durations: Vec<u32>,
|
||||||
|
pub content_lengths: Vec<u32>,
|
||||||
|
pub has_cookies: Vec<bool>,
|
||||||
|
pub has_referer: Vec<bool>,
|
||||||
|
pub has_accept_language: Vec<bool>,
|
||||||
|
pub suspicious_paths: Vec<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogIpState {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
timestamps: Vec::new(),
|
||||||
|
methods: Vec::new(),
|
||||||
|
path_hashes: Vec::new(),
|
||||||
|
host_hashes: Vec::new(),
|
||||||
|
user_agent_hashes: Vec::new(),
|
||||||
|
statuses: Vec::new(),
|
||||||
|
durations: Vec::new(),
|
||||||
|
content_lengths: Vec::new(),
|
||||||
|
has_cookies: Vec::new(),
|
||||||
|
has_referer: Vec::new(),
|
||||||
|
has_accept_language: Vec::new(),
|
||||||
|
suspicious_paths: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_features_for_window(
|
||||||
|
&self,
|
||||||
|
start: usize,
|
||||||
|
end: usize,
|
||||||
|
window_secs: f64,
|
||||||
|
) -> FeatureVector {
|
||||||
|
let n = (end - start) as f64;
|
||||||
|
if n < 1.0 {
|
||||||
|
return [0.0; NUM_FEATURES];
|
||||||
|
}
|
||||||
|
|
||||||
|
let request_rate = n / window_secs;
|
||||||
|
|
||||||
|
let unique_paths = {
|
||||||
|
let mut set = FxHashSet::default();
|
||||||
|
for i in start..end {
|
||||||
|
set.insert(self.path_hashes[i]);
|
||||||
|
}
|
||||||
|
set.len() as f64
|
||||||
|
};
|
||||||
|
|
||||||
|
let unique_hosts = {
|
||||||
|
let mut set = FxHashSet::default();
|
||||||
|
for i in start..end {
|
||||||
|
set.insert(self.host_hashes[i]);
|
||||||
|
}
|
||||||
|
set.len() as f64
|
||||||
|
};
|
||||||
|
|
||||||
|
let errors = self.statuses[start..end]
|
||||||
|
.iter()
|
||||||
|
.filter(|&&s| s >= 400)
|
||||||
|
.count() as f64;
|
||||||
|
let error_rate = errors / n;
|
||||||
|
|
||||||
|
let avg_duration_ms =
|
||||||
|
self.durations[start..end].iter().map(|&d| d as f64).sum::<f64>() / n;
|
||||||
|
|
||||||
|
let method_entropy = {
|
||||||
|
let mut counts = [0u32; 8];
|
||||||
|
for i in start..end {
|
||||||
|
counts[self.methods[i] as usize % 8] += 1;
|
||||||
|
}
|
||||||
|
let mut entropy = 0.0f64;
|
||||||
|
for &c in &counts {
|
||||||
|
if c > 0 {
|
||||||
|
let p = c as f64 / n;
|
||||||
|
entropy -= p * p.ln();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entropy
|
||||||
|
};
|
||||||
|
|
||||||
|
let burst_score = if (end - start) >= 2 {
|
||||||
|
let total_span =
|
||||||
|
self.timestamps[end - 1] - self.timestamps[start];
|
||||||
|
if total_span > 0.0 {
|
||||||
|
(end - start - 1) as f64 / total_span
|
||||||
|
} else {
|
||||||
|
n
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let path_repetition = {
|
||||||
|
let mut counts = rustc_hash::FxHashMap::default();
|
||||||
|
for i in start..end {
|
||||||
|
*counts.entry(self.path_hashes[i]).or_insert(0u32) += 1;
|
||||||
|
}
|
||||||
|
let max_count = counts.values().copied().max().unwrap_or(0) as f64;
|
||||||
|
max_count / n
|
||||||
|
};
|
||||||
|
|
||||||
|
let avg_content_length = self.content_lengths[start..end]
|
||||||
|
.iter()
|
||||||
|
.map(|&c| c as f64)
|
||||||
|
.sum::<f64>()
|
||||||
|
/ n;
|
||||||
|
|
||||||
|
let unique_user_agents = {
|
||||||
|
let mut set = FxHashSet::default();
|
||||||
|
for i in start..end {
|
||||||
|
set.insert(self.user_agent_hashes[i]);
|
||||||
|
}
|
||||||
|
set.len() as f64
|
||||||
|
};
|
||||||
|
|
||||||
|
let cookie_ratio =
|
||||||
|
self.has_cookies[start..end].iter().filter(|&&v| v).count() as f64 / n;
|
||||||
|
let referer_ratio =
|
||||||
|
self.has_referer[start..end].iter().filter(|&&v| v).count() as f64 / n;
|
||||||
|
let accept_language_ratio =
|
||||||
|
self.has_accept_language[start..end].iter().filter(|&&v| v).count() as f64 / n;
|
||||||
|
let suspicious_path_ratio =
|
||||||
|
self.suspicious_paths[start..end].iter().filter(|&&v| v).count() as f64 / n;
|
||||||
|
|
||||||
|
[
|
||||||
|
request_rate,
|
||||||
|
unique_paths,
|
||||||
|
unique_hosts,
|
||||||
|
error_rate,
|
||||||
|
avg_duration_ms,
|
||||||
|
method_entropy,
|
||||||
|
burst_score,
|
||||||
|
path_repetition,
|
||||||
|
avg_content_length,
|
||||||
|
unique_user_agents,
|
||||||
|
cookie_ratio,
|
||||||
|
referer_ratio,
|
||||||
|
accept_language_ratio,
|
||||||
|
suspicious_path_ratio,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use rustc_hash::FxHasher;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
|
||||||
|
fn fx(s: &str) -> u64 {
|
||||||
|
let mut h = FxHasher::default();
|
||||||
|
s.hash(&mut h);
|
||||||
|
h.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_single_event_features() {
|
||||||
|
let mut state = IpState::new(100);
|
||||||
|
state.push(RequestEvent {
|
||||||
|
timestamp: Instant::now(),
|
||||||
|
method: 0,
|
||||||
|
path_hash: fx("/"),
|
||||||
|
host_hash: fx("example.com"),
|
||||||
|
user_agent_hash: fx("curl/7.0"),
|
||||||
|
status: 200,
|
||||||
|
duration_ms: 10,
|
||||||
|
content_length: 0,
|
||||||
|
has_cookies: true,
|
||||||
|
has_referer: false,
|
||||||
|
has_accept_language: true,
|
||||||
|
suspicious_path: false,
|
||||||
|
});
|
||||||
|
let features = state.extract_features(60);
|
||||||
|
// request_rate = 1/60
|
||||||
|
assert!(features[0] > 0.0);
|
||||||
|
// error_rate = 0
|
||||||
|
assert_eq!(features[3], 0.0);
|
||||||
|
// path_repetition = 1.0 (only one path)
|
||||||
|
assert_eq!(features[7], 1.0);
|
||||||
|
// cookie_ratio = 1.0 (single event with cookies)
|
||||||
|
assert_eq!(features[10], 1.0);
|
||||||
|
// referer_ratio = 0.0
|
||||||
|
assert_eq!(features[11], 0.0);
|
||||||
|
// accept_language_ratio = 1.0
|
||||||
|
assert_eq!(features[12], 1.0);
|
||||||
|
// suspicious_path_ratio = 0.0
|
||||||
|
assert_eq!(features[13], 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_norm_params() {
|
||||||
|
let data = vec![[0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[1.0, 20.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]];
|
||||||
|
let params = NormParams::from_data(&data);
|
||||||
|
let normalized = params.normalize(&[0.5, 15.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]);
|
||||||
|
for &v in &normalized {
|
||||||
|
assert!((v - 0.5).abs() < 1e-10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
6
src/ddos/mod.rs
Normal file
6
src/ddos/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
pub mod audit_log;
|
||||||
|
pub mod detector;
|
||||||
|
pub mod features;
|
||||||
|
pub mod model;
|
||||||
|
pub mod replay;
|
||||||
|
pub mod train;
|
||||||
168
src/ddos/model.rs
Normal file
168
src/ddos/model.rs
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
use crate::ddos::features::{FeatureVector, NormParams, NUM_FEATURES};
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum TrafficLabel {
|
||||||
|
Normal,
|
||||||
|
Attack,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct SerializedModel {
|
||||||
|
pub points: Vec<FeatureVector>,
|
||||||
|
pub labels: Vec<TrafficLabel>,
|
||||||
|
pub norm_params: NormParams,
|
||||||
|
pub k: usize,
|
||||||
|
pub threshold: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TrainedModel {
|
||||||
|
/// Stored points (normalized). The kD-tree borrows these.
|
||||||
|
points: Vec<[f64; NUM_FEATURES]>,
|
||||||
|
labels: Vec<TrafficLabel>,
|
||||||
|
norm_params: NormParams,
|
||||||
|
k: usize,
|
||||||
|
threshold: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum DDoSAction {
|
||||||
|
Allow,
|
||||||
|
Block,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrainedModel {
|
||||||
|
pub fn load(path: &Path, k_override: Option<usize>, threshold_override: Option<f64>) -> Result<Self> {
|
||||||
|
let data = std::fs::read(path)
|
||||||
|
.with_context(|| format!("reading model from {}", path.display()))?;
|
||||||
|
let model: SerializedModel =
|
||||||
|
bincode::deserialize(&data).context("deserializing model")?;
|
||||||
|
Ok(Self {
|
||||||
|
points: model.points,
|
||||||
|
labels: model.labels,
|
||||||
|
norm_params: model.norm_params,
|
||||||
|
k: k_override.unwrap_or(model.k),
|
||||||
|
threshold: threshold_override.unwrap_or(model.threshold),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_serialized(model: SerializedModel) -> Self {
|
||||||
|
Self {
|
||||||
|
points: model.points,
|
||||||
|
labels: model.labels,
|
||||||
|
norm_params: model.norm_params,
|
||||||
|
k: model.k,
|
||||||
|
threshold: model.threshold,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn classify(&self, features: &FeatureVector) -> DDoSAction {
|
||||||
|
let normalized = self.norm_params.normalize(features);
|
||||||
|
|
||||||
|
if self.points.is_empty() {
|
||||||
|
return DDoSAction::Allow;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build tree on-the-fly for query. In production with many queries,
|
||||||
|
// we'd cache this, but the tree build is fast for <100K points.
|
||||||
|
// fnntw::Tree borrows data, so we build it here.
|
||||||
|
let tree = match fnntw::Tree::<'_, f64, NUM_FEATURES>::new(&self.points, 32) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(_) => return DDoSAction::Allow,
|
||||||
|
};
|
||||||
|
|
||||||
|
let k = self.k.min(self.points.len());
|
||||||
|
let result = tree.query_nearest_k(&normalized, k);
|
||||||
|
match result {
|
||||||
|
Ok((_distances, indices)) => {
|
||||||
|
let attack_count = indices
|
||||||
|
.iter()
|
||||||
|
.filter(|&&idx| self.labels[idx as usize] == TrafficLabel::Attack)
|
||||||
|
.count();
|
||||||
|
let attack_frac = attack_count as f64 / k as f64;
|
||||||
|
if attack_frac >= self.threshold {
|
||||||
|
DDoSAction::Block
|
||||||
|
} else {
|
||||||
|
DDoSAction::Allow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => DDoSAction::Allow,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn norm_params(&self) -> &NormParams {
|
||||||
|
&self.norm_params
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn point_count(&self) -> usize {
|
||||||
|
self.points.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_classify_empty_model() {
|
||||||
|
let model = TrainedModel {
|
||||||
|
points: vec![],
|
||||||
|
labels: vec![],
|
||||||
|
norm_params: NormParams {
|
||||||
|
mins: [0.0; NUM_FEATURES],
|
||||||
|
maxs: [1.0; NUM_FEATURES],
|
||||||
|
},
|
||||||
|
k: 5,
|
||||||
|
threshold: 0.6,
|
||||||
|
};
|
||||||
|
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_test_points(n: usize) -> Vec<FeatureVector> {
|
||||||
|
(0..n)
|
||||||
|
.map(|i| {
|
||||||
|
let mut v = [0.0; NUM_FEATURES];
|
||||||
|
for d in 0..NUM_FEATURES {
|
||||||
|
v[d] = ((i * (d + 1)) as f64 / n as f64) % 1.0;
|
||||||
|
}
|
||||||
|
v
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_classify_all_attack() {
|
||||||
|
let points = make_test_points(100);
|
||||||
|
let labels = vec![TrafficLabel::Attack; 100];
|
||||||
|
let model = TrainedModel {
|
||||||
|
points,
|
||||||
|
labels,
|
||||||
|
norm_params: NormParams {
|
||||||
|
mins: [0.0; NUM_FEATURES],
|
||||||
|
maxs: [1.0; NUM_FEATURES],
|
||||||
|
},
|
||||||
|
k: 5,
|
||||||
|
threshold: 0.6,
|
||||||
|
};
|
||||||
|
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_classify_all_normal() {
|
||||||
|
let points = make_test_points(100);
|
||||||
|
let labels = vec![TrafficLabel::Normal; 100];
|
||||||
|
let model = TrainedModel {
|
||||||
|
points,
|
||||||
|
labels,
|
||||||
|
norm_params: NormParams {
|
||||||
|
mins: [0.0; NUM_FEATURES],
|
||||||
|
maxs: [1.0; NUM_FEATURES],
|
||||||
|
},
|
||||||
|
k: 5,
|
||||||
|
threshold: 0.6,
|
||||||
|
};
|
||||||
|
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
}
|
||||||
291
src/ddos/replay.rs
Normal file
291
src/ddos/replay.rs
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
use crate::config::{DDoSConfig, RateLimitConfig};
|
||||||
|
use crate::ddos::audit_log::{self, AuditLog};
|
||||||
|
use crate::ddos::detector::DDoSDetector;
|
||||||
|
use crate::ddos::model::{DDoSAction, TrainedModel};
|
||||||
|
use crate::rate_limit::key::RateLimitKey;
|
||||||
|
use crate::rate_limit::limiter::{RateLimitResult, RateLimiter};
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
|
use std::io::BufRead;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub struct ReplayArgs {
|
||||||
|
pub input: String,
|
||||||
|
pub model_path: String,
|
||||||
|
pub config_path: Option<String>,
|
||||||
|
pub k: usize,
|
||||||
|
pub threshold: f64,
|
||||||
|
pub window_secs: u64,
|
||||||
|
pub min_events: usize,
|
||||||
|
pub rate_limit: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ReplayStats {
|
||||||
|
total: u64,
|
||||||
|
skipped: u64,
|
||||||
|
ddos_blocked: u64,
|
||||||
|
rate_limited: u64,
|
||||||
|
allowed: u64,
|
||||||
|
ddos_blocked_ips: FxHashMap<String, u64>,
|
||||||
|
rate_limited_ips: FxHashMap<String, u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run(args: ReplayArgs) -> Result<()> {
|
||||||
|
eprintln!("Loading model from {}...", args.model_path);
|
||||||
|
let model = TrainedModel::load(
|
||||||
|
std::path::Path::new(&args.model_path),
|
||||||
|
Some(args.k),
|
||||||
|
Some(args.threshold),
|
||||||
|
)
|
||||||
|
.with_context(|| format!("loading model from {}", args.model_path))?;
|
||||||
|
eprintln!(" {} training points, k={}, threshold={}", model.point_count(), args.k, args.threshold);
|
||||||
|
|
||||||
|
let ddos_cfg = DDoSConfig {
|
||||||
|
model_path: args.model_path.clone(),
|
||||||
|
k: args.k,
|
||||||
|
threshold: args.threshold,
|
||||||
|
window_secs: args.window_secs,
|
||||||
|
window_capacity: 1000,
|
||||||
|
min_events: args.min_events,
|
||||||
|
enabled: true,
|
||||||
|
};
|
||||||
|
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
|
||||||
|
|
||||||
|
// Optionally set up rate limiter
|
||||||
|
let rate_limiter = if args.rate_limit {
|
||||||
|
let rl_cfg = if let Some(cfg_path) = &args.config_path {
|
||||||
|
let cfg = crate::config::Config::load(cfg_path)?;
|
||||||
|
cfg.rate_limit.unwrap_or_else(default_rate_limit_config)
|
||||||
|
} else {
|
||||||
|
default_rate_limit_config()
|
||||||
|
};
|
||||||
|
eprintln!(
|
||||||
|
" Rate limiter: auth burst={} rate={}/s, unauth burst={} rate={}/s",
|
||||||
|
rl_cfg.authenticated.burst,
|
||||||
|
rl_cfg.authenticated.rate,
|
||||||
|
rl_cfg.unauthenticated.burst,
|
||||||
|
rl_cfg.unauthenticated.rate,
|
||||||
|
);
|
||||||
|
Some(RateLimiter::new(&rl_cfg))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
eprintln!("Replaying {}...\n", args.input);
|
||||||
|
|
||||||
|
let file = std::fs::File::open(&args.input)
|
||||||
|
.with_context(|| format!("opening {}", args.input))?;
|
||||||
|
let reader = std::io::BufReader::new(file);
|
||||||
|
|
||||||
|
let mut stats = ReplayStats {
|
||||||
|
total: 0,
|
||||||
|
skipped: 0,
|
||||||
|
ddos_blocked: 0,
|
||||||
|
rate_limited: 0,
|
||||||
|
allowed: 0,
|
||||||
|
ddos_blocked_ips: FxHashMap::default(),
|
||||||
|
rate_limited_ips: FxHashMap::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(_) => {
|
||||||
|
stats.skipped += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if entry.fields.method.is_empty() {
|
||||||
|
stats.skipped += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
stats.total += 1;
|
||||||
|
|
||||||
|
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||||
|
let ip: IpAddr = match ip_str.parse() {
|
||||||
|
Ok(ip) => ip,
|
||||||
|
Err(_) => {
|
||||||
|
stats.skipped += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// DDoS check
|
||||||
|
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
|
||||||
|
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
|
||||||
|
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
|
||||||
|
let ddos_action = detector.check(
|
||||||
|
ip,
|
||||||
|
&entry.fields.method,
|
||||||
|
&entry.fields.path,
|
||||||
|
&entry.fields.host,
|
||||||
|
&entry.fields.user_agent,
|
||||||
|
entry.fields.content_length,
|
||||||
|
has_cookies,
|
||||||
|
has_referer,
|
||||||
|
has_accept_language,
|
||||||
|
);
|
||||||
|
|
||||||
|
if ddos_action == DDoSAction::Block {
|
||||||
|
stats.ddos_blocked += 1;
|
||||||
|
*stats.ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rate limit check
|
||||||
|
if let Some(limiter) = &rate_limiter {
|
||||||
|
// Audit logs don't have auth headers, so all traffic is keyed by IP
|
||||||
|
let rl_key = RateLimitKey::Ip(ip);
|
||||||
|
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
|
||||||
|
stats.rate_limited += 1;
|
||||||
|
*stats.rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats.allowed += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Report
|
||||||
|
let total = stats.total;
|
||||||
|
eprintln!("═══ Replay Results ═══════════════════════════════════════");
|
||||||
|
eprintln!(" Total requests: {total}");
|
||||||
|
eprintln!(" Skipped (parse): {}", stats.skipped);
|
||||||
|
eprintln!(" Allowed: {} ({:.1}%)", stats.allowed, pct(stats.allowed, total));
|
||||||
|
eprintln!(" DDoS blocked: {} ({:.1}%)", stats.ddos_blocked, pct(stats.ddos_blocked, total));
|
||||||
|
if rate_limiter.is_some() {
|
||||||
|
eprintln!(" Rate limited: {} ({:.1}%)", stats.rate_limited, pct(stats.rate_limited, total));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stats.ddos_blocked_ips.is_empty() {
|
||||||
|
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
|
||||||
|
let mut sorted: Vec<_> = stats.ddos_blocked_ips.iter().collect();
|
||||||
|
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||||
|
for (ip, count) in sorted.iter().take(20) {
|
||||||
|
eprintln!(" {:<40} {} reqs blocked", ip, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stats.rate_limited_ips.is_empty() {
|
||||||
|
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
|
||||||
|
let mut sorted: Vec<_> = stats.rate_limited_ips.iter().collect();
|
||||||
|
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||||
|
for (ip, count) in sorted.iter().take(20) {
|
||||||
|
eprintln!(" {:<40} {} reqs limited", ip, count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for false positives: IPs that were blocked but had 2xx statuses in the original logs
|
||||||
|
eprintln!("\n── False positive check ──────────────────────────────────");
|
||||||
|
check_false_positives(&args.input, &stats)?;
|
||||||
|
|
||||||
|
eprintln!("══════════════════════════════════════════════════════════");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Re-scan the log to find blocked IPs that had mostly 2xx responses originally
|
||||||
|
/// (i.e. they were legitimate traffic that the model would incorrectly block).
|
||||||
|
fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
|
||||||
|
let blocked_ips: rustc_hash::FxHashSet<&str> = stats
|
||||||
|
.ddos_blocked_ips
|
||||||
|
.keys()
|
||||||
|
.chain(stats.rate_limited_ips.keys())
|
||||||
|
.map(|s| s.as_str())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if blocked_ips.is_empty() {
|
||||||
|
eprintln!(" No blocked IPs — nothing to check.");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect original status codes for blocked IPs
|
||||||
|
let file = std::fs::File::open(input)?;
|
||||||
|
let reader = std::io::BufReader::new(file);
|
||||||
|
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
|
||||||
|
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||||
|
if blocked_ips.contains(ip_str.as_str()) {
|
||||||
|
ip_statuses
|
||||||
|
.entry(ip_str)
|
||||||
|
.or_default()
|
||||||
|
.push(entry.fields.status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut suspects = Vec::new();
|
||||||
|
for (ip, statuses) in &ip_statuses {
|
||||||
|
let total = statuses.len();
|
||||||
|
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
|
||||||
|
let ok_pct = (ok_count as f64 / total as f64) * 100.0;
|
||||||
|
// If >60% of original responses were 2xx/3xx, this might be a false positive
|
||||||
|
if ok_pct > 60.0 {
|
||||||
|
let blocked = stats
|
||||||
|
.ddos_blocked_ips
|
||||||
|
.get(ip)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0)
|
||||||
|
+ stats
|
||||||
|
.rate_limited_ips
|
||||||
|
.get(ip)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0);
|
||||||
|
suspects.push((ip.clone(), total, ok_pct, blocked));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if suspects.is_empty() {
|
||||||
|
eprintln!(" No likely false positives found.");
|
||||||
|
} else {
|
||||||
|
suspects.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
|
||||||
|
eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses:", suspects.len());
|
||||||
|
for (ip, total, ok_pct, blocked) in suspects.iter().take(15) {
|
||||||
|
eprintln!(
|
||||||
|
" {:<40} {}/{} reqs were 2xx/3xx ({:.0}%), {} blocked",
|
||||||
|
ip, ((*ok_pct / 100.0) * *total as f64) as u64, total, ok_pct, blocked,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rate_limit_config() -> RateLimitConfig {
|
||||||
|
RateLimitConfig {
|
||||||
|
enabled: true,
|
||||||
|
bypass_cidrs: vec![
|
||||||
|
"10.0.0.0/8".into(),
|
||||||
|
"172.16.0.0/12".into(),
|
||||||
|
"192.168.0.0/16".into(),
|
||||||
|
"100.64.0.0/10".into(),
|
||||||
|
"fd00::/8".into(),
|
||||||
|
],
|
||||||
|
eviction_interval_secs: 300,
|
||||||
|
stale_after_secs: 600,
|
||||||
|
authenticated: crate::config::BucketConfig {
|
||||||
|
burst: 200,
|
||||||
|
rate: 50.0,
|
||||||
|
},
|
||||||
|
unauthenticated: crate::config::BucketConfig {
|
||||||
|
burst: 60,
|
||||||
|
rate: 15.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pct(n: u64, total: u64) -> f64 {
|
||||||
|
if total == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
(n as f64 / total as f64) * 100.0
|
||||||
|
}
|
||||||
|
}
|
||||||
298
src/ddos/train.rs
Normal file
298
src/ddos/train.rs
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
use crate::ddos::audit_log::AuditLog;
|
||||||
|
use crate::ddos::audit_log;
|
||||||
|
use crate::ddos::features::{method_to_u8, FeatureVector, LogIpState, NormParams, NUM_FEATURES};
|
||||||
|
use crate::ddos::model::{SerializedModel, TrafficLabel};
|
||||||
|
use anyhow::{bail, Context, Result};
|
||||||
|
use rustc_hash::{FxHashMap, FxHashSet};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::io::BufRead;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct HeuristicThresholds {
|
||||||
|
/// Requests/second above which an IP is labeled attack
|
||||||
|
#[serde(default = "default_rate_threshold")]
|
||||||
|
pub request_rate: f64,
|
||||||
|
/// Path repetition ratio above which an IP is labeled attack
|
||||||
|
#[serde(default = "default_repetition_threshold")]
|
||||||
|
pub path_repetition: f64,
|
||||||
|
/// Error rate above which an IP is labeled attack
|
||||||
|
#[serde(default = "default_error_threshold")]
|
||||||
|
pub error_rate: f64,
|
||||||
|
/// Suspicious path ratio above which an IP is labeled attack
|
||||||
|
#[serde(default = "default_suspicious_path_threshold")]
|
||||||
|
pub suspicious_path_ratio: f64,
|
||||||
|
/// Cookie ratio below which (combined with high unique paths) labels attack
|
||||||
|
#[serde(default = "default_no_cookies_threshold")]
|
||||||
|
pub no_cookies_threshold: f64,
|
||||||
|
/// Unique path count above which no-cookie traffic is labeled attack
|
||||||
|
#[serde(default = "default_no_cookies_path_count")]
|
||||||
|
pub no_cookies_path_count: f64,
|
||||||
|
/// Minimum events to consider an IP for labeling
|
||||||
|
#[serde(default = "default_min_events")]
|
||||||
|
pub min_events: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rate_threshold() -> f64 { 10.0 }
|
||||||
|
fn default_repetition_threshold() -> f64 { 0.9 }
|
||||||
|
fn default_error_threshold() -> f64 { 0.7 }
|
||||||
|
fn default_suspicious_path_threshold() -> f64 { 0.3 }
|
||||||
|
fn default_no_cookies_threshold() -> f64 { 0.05 }
|
||||||
|
fn default_no_cookies_path_count() -> f64 { 20.0 }
|
||||||
|
fn default_min_events() -> usize { 10 }
|
||||||
|
|
||||||
|
pub struct TrainArgs {
|
||||||
|
pub input: String,
|
||||||
|
pub output: String,
|
||||||
|
pub attack_ips: Option<String>,
|
||||||
|
pub normal_ips: Option<String>,
|
||||||
|
pub heuristics: Option<String>,
|
||||||
|
pub k: usize,
|
||||||
|
pub threshold: f64,
|
||||||
|
pub window_secs: u64,
|
||||||
|
pub min_events: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fx_hash(s: &str) -> u64 {
|
||||||
|
let mut h = rustc_hash::FxHasher::default();
|
||||||
|
s.hash(&mut h);
|
||||||
|
h.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_timestamp(ts: &str) -> f64 {
|
||||||
|
// Parse ISO 8601 timestamp to seconds since epoch (approximate).
|
||||||
|
// We only need relative ordering within a log file.
|
||||||
|
// Format: "2026-03-07T17:41:40.705326Z"
|
||||||
|
let parts: Vec<&str> = ts.split('T').collect();
|
||||||
|
if parts.len() != 2 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
let date_parts: Vec<&str> = parts[0].split('-').collect();
|
||||||
|
let time_str = parts[1].trim_end_matches('Z');
|
||||||
|
let time_parts: Vec<&str> = time_str.split(':').collect();
|
||||||
|
if date_parts.len() != 3 || time_parts.len() != 3 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
let day: f64 = date_parts[2].parse().unwrap_or(0.0);
|
||||||
|
let hour: f64 = time_parts[0].parse().unwrap_or(0.0);
|
||||||
|
let min: f64 = time_parts[1].parse().unwrap_or(0.0);
|
||||||
|
let sec: f64 = time_parts[2].parse().unwrap_or(0.0);
|
||||||
|
// Relative seconds (day * 86400 + time)
|
||||||
|
day * 86400.0 + hour * 3600.0 + min * 60.0 + sec
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub fn run(args: TrainArgs) -> Result<()> {
|
||||||
|
eprintln!("Parsing logs from {}...", args.input);
|
||||||
|
|
||||||
|
// Parse logs into per-IP state
|
||||||
|
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
|
||||||
|
let file = std::fs::File::open(&args.input)
|
||||||
|
.with_context(|| format!("opening {}", args.input))?;
|
||||||
|
let reader = std::io::BufReader::new(file);
|
||||||
|
|
||||||
|
let mut total_lines = 0u64;
|
||||||
|
let mut parse_errors = 0u64;
|
||||||
|
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
total_lines += 1;
|
||||||
|
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(_) => {
|
||||||
|
parse_errors += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Skip non-audit entries
|
||||||
|
if entry.fields.method.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ip = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||||
|
let ts = parse_timestamp(&entry.timestamp);
|
||||||
|
|
||||||
|
let state = ip_states.entry(ip).or_insert_with(LogIpState::new);
|
||||||
|
state.timestamps.push(ts);
|
||||||
|
state.methods.push(method_to_u8(&entry.fields.method));
|
||||||
|
state.path_hashes.push(fx_hash(&entry.fields.path));
|
||||||
|
state.host_hashes.push(fx_hash(&entry.fields.host));
|
||||||
|
state
|
||||||
|
.user_agent_hashes
|
||||||
|
.push(fx_hash(&entry.fields.user_agent));
|
||||||
|
state.statuses.push(entry.fields.status);
|
||||||
|
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
|
||||||
|
state
|
||||||
|
.content_lengths
|
||||||
|
.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
|
||||||
|
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
|
||||||
|
state.has_referer.push(
|
||||||
|
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
|
||||||
|
);
|
||||||
|
state.has_accept_language.push(
|
||||||
|
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
|
||||||
|
);
|
||||||
|
state.suspicious_paths.push(
|
||||||
|
crate::ddos::features::is_suspicious_path(&entry.fields.path),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
"Parsed {} lines ({} errors), {} unique IPs",
|
||||||
|
total_lines,
|
||||||
|
parse_errors,
|
||||||
|
ip_states.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Extract feature vectors per IP (using sliding windows)
|
||||||
|
let window_secs = args.window_secs as f64;
|
||||||
|
let mut ip_features: FxHashMap<String, Vec<FeatureVector>> = FxHashMap::default();
|
||||||
|
|
||||||
|
for (ip, state) in &ip_states {
|
||||||
|
let n = state.timestamps.len();
|
||||||
|
if n < args.min_events {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Extract one feature vector per window
|
||||||
|
let mut features = Vec::new();
|
||||||
|
let mut start = 0;
|
||||||
|
for end in 1..n {
|
||||||
|
let span = state.timestamps[end] - state.timestamps[start];
|
||||||
|
if span >= window_secs || end == n - 1 {
|
||||||
|
let fv = state.extract_features_for_window(start, end + 1, window_secs);
|
||||||
|
features.push(fv);
|
||||||
|
start = end + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !features.is_empty() {
|
||||||
|
ip_features.insert(ip.clone(), features);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Label IPs
|
||||||
|
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default();
|
||||||
|
|
||||||
|
if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) {
|
||||||
|
// IP list mode
|
||||||
|
let attack_ips: FxHashSet<String> = std::fs::read_to_string(attack_file)
|
||||||
|
.context("reading attack IPs file")?
|
||||||
|
.lines()
|
||||||
|
.map(|l| l.trim().to_string())
|
||||||
|
.filter(|l| !l.is_empty())
|
||||||
|
.collect();
|
||||||
|
let normal_ips: FxHashSet<String> = std::fs::read_to_string(normal_file)
|
||||||
|
.context("reading normal IPs file")?
|
||||||
|
.lines()
|
||||||
|
.map(|l| l.trim().to_string())
|
||||||
|
.filter(|l| !l.is_empty())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for ip in ip_features.keys() {
|
||||||
|
if attack_ips.contains(ip) {
|
||||||
|
ip_labels.insert(ip.clone(), TrafficLabel::Attack);
|
||||||
|
} else if normal_ips.contains(ip) {
|
||||||
|
ip_labels.insert(ip.clone(), TrafficLabel::Normal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let Some(heuristics_file) = &args.heuristics {
|
||||||
|
// Heuristic auto-labeling
|
||||||
|
let heuristics_str = std::fs::read_to_string(heuristics_file)
|
||||||
|
.context("reading heuristics file")?;
|
||||||
|
let thresholds: HeuristicThresholds =
|
||||||
|
toml::from_str(&heuristics_str).context("parsing heuristics TOML")?;
|
||||||
|
|
||||||
|
for (ip, features) in &ip_features {
|
||||||
|
// Use the aggregate (last/max) feature vector for labeling
|
||||||
|
let avg = average_features(features);
|
||||||
|
let is_attack = avg[0] > thresholds.request_rate // request_rate
|
||||||
|
|| avg[7] > thresholds.path_repetition // path_repetition
|
||||||
|
|| avg[3] > thresholds.error_rate // error_rate
|
||||||
|
|| avg[13] > thresholds.suspicious_path_ratio // suspicious_path_ratio
|
||||||
|
|| (avg[10] < thresholds.no_cookies_threshold // no cookies + high unique paths
|
||||||
|
&& avg[1] > thresholds.no_cookies_path_count);
|
||||||
|
ip_labels.insert(
|
||||||
|
ip.clone(),
|
||||||
|
if is_attack {
|
||||||
|
TrafficLabel::Attack
|
||||||
|
} else {
|
||||||
|
TrafficLabel::Normal
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build training dataset
|
||||||
|
let mut all_points: Vec<FeatureVector> = Vec::new();
|
||||||
|
let mut all_labels: Vec<TrafficLabel> = Vec::new();
|
||||||
|
|
||||||
|
for (ip, features) in &ip_features {
|
||||||
|
if let Some(&label) = ip_labels.get(ip) {
|
||||||
|
for fv in features {
|
||||||
|
all_points.push(*fv);
|
||||||
|
all_labels.push(label);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_points.is_empty() {
|
||||||
|
bail!("No labeled data points found. Check your IP lists or heuristic thresholds.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let attack_count = all_labels
|
||||||
|
.iter()
|
||||||
|
.filter(|&&l| l == TrafficLabel::Attack)
|
||||||
|
.count();
|
||||||
|
let normal_count = all_labels.len() - attack_count;
|
||||||
|
eprintln!(
|
||||||
|
"Training with {} points ({} attack, {} normal)",
|
||||||
|
all_points.len(),
|
||||||
|
attack_count,
|
||||||
|
normal_count
|
||||||
|
);
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
let norm_params = NormParams::from_data(&all_points);
|
||||||
|
let normalized: Vec<FeatureVector> = all_points
|
||||||
|
.iter()
|
||||||
|
.map(|v| norm_params.normalize(v))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Serialize
|
||||||
|
let model = SerializedModel {
|
||||||
|
points: normalized,
|
||||||
|
labels: all_labels,
|
||||||
|
norm_params,
|
||||||
|
k: args.k,
|
||||||
|
threshold: args.threshold,
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = bincode::serialize(&model).context("serializing model")?;
|
||||||
|
std::fs::write(&args.output, &encoded)
|
||||||
|
.with_context(|| format!("writing model to {}", args.output))?;
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
"Model saved to {} ({} bytes, {} points)",
|
||||||
|
args.output,
|
||||||
|
encoded.len(),
|
||||||
|
model.points.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn average_features(features: &[FeatureVector]) -> FeatureVector {
|
||||||
|
let n = features.len() as f64;
|
||||||
|
let mut avg = [0.0; NUM_FEATURES];
|
||||||
|
for fv in features {
|
||||||
|
for i in 0..NUM_FEATURES {
|
||||||
|
avg[i] += fv[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for v in &mut avg {
|
||||||
|
*v /= n;
|
||||||
|
}
|
||||||
|
avg
|
||||||
|
}
|
||||||
776
tests/ddos_test.rs
Normal file
776
tests/ddos_test.rs
Normal file
@@ -0,0 +1,776 @@
|
|||||||
|
//! Extensive DDoS detection tests.
|
||||||
|
//!
|
||||||
|
//! These tests build realistic traffic profiles — normal browsing, API usage,
|
||||||
|
//! webhook bursts, etc. — and verify the model never blocks legitimate traffic.
|
||||||
|
//! Attack scenarios are also tested to confirm blocking works.
|
||||||
|
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
|
|
||||||
|
use sunbeam_proxy::config::DDoSConfig;
|
||||||
|
use sunbeam_proxy::ddos::detector::DDoSDetector;
|
||||||
|
use sunbeam_proxy::ddos::features::{NormParams, NUM_FEATURES};
|
||||||
|
use sunbeam_proxy::ddos::model::{DDoSAction, SerializedModel, TrafficLabel, TrainedModel};
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Build a model from explicit normal/attack feature vectors.
|
||||||
|
fn make_model(
|
||||||
|
normal: &[[f64; NUM_FEATURES]],
|
||||||
|
attack: &[[f64; NUM_FEATURES]],
|
||||||
|
k: usize,
|
||||||
|
threshold: f64,
|
||||||
|
) -> TrainedModel {
|
||||||
|
let mut points = Vec::new();
|
||||||
|
let mut labels = Vec::new();
|
||||||
|
for v in normal {
|
||||||
|
points.push(*v);
|
||||||
|
labels.push(TrafficLabel::Normal);
|
||||||
|
}
|
||||||
|
for v in attack {
|
||||||
|
points.push(*v);
|
||||||
|
labels.push(TrafficLabel::Attack);
|
||||||
|
}
|
||||||
|
let norm_params = NormParams::from_data(&points);
|
||||||
|
let normalized: Vec<[f64; NUM_FEATURES]> =
|
||||||
|
points.iter().map(|v| norm_params.normalize(v)).collect();
|
||||||
|
TrainedModel::from_serialized(SerializedModel {
|
||||||
|
points: normalized,
|
||||||
|
labels,
|
||||||
|
norm_params,
|
||||||
|
k,
|
||||||
|
threshold,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_ddos_config() -> DDoSConfig {
|
||||||
|
DDoSConfig {
|
||||||
|
model_path: String::new(),
|
||||||
|
k: 5,
|
||||||
|
threshold: 0.6,
|
||||||
|
window_secs: 60,
|
||||||
|
window_capacity: 1000,
|
||||||
|
min_events: 10,
|
||||||
|
enabled: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_detector(model: TrainedModel, min_events: usize) -> DDoSDetector {
|
||||||
|
let mut cfg = default_ddos_config();
|
||||||
|
cfg.min_events = min_events;
|
||||||
|
DDoSDetector::new(model, &cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature vector indices (matching features.rs order):
|
||||||
|
/// 0: request_rate (requests / window_secs)
|
||||||
|
/// 1: unique_paths (count of distinct paths)
|
||||||
|
/// 2: unique_hosts (count of distinct hosts)
|
||||||
|
/// 3: error_rate (fraction 4xx/5xx)
|
||||||
|
/// 4: avg_duration_ms (mean response time)
|
||||||
|
/// 5: method_entropy (Shannon entropy of methods)
|
||||||
|
/// 6: burst_score (inverse mean inter-arrival)
|
||||||
|
/// 7: path_repetition (most-repeated path / total)
|
||||||
|
/// 8: avg_content_length (mean body size)
|
||||||
|
/// 9: unique_user_agents (count of distinct UAs)
|
||||||
|
/// 10: cookie_ratio (fraction with cookies)
|
||||||
|
/// 11: referer_ratio (fraction with referer)
|
||||||
|
/// 12: accept_language_ratio (fraction with accept-language)
|
||||||
|
/// 13: suspicious_path_ratio (fraction hitting known-bad paths)
|
||||||
|
/// 9: unique_user_agents (count of distinct UAs)
|
||||||
|
|
||||||
|
// Realistic normal traffic profiles
|
||||||
|
fn normal_browser_browsing() -> [f64; NUM_FEATURES] {
|
||||||
|
// A human browsing a site: ~0.5 req/s, many paths, 1 host, low errors,
|
||||||
|
// ~150ms avg latency, mostly GET, moderate spacing, diverse paths, no body, 1 UA
|
||||||
|
// cookies=yes, referer=sometimes, accept-lang=yes, suspicious=no
|
||||||
|
[0.5, 12.0, 1.0, 0.02, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_api_client() -> [f64; NUM_FEATURES] {
|
||||||
|
// Backend API client: ~2 req/s, hits a few endpoints, 1 host, ~5% errors (retries),
|
||||||
|
// ~50ms latency, mix of GET/POST, steady rate, some path repetition, small bodies, 1 UA
|
||||||
|
// cookies=yes (session), referer=no, accept-lang=no, suspicious=no
|
||||||
|
[2.0, 5.0, 1.0, 0.05, 50.0, 0.69, 2.5, 0.4, 512.0, 1.0, 1.0, 0.0, 0.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_webhook_burst() -> [f64; NUM_FEATURES] {
|
||||||
|
// CI/CD or webhook burst: ~10 req/s for a short period, 1-2 paths, 1 host,
|
||||||
|
// 0% errors, fast responses, all POST, bursty, high path repetition, medium bodies, 1 UA
|
||||||
|
// cookies=no (machine), referer=no, accept-lang=no, suspicious=no
|
||||||
|
[10.0, 2.0, 1.0, 0.0, 25.0, 0.0, 12.0, 0.8, 2048.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_health_check() -> [f64; NUM_FEATURES] {
|
||||||
|
// Health check probe: ~0.2 req/s, 1 path, 1 host, 0% errors, ~5ms latency,
|
||||||
|
// all GET, very regular, 100% same path, no body, 1 UA
|
||||||
|
// cookies=no (probe), referer=no, accept-lang=no, suspicious=no
|
||||||
|
[0.2, 1.0, 1.0, 0.0, 5.0, 0.0, 0.2, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_mobile_app() -> [f64; NUM_FEATURES] {
|
||||||
|
// Mobile app: ~1 req/s, several API endpoints, 1 host, ~3% errors,
|
||||||
|
// ~200ms latency (mobile network), GET + POST, moderate spacing, moderate repetition,
|
||||||
|
// small-medium bodies, 1 UA
|
||||||
|
// cookies=yes, referer=no, accept-lang=yes, suspicious=no
|
||||||
|
[1.0, 8.0, 1.0, 0.03, 200.0, 0.5, 1.2, 0.25, 256.0, 1.0, 1.0, 0.0, 1.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_search_crawler() -> [f64; NUM_FEATURES] {
|
||||||
|
// Googlebot-style crawler: ~0.3 req/s, many unique paths, 1 host, ~10% 404s,
|
||||||
|
// ~300ms latency, all GET, slow steady rate, diverse paths, no body, 1 UA
|
||||||
|
// cookies=no (crawler), referer=no, accept-lang=no, suspicious=no
|
||||||
|
[0.3, 20.0, 1.0, 0.1, 300.0, 0.0, 0.35, 0.08, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_graphql_spa() -> [f64; NUM_FEATURES] {
|
||||||
|
// SPA hitting a GraphQL endpoint: ~3 req/s, 1 path (/graphql), 1 host, ~1% errors,
|
||||||
|
// ~80ms latency, all POST, steady, 100% same path, medium bodies, 1 UA
|
||||||
|
// cookies=yes, referer=yes (SPA nav), accept-lang=yes, suspicious=no
|
||||||
|
[3.0, 1.0, 1.0, 0.01, 80.0, 0.0, 3.5, 1.0, 1024.0, 1.0, 1.0, 1.0, 1.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_websocket_upgrade() -> [f64; NUM_FEATURES] {
|
||||||
|
// Initial HTTP requests before WS upgrade: ~0.1 req/s, 2 paths, 1 host, 0% errors,
|
||||||
|
// ~10ms latency, GET, slow, some repetition, no body, 1 UA
|
||||||
|
// cookies=yes, referer=yes, accept-lang=yes, suspicious=no
|
||||||
|
[0.1, 2.0, 1.0, 0.0, 10.0, 0.0, 0.1, 0.5, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_file_upload() -> [f64; NUM_FEATURES] {
|
||||||
|
// File upload session: ~0.5 req/s, 3 paths (upload, status, confirm), 1 host,
|
||||||
|
// 0% errors, ~500ms latency (large bodies), POST + GET, steady, moderate repetition,
|
||||||
|
// large bodies, 1 UA
|
||||||
|
// cookies=yes, referer=yes, accept-lang=yes, suspicious=no
|
||||||
|
[0.5, 3.0, 1.0, 0.0, 500.0, 0.69, 0.6, 0.5, 1_000_000.0, 1.0, 1.0, 1.0, 1.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normal_multi_tenant_api() -> [f64; NUM_FEATURES] {
|
||||||
|
// API client hitting multiple hosts (multi-tenant): ~1.5 req/s, 4 paths, 3 hosts,
|
||||||
|
// ~2% errors, ~100ms latency, GET + POST, steady, low repetition, small bodies, 1 UA
|
||||||
|
// cookies=yes, referer=no, accept-lang=no, suspicious=no
|
||||||
|
[1.5, 4.0, 3.0, 0.02, 100.0, 0.69, 1.8, 0.3, 128.0, 1.0, 1.0, 0.0, 0.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Realistic attack traffic profiles
|
||||||
|
fn attack_path_scan() -> [f64; NUM_FEATURES] {
|
||||||
|
// WordPress/PHP scanner: ~20 req/s, many unique paths, 1 host, 100% 404s,
|
||||||
|
// ~2ms latency (all errors), all GET, very bursty, all unique paths, no body, 1 UA
|
||||||
|
// cookies=no, referer=no, accept-lang=no, suspicious=0.8 (most paths are probes)
|
||||||
|
[20.0, 50.0, 1.0, 1.0, 2.0, 0.0, 25.0, 0.02, 0.0, 1.0, 0.0, 0.0, 0.0, 0.8]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attack_credential_stuffing() -> [f64; NUM_FEATURES] {
|
||||||
|
// Login brute-force: ~30 req/s, 1 path (/login), 1 host, 95% 401/403,
|
||||||
|
// ~10ms latency, all POST, very bursty, 100% same path, small bodies, 1 UA
|
||||||
|
// cookies=no, referer=no, accept-lang=no, suspicious=0.0 (/login is not in suspicious list)
|
||||||
|
[30.0, 1.0, 1.0, 0.95, 10.0, 0.0, 35.0, 1.0, 64.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attack_slowloris() -> [f64; NUM_FEATURES] {
|
||||||
|
// Slowloris-style: ~0.5 req/s (slow), 1 path, 1 host, 0% errors (connections held),
|
||||||
|
// ~30000ms latency (!), all GET, slow, 100% same path, huge content-length, 1 UA
|
||||||
|
// cookies=no, referer=no, accept-lang=no, suspicious=0.0
|
||||||
|
[0.5, 1.0, 1.0, 0.0, 30000.0, 0.0, 0.5, 1.0, 10_000_000.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attack_ua_rotation() -> [f64; NUM_FEATURES] {
|
||||||
|
// Bot rotating user-agents: ~15 req/s, 2 paths, 1 host, 80% errors,
|
||||||
|
// ~5ms latency, GET + POST, bursty, high repetition, no body, 50 distinct UAs
|
||||||
|
// cookies=no, referer=no, accept-lang=no, suspicious=0.3
|
||||||
|
[15.0, 2.0, 1.0, 0.8, 5.0, 0.69, 18.0, 0.7, 0.0, 50.0, 0.0, 0.0, 0.0, 0.3]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attack_host_scan() -> [f64; NUM_FEATURES] {
|
||||||
|
// Virtual host enumeration: ~25 req/s, 1 path (/), many hosts, 100% errors,
|
||||||
|
// ~1ms latency, all GET, very bursty, 100% same path, no body, 1 UA
|
||||||
|
// cookies=no, referer=no, accept-lang=no, suspicious=0.0
|
||||||
|
[25.0, 1.0, 40.0, 1.0, 1.0, 0.0, 30.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attack_api_fuzzing() -> [f64; NUM_FEATURES] {
|
||||||
|
// API fuzzer: ~50 req/s, many paths, 1 host, 90% errors (bad inputs),
|
||||||
|
// ~3ms latency, mixed methods, extremely bursty, low repetition, varied bodies, 1 UA
|
||||||
|
// cookies=no, referer=no, accept-lang=no, suspicious=0.5
|
||||||
|
[50.0, 100.0, 1.0, 0.9, 3.0, 1.5, 55.0, 0.01, 4096.0, 1.0, 0.0, 0.0, 0.0, 0.5]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn all_normal_profiles() -> Vec<[f64; NUM_FEATURES]> {
|
||||||
|
vec![
|
||||||
|
normal_browser_browsing(),
|
||||||
|
normal_api_client(),
|
||||||
|
normal_webhook_burst(),
|
||||||
|
normal_health_check(),
|
||||||
|
normal_mobile_app(),
|
||||||
|
normal_search_crawler(),
|
||||||
|
normal_graphql_spa(),
|
||||||
|
normal_websocket_upgrade(),
|
||||||
|
normal_file_upload(),
|
||||||
|
normal_multi_tenant_api(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn all_attack_profiles() -> Vec<[f64; NUM_FEATURES]> {
|
||||||
|
vec![
|
||||||
|
attack_path_scan(),
|
||||||
|
attack_credential_stuffing(),
|
||||||
|
attack_slowloris(),
|
||||||
|
attack_ua_rotation(),
|
||||||
|
attack_host_scan(),
|
||||||
|
attack_api_fuzzing(),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a model from realistic profiles, with each profile replicated `copies`
|
||||||
|
/// times (with slight jitter) to give the KNN enough neighbors.
|
||||||
|
fn make_realistic_model(k: usize, threshold: f64) -> TrainedModel {
|
||||||
|
let mut normal = Vec::new();
|
||||||
|
let mut attack = Vec::new();
|
||||||
|
|
||||||
|
// Replicate each profile with small perturbations
|
||||||
|
for base in all_normal_profiles() {
|
||||||
|
for i in 0..20 {
|
||||||
|
let mut v = base;
|
||||||
|
for d in 0..NUM_FEATURES {
|
||||||
|
// ±5% jitter
|
||||||
|
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
|
||||||
|
v[d] *= jitter;
|
||||||
|
}
|
||||||
|
normal.push(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for base in all_attack_profiles() {
|
||||||
|
for i in 0..20 {
|
||||||
|
let mut v = base;
|
||||||
|
for d in 0..NUM_FEATURES {
|
||||||
|
let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05);
|
||||||
|
v[d] *= jitter;
|
||||||
|
}
|
||||||
|
attack.push(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
make_model(&normal, &attack, k, threshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Model classification tests — normal profiles must NEVER be blocked
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_browser_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_api_client_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_webhook_burst_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_webhook_burst()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_health_check_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_health_check()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_mobile_app_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_mobile_app()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_search_crawler_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_search_crawler()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_graphql_spa_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_graphql_spa()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_websocket_upgrade_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(
|
||||||
|
model.classify(&normal_websocket_upgrade()),
|
||||||
|
DDoSAction::Allow
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_file_upload_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_file_upload()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_multi_tenant_api_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(
|
||||||
|
model.classify(&normal_multi_tenant_api()),
|
||||||
|
DDoSAction::Allow
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Model classification tests — attack profiles must be blocked
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn attack_path_scan_is_blocked() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn attack_credential_stuffing_is_blocked() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(
|
||||||
|
model.classify(&attack_credential_stuffing()),
|
||||||
|
DDoSAction::Block
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn attack_slowloris_is_blocked() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&attack_slowloris()), DDoSAction::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn attack_ua_rotation_is_blocked() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&attack_ua_rotation()), DDoSAction::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn attack_host_scan_is_blocked() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&attack_host_scan()), DDoSAction::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn attack_api_fuzzing_is_blocked() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&attack_api_fuzzing()), DDoSAction::Block);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Edge cases: normal traffic that LOOKS suspicious but isn't
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn high_rate_legitimate_cdn_prefetch_is_allowed() {
|
||||||
|
// CDN prefetch: high rate but low errors, diverse paths, normal latency
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let profile: [f64; NUM_FEATURES] =
|
||||||
|
[8.0, 15.0, 1.0, 0.0, 100.0, 0.0, 9.0, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
|
||||||
|
assert_eq!(model.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn single_path_api_polling_is_allowed() {
|
||||||
|
// Long-poll or SSE endpoint: single path, 100% repetition, but low rate, no errors
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let profile: [f64; NUM_FEATURES] =
|
||||||
|
[0.3, 1.0, 1.0, 0.0, 1000.0, 0.0, 0.3, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0];
|
||||||
|
assert_eq!(model.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn moderate_error_rate_during_deploy_is_allowed() {
|
||||||
|
// During a rolling deploy, error rate spikes to ~20% temporarily
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let profile: [f64; NUM_FEATURES] =
|
||||||
|
[1.0, 5.0, 1.0, 0.2, 200.0, 0.5, 1.2, 0.3, 128.0, 1.0, 1.0, 0.3, 1.0, 0.0];
|
||||||
|
assert_eq!(model.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn burst_of_form_submissions_is_allowed() {
|
||||||
|
// Marketing event → users submit forms rapidly: high rate, single path, all POST, no errors
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let profile: [f64; NUM_FEATURES] =
|
||||||
|
[5.0, 1.0, 1.0, 0.0, 80.0, 0.0, 6.0, 1.0, 512.0, 1.0, 1.0, 1.0, 1.0, 0.0];
|
||||||
|
assert_eq!(model.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn legitimate_load_test_with_varied_paths_is_allowed() {
|
||||||
|
// Internal load test: high rate but diverse paths, low error, real latency
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let profile: [f64; NUM_FEATURES] =
|
||||||
|
[8.0, 30.0, 1.0, 0.02, 120.0, 0.69, 10.0, 0.05, 256.0, 1.0, 0.0, 0.0, 0.0, 0.0];
|
||||||
|
assert_eq!(model.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Threshold and k sensitivity
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn higher_threshold_is_more_permissive() {
|
||||||
|
// With threshold=0.9, even borderline traffic should be allowed
|
||||||
|
let model = make_realistic_model(5, 0.9);
|
||||||
|
// A profile that's borderline between attack and normal
|
||||||
|
let borderline: [f64; NUM_FEATURES] =
|
||||||
|
[12.0, 8.0, 1.0, 0.5, 20.0, 0.5, 14.0, 0.5, 100.0, 2.0, 0.0, 0.0, 0.0, 0.1];
|
||||||
|
assert_eq!(model.classify(&borderline), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn larger_k_smooths_classification() {
|
||||||
|
// With larger k, noisy outliers matter less
|
||||||
|
let model_k3 = make_realistic_model(3, 0.6);
|
||||||
|
let model_k9 = make_realistic_model(9, 0.6);
|
||||||
|
// Normal traffic should be allowed by both
|
||||||
|
let profile = normal_browser_browsing();
|
||||||
|
assert_eq!(model_k3.classify(&profile), DDoSAction::Allow);
|
||||||
|
assert_eq!(model_k9.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Normalization tests
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalization_clamps_out_of_range() {
|
||||||
|
let params = NormParams {
|
||||||
|
mins: [0.0; NUM_FEATURES],
|
||||||
|
maxs: [1.0; NUM_FEATURES],
|
||||||
|
};
|
||||||
|
// Values above max should clamp to 1.0
|
||||||
|
let above = [2.0; NUM_FEATURES];
|
||||||
|
let normed = params.normalize(&above);
|
||||||
|
for &v in &normed {
|
||||||
|
assert_eq!(v, 1.0);
|
||||||
|
}
|
||||||
|
// Values below min should clamp to 0.0
|
||||||
|
let below = [-1.0; NUM_FEATURES];
|
||||||
|
let normed = params.normalize(&below);
|
||||||
|
for &v in &normed {
|
||||||
|
assert_eq!(v, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalization_handles_zero_range() {
|
||||||
|
// When all training data has the same value for a feature, range = 0
|
||||||
|
let params = NormParams {
|
||||||
|
mins: [5.0; NUM_FEATURES],
|
||||||
|
maxs: [5.0; NUM_FEATURES],
|
||||||
|
};
|
||||||
|
let v = [5.0; NUM_FEATURES];
|
||||||
|
let normed = params.normalize(&v);
|
||||||
|
for &val in &normed {
|
||||||
|
assert_eq!(val, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalization_preserves_midpoint() {
|
||||||
|
let params = NormParams {
|
||||||
|
mins: [0.0; NUM_FEATURES],
|
||||||
|
maxs: [100.0; NUM_FEATURES],
|
||||||
|
};
|
||||||
|
let v = [50.0; NUM_FEATURES];
|
||||||
|
let normed = params.normalize(&v);
|
||||||
|
for &val in &normed {
|
||||||
|
assert!((val - 0.5).abs() < 1e-10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn norm_params_from_data_finds_extremes() {
|
||||||
|
let data = vec![
|
||||||
|
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0],
|
||||||
|
[10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.5, 0.5, 0.5, 0.5],
|
||||||
|
];
|
||||||
|
let params = NormParams::from_data(&data);
|
||||||
|
for i in 0..NUM_FEATURES {
|
||||||
|
assert!(params.mins[i] <= params.maxs[i]);
|
||||||
|
}
|
||||||
|
assert_eq!(params.mins[0], 1.0);
|
||||||
|
assert_eq!(params.maxs[0], 10.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Serialization round-trip
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn model_serialization_roundtrip() {
|
||||||
|
// Use the realistic model (200+ points) so fnntw has enough data for the kD-tree
|
||||||
|
let model = make_realistic_model(3, 0.5);
|
||||||
|
|
||||||
|
// Rebuild from the same training data
|
||||||
|
let mut all_points = Vec::new();
|
||||||
|
let mut all_labels = Vec::new();
|
||||||
|
for base in all_normal_profiles() {
|
||||||
|
for i in 0..20 {
|
||||||
|
let mut v = base;
|
||||||
|
for d in 0..NUM_FEATURES {
|
||||||
|
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
|
||||||
|
v[d] *= jitter;
|
||||||
|
}
|
||||||
|
all_points.push(v);
|
||||||
|
all_labels.push(TrafficLabel::Normal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for base in all_attack_profiles() {
|
||||||
|
for i in 0..20 {
|
||||||
|
let mut v = base;
|
||||||
|
for d in 0..NUM_FEATURES {
|
||||||
|
let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05);
|
||||||
|
v[d] *= jitter;
|
||||||
|
}
|
||||||
|
all_points.push(v);
|
||||||
|
all_labels.push(TrafficLabel::Attack);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm_params = NormParams::from_data(&all_points);
|
||||||
|
let serialized = SerializedModel {
|
||||||
|
points: all_points.iter().map(|v| norm_params.normalize(v)).collect(),
|
||||||
|
labels: all_labels,
|
||||||
|
norm_params,
|
||||||
|
k: 3,
|
||||||
|
threshold: 0.5,
|
||||||
|
};
|
||||||
|
|
||||||
|
let encoded = bincode::serialize(&serialized).unwrap();
|
||||||
|
let decoded: SerializedModel = bincode::deserialize(&encoded).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(decoded.points.len(), serialized.points.len());
|
||||||
|
assert_eq!(decoded.labels.len(), serialized.labels.len());
|
||||||
|
assert_eq!(decoded.k, 3);
|
||||||
|
assert!((decoded.threshold - 0.5).abs() < 1e-10);
|
||||||
|
|
||||||
|
// Rebuilt model should classify the same
|
||||||
|
let rebuilt = TrainedModel::from_serialized(decoded);
|
||||||
|
assert_eq!(
|
||||||
|
rebuilt.classify(&normal_browser_browsing()),
|
||||||
|
model.classify(&normal_browser_browsing())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Detector integration tests (full check() pipeline)
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detector_allows_below_min_events() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let detector = make_detector(model, 10);
|
||||||
|
|
||||||
|
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
|
||||||
|
// Send 9 requests — below min_events threshold of 10
|
||||||
|
for _ in 0..9 {
|
||||||
|
let action = detector.check(ip, "GET", "/wp-admin", "evil.com", "bot", 0, false, false, false);
|
||||||
|
assert_eq!(action, DDoSAction::Allow, "should allow below min_events");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detector_ipv4_and_ipv6_tracked_separately() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let detector = make_detector(model, 3);
|
||||||
|
|
||||||
|
let v4 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
|
||||||
|
let v6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
|
||||||
|
|
||||||
|
// Send events to v4 only
|
||||||
|
for _ in 0..5 {
|
||||||
|
detector.check(v4, "GET", "/", "example.com", "Mozilla/5.0", 0, true, false, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// v6 should still have 0 events (below min_events)
|
||||||
|
let action = detector.check(v6, "GET", "/", "example.com", "Mozilla/5.0", 0, true, false, true);
|
||||||
|
assert_eq!(action, DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detector_normal_browsing_pattern_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let detector = make_detector(model, 5);
|
||||||
|
|
||||||
|
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50));
|
||||||
|
let paths = ["/", "/about", "/products", "/products/1", "/contact",
|
||||||
|
"/blog", "/blog/post-1", "/docs", "/pricing", "/login",
|
||||||
|
"/dashboard", "/settings", "/api/me"];
|
||||||
|
let ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)";
|
||||||
|
|
||||||
|
for (i, path) in paths.iter().enumerate() {
|
||||||
|
let method = if i % 5 == 0 { "POST" } else { "GET" };
|
||||||
|
let action = detector.check(ip, method, path, "mysite.com", ua, 0, true, true, true);
|
||||||
|
// After min_events, every check should still allow normal browsing
|
||||||
|
assert_eq!(
|
||||||
|
action,
|
||||||
|
DDoSAction::Allow,
|
||||||
|
"normal browsing blocked on request #{i} to {path}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detector_handles_concurrent_ips() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let detector = make_detector(model, 5);
|
||||||
|
|
||||||
|
// Simulate 50 distinct IPs each making a few normal requests
|
||||||
|
for i in 0..50u8 {
|
||||||
|
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, i));
|
||||||
|
let paths = ["/", "/about", "/products", "/contact", "/blog",
|
||||||
|
"/docs", "/api/status"];
|
||||||
|
for path in &paths {
|
||||||
|
let action = detector.check(ip, "GET", path, "example.com", "Chrome", 0, true, false, true);
|
||||||
|
assert_eq!(action, DDoSAction::Allow,
|
||||||
|
"IP 10.0.0.{i} blocked on {path}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detector_ipv6_normal_traffic_is_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let detector = make_detector(model, 5);
|
||||||
|
|
||||||
|
let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42));
|
||||||
|
let paths = ["/", "/about", "/products", "/blog", "/contact",
|
||||||
|
"/login", "/dashboard"];
|
||||||
|
for path in &paths {
|
||||||
|
let action = detector.check(ip, "GET", path, "example.com",
|
||||||
|
"Mozilla/5.0", 0, true, false, true);
|
||||||
|
assert_eq!(action, DDoSAction::Allow,
|
||||||
|
"IPv6 normal traffic blocked on {path}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Model robustness: slight variations of normal traffic
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn slightly_elevated_rate_still_allowed() {
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
// 2x normal browsing rate — busy but not attacking
|
||||||
|
let profile: [f64; NUM_FEATURES] =
|
||||||
|
[1.0, 12.0, 1.0, 0.02, 150.0, 0.2, 1.2, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0];
|
||||||
|
assert_eq!(model.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn slightly_elevated_errors_still_allowed() {
|
||||||
|
// 15% errors (e.g. some 404s from broken links) — normal for real sites
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let profile: [f64; NUM_FEATURES] =
|
||||||
|
[0.5, 10.0, 1.0, 0.15, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.3, 1.0, 0.0];
|
||||||
|
assert_eq!(model.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn zero_traffic_features_allowed() {
|
||||||
|
// Edge case: all zeros (shouldn't happen in practice, but must not crash or block)
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
assert_eq!(model.classify(&[0.0; NUM_FEATURES]), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_model_always_allows() {
|
||||||
|
let model = TrainedModel::from_serialized(SerializedModel {
|
||||||
|
points: vec![],
|
||||||
|
labels: vec![],
|
||||||
|
norm_params: NormParams {
|
||||||
|
mins: [0.0; NUM_FEATURES],
|
||||||
|
maxs: [1.0; NUM_FEATURES],
|
||||||
|
},
|
||||||
|
k: 5,
|
||||||
|
threshold: 0.6,
|
||||||
|
});
|
||||||
|
// Must allow everything — no training data to compare against
|
||||||
|
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow);
|
||||||
|
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn all_normal_model_allows_everything() {
|
||||||
|
// A model trained only on normal data (no attack points) should never block.
|
||||||
|
// Use enough points (200) so fnntw can build the kD-tree.
|
||||||
|
let mut normal = Vec::new();
|
||||||
|
for base in all_normal_profiles() {
|
||||||
|
for i in 0..20 {
|
||||||
|
let mut v = base;
|
||||||
|
for d in 0..NUM_FEATURES {
|
||||||
|
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
|
||||||
|
v[d] *= jitter;
|
||||||
|
}
|
||||||
|
normal.push(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let model = make_model(&normal, &[], 5, 0.6);
|
||||||
|
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
|
||||||
|
assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow);
|
||||||
|
// Even attack-like traffic is allowed since the model has no attack examples
|
||||||
|
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Feature extraction tests
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn method_entropy_zero_for_single_method() {
|
||||||
|
// All GET requests → method distribution is [1.0, 0, 0, ...] → entropy = 0
|
||||||
|
let model = make_realistic_model(5, 0.6);
|
||||||
|
let profile = normal_health_check(); // all GET
|
||||||
|
assert_eq!(profile[5], 0.0); // method_entropy
|
||||||
|
assert_eq!(model.classify(&profile), DDoSAction::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn method_entropy_positive_for_mixed_methods() {
|
||||||
|
let profile = normal_api_client(); // mix of GET/POST
|
||||||
|
assert!(profile[5] > 0.0, "method_entropy should be positive for mixed methods");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn path_repetition_is_one_for_single_path() {
|
||||||
|
let profile = normal_graphql_spa(); // single /graphql endpoint
|
||||||
|
assert_eq!(profile[7], 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn path_repetition_is_low_for_diverse_paths() {
|
||||||
|
let profile = normal_search_crawler(); // many unique paths
|
||||||
|
assert!(profile[7] < 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===========================================================================
|
||||||
|
// Load the real trained model and validate against known profiles
|
||||||
|
// ===========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn real_model_file_roundtrip() {
|
||||||
|
let model_path = std::path::Path::new("ddos_model.bin");
|
||||||
|
if !model_path.exists() {
|
||||||
|
// Skip if no model file present (CI environments)
|
||||||
|
eprintln!("skipping real_model_file_roundtrip: ddos_model.bin not found");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let model = TrainedModel::load(model_path, Some(3), Some(0.5)).unwrap();
|
||||||
|
assert!(model.point_count() > 0, "model should have training points");
|
||||||
|
// Smoke test: classifying shouldn't panic
|
||||||
|
let _ = model.classify(&normal_browser_browsing());
|
||||||
|
let _ = model.classify(&attack_path_scan());
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user