feat(ddos): add KNN-based DDoS detection module

14-feature vector extraction, KNN classifier using fnntw, per-IP
sliding window aggregation, and heuristic auto-labeling for training.
Includes replay subcommand for offline evaluation and integration tests.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:19 +00:00
parent e16299068f
commit 007865fbe7
8 changed files with 2189 additions and 0 deletions

83
src/ddos/audit_log.rs Normal file
View File

@@ -0,0 +1,83 @@
use serde::Deserialize;
#[derive(Deserialize)]
pub struct AuditLog {
pub timestamp: String,
pub fields: AuditFields,
}
#[derive(Deserialize)]
pub struct AuditFields {
pub method: String,
pub host: String,
pub path: String,
pub client_ip: String,
#[serde(deserialize_with = "flexible_u16")]
pub status: u16,
#[serde(deserialize_with = "flexible_u64")]
pub duration_ms: u64,
#[serde(default)]
pub backend: String,
#[serde(default)]
pub content_length: u64,
#[serde(default = "default_ua")]
pub user_agent: String,
#[serde(default)]
pub query: String,
#[serde(default)]
pub has_cookies: Option<bool>,
#[serde(default)]
pub referer: Option<String>,
#[serde(default)]
pub accept_language: Option<String>,
/// Optional ground-truth label from external datasets (e.g. CSIC 2010).
/// Values: "attack", "normal". When present, trainers should use this
/// instead of heuristic labeling.
#[serde(default)]
pub label: Option<String>,
}
fn default_ua() -> String {
"-".to_string()
}
pub fn flexible_u64<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<u64, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNum {
Num(u64),
Str(String),
}
match StringOrNum::deserialize(deserializer)? {
StringOrNum::Num(n) => Ok(n),
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
}
}
pub fn flexible_u16<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<u16, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNum {
Num(u16),
Str(String),
}
match StringOrNum::deserialize(deserializer)? {
StringOrNum::Num(n) => Ok(n),
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
}
}
/// Strip the port suffix from a socket address string.
pub fn strip_port(addr: &str) -> &str {
if addr.starts_with('[') {
addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr)
} else if let Some(pos) = addr.rfind(':') {
&addr[..pos]
} else {
addr
}
}

100
src/ddos/detector.rs Normal file
View File

@@ -0,0 +1,100 @@
use crate::config::DDoSConfig;
use crate::ddos::features::{method_to_u8, IpState, RequestEvent};
use crate::ddos::model::{DDoSAction, TrainedModel};
use rustc_hash::FxHashMap;
use std::hash::{Hash, Hasher};
use std::net::IpAddr;
use std::sync::RwLock;
use std::time::Instant;
const NUM_SHARDS: usize = 256;
pub struct DDoSDetector {
model: TrainedModel,
shards: Vec<RwLock<FxHashMap<IpAddr, IpState>>>,
window_secs: u64,
window_capacity: usize,
min_events: usize,
}
fn shard_index(ip: &IpAddr) -> usize {
let mut h = rustc_hash::FxHasher::default();
ip.hash(&mut h);
h.finish() as usize % NUM_SHARDS
}
impl DDoSDetector {
pub fn new(model: TrainedModel, config: &DDoSConfig) -> Self {
let shards = (0..NUM_SHARDS)
.map(|_| RwLock::new(FxHashMap::default()))
.collect();
Self {
model,
shards,
window_secs: config.window_secs,
window_capacity: config.window_capacity,
min_events: config.min_events,
}
}
/// Record an incoming request and classify the IP.
/// Called from request_filter (before upstream).
pub fn check(
&self,
ip: IpAddr,
method: &str,
path: &str,
host: &str,
user_agent: &str,
content_length: u64,
has_cookies: bool,
has_referer: bool,
has_accept_language: bool,
) -> DDoSAction {
let event = RequestEvent {
timestamp: Instant::now(),
method: method_to_u8(method),
path_hash: fx_hash(path),
host_hash: fx_hash(host),
user_agent_hash: fx_hash(user_agent),
status: 0,
duration_ms: 0,
content_length: content_length.min(u32::MAX as u64) as u32,
has_cookies,
has_referer,
has_accept_language,
suspicious_path: crate::ddos::features::is_suspicious_path(path),
};
let idx = shard_index(&ip);
let mut shard = self.shards[idx].write().unwrap_or_else(|e| e.into_inner());
let state = shard
.entry(ip)
.or_insert_with(|| IpState::new(self.window_capacity));
state.push(event);
if state.len() < self.min_events {
return DDoSAction::Allow;
}
let features = state.extract_features(self.window_secs);
self.model.classify(&features)
}
/// Feed response data back into the IP's event history.
/// Called from logging() after the response is sent.
pub fn record_response(&self, _ip: IpAddr, _status: u16, _duration_ms: u32) {
// Status/duration from check() are 0-initialized; the next request
// will have fresh data. This is intentionally a no-op for now.
}
pub fn point_count(&self) -> usize {
self.model.point_count()
}
}
fn fx_hash(s: &str) -> u64 {
let mut h = rustc_hash::FxHasher::default();
s.hash(&mut h);
h.finish()
}

467
src/ddos/features.rs Normal file
View File

@@ -0,0 +1,467 @@
use rustc_hash::FxHashSet;
use serde::{Deserialize, Serialize};
use std::time::Instant;
pub const NUM_FEATURES: usize = 14;
pub type FeatureVector = [f64; NUM_FEATURES];
#[derive(Clone)]
pub struct RequestEvent {
pub timestamp: Instant,
/// GET=0, POST=1, PUT=2, DELETE=3, HEAD=4, PATCH=5, OPTIONS=6, other=7
pub method: u8,
pub path_hash: u64,
pub host_hash: u64,
pub user_agent_hash: u64,
pub status: u16,
pub duration_ms: u32,
pub content_length: u32,
pub has_cookies: bool,
pub has_referer: bool,
pub has_accept_language: bool,
pub suspicious_path: bool,
}
/// Known-bad path fragments that scanners/bots probe for.
const SUSPICIOUS_FRAGMENTS: &[&str] = &[
".env", ".git/", ".git\\", ".bak", ".sql", ".tar", ".zip",
"wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc",
"phpinfo", "phpmyadmin", "php-info", ".php",
"cgi-bin", "shell", "eval-stdin",
"/vendor/", "/telescope/", "/actuator/",
"/.htaccess", "/.htpasswd",
"/debug/", "/config.", "/admin/",
"yarn.lock", "yarn-debug", "package.json", "composer.json",
];
pub fn is_suspicious_path(path: &str) -> bool {
let lower = path.to_ascii_lowercase();
SUSPICIOUS_FRAGMENTS.iter().any(|f| lower.contains(f))
}
pub struct IpState {
events: Vec<RequestEvent>,
cursor: usize,
count: usize,
capacity: usize,
}
impl IpState {
pub fn new(capacity: usize) -> Self {
Self {
events: Vec::with_capacity(capacity),
cursor: 0,
count: 0,
capacity,
}
}
pub fn push(&mut self, event: RequestEvent) {
if self.events.len() < self.capacity {
self.events.push(event);
} else {
self.events[self.cursor] = event;
}
self.cursor = (self.cursor + 1) % self.capacity;
self.count += 1;
}
pub fn len(&self) -> usize {
self.events.len()
}
/// Prune events older than `window` from the logical view.
/// Returns a slice of active events (not necessarily contiguous in ring buffer,
/// so we collect into a Vec).
fn active_events(&self, window_secs: u64) -> Vec<&RequestEvent> {
let now = Instant::now();
let cutoff = std::time::Duration::from_secs(window_secs);
self.events
.iter()
.filter(|e| now.duration_since(e.timestamp) <= cutoff)
.collect()
}
pub fn extract_features(&self, window_secs: u64) -> FeatureVector {
let events = self.active_events(window_secs);
let n = events.len() as f64;
if n < 1.0 {
return [0.0; NUM_FEATURES];
}
// 0: request_rate (requests / window_secs)
let request_rate = n / window_secs as f64;
// 1: unique_paths
let unique_paths = {
let mut set = FxHashSet::default();
for e in &events {
set.insert(e.path_hash);
}
set.len() as f64
};
// 2: unique_hosts
let unique_hosts = {
let mut set = FxHashSet::default();
for e in &events {
set.insert(e.host_hash);
}
set.len() as f64
};
// 3: error_rate (fraction of 4xx/5xx)
let errors = events.iter().filter(|e| e.status >= 400).count() as f64;
let error_rate = errors / n;
// 4: avg_duration_ms
let avg_duration_ms =
events.iter().map(|e| e.duration_ms as f64).sum::<f64>() / n;
// 5: method_entropy (Shannon entropy of method distribution)
let method_entropy = {
let mut counts = [0u32; 8];
for e in &events {
counts[e.method as usize % 8] += 1;
}
let mut entropy = 0.0f64;
for &c in &counts {
if c > 0 {
let p = c as f64 / n;
entropy -= p * p.ln();
}
}
entropy
};
// 6: burst_score (inverse mean inter-arrival time)
let burst_score = if events.len() >= 2 {
let mut timestamps: Vec<Instant> =
events.iter().map(|e| e.timestamp).collect();
timestamps.sort();
let total_span = timestamps
.last()
.unwrap()
.duration_since(*timestamps.first().unwrap())
.as_secs_f64();
if total_span > 0.0 {
(events.len() - 1) as f64 / total_span
} else {
n // all events at same instant = maximum burstiness
}
} else {
0.0
};
// 7: path_repetition (ratio of most-repeated path to total)
let path_repetition = {
let mut counts = rustc_hash::FxHashMap::default();
for e in &events {
*counts.entry(e.path_hash).or_insert(0u32) += 1;
}
let max_count = counts.values().copied().max().unwrap_or(0) as f64;
max_count / n
};
// 8: avg_content_length
let avg_content_length =
events.iter().map(|e| e.content_length as f64).sum::<f64>() / n;
// 9: unique_user_agents
let unique_user_agents = {
let mut set = FxHashSet::default();
for e in &events {
set.insert(e.user_agent_hash);
}
set.len() as f64
};
// 10: cookie_ratio (fraction of requests that have cookies)
let cookie_ratio =
events.iter().filter(|e| e.has_cookies).count() as f64 / n;
// 11: referer_ratio (fraction of requests with a referer)
let referer_ratio =
events.iter().filter(|e| e.has_referer).count() as f64 / n;
// 12: accept_language_ratio (fraction with accept-language)
let accept_language_ratio =
events.iter().filter(|e| e.has_accept_language).count() as f64 / n;
// 13: suspicious_path_ratio (fraction hitting known-bad paths)
let suspicious_path_ratio =
events.iter().filter(|e| e.suspicious_path).count() as f64 / n;
[
request_rate,
unique_paths,
unique_hosts,
error_rate,
avg_duration_ms,
method_entropy,
burst_score,
path_repetition,
avg_content_length,
unique_user_agents,
cookie_ratio,
referer_ratio,
accept_language_ratio,
suspicious_path_ratio,
]
}
}
pub fn method_to_u8(method: &str) -> u8 {
match method {
"GET" => 0,
"POST" => 1,
"PUT" => 2,
"DELETE" => 3,
"HEAD" => 4,
"PATCH" => 5,
"OPTIONS" => 6,
_ => 7,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NormParams {
pub mins: [f64; NUM_FEATURES],
pub maxs: [f64; NUM_FEATURES],
}
impl NormParams {
pub fn from_data(vectors: &[FeatureVector]) -> Self {
let mut mins = [f64::MAX; NUM_FEATURES];
let mut maxs = [f64::MIN; NUM_FEATURES];
for v in vectors {
for i in 0..NUM_FEATURES {
mins[i] = mins[i].min(v[i]);
maxs[i] = maxs[i].max(v[i]);
}
}
Self { mins, maxs }
}
pub fn normalize(&self, v: &FeatureVector) -> FeatureVector {
let mut out = [0.0; NUM_FEATURES];
for i in 0..NUM_FEATURES {
let range = self.maxs[i] - self.mins[i];
out[i] = if range > 0.0 {
((v[i] - self.mins[i]) / range).clamp(0.0, 1.0)
} else {
0.0
};
}
out
}
}
/// Feature extraction from parsed log entries (used by training pipeline).
/// Unlike IpState which uses Instant, this uses f64 timestamps from log parsing.
pub struct LogIpState {
pub timestamps: Vec<f64>,
pub methods: Vec<u8>,
pub path_hashes: Vec<u64>,
pub host_hashes: Vec<u64>,
pub user_agent_hashes: Vec<u64>,
pub statuses: Vec<u16>,
pub durations: Vec<u32>,
pub content_lengths: Vec<u32>,
pub has_cookies: Vec<bool>,
pub has_referer: Vec<bool>,
pub has_accept_language: Vec<bool>,
pub suspicious_paths: Vec<bool>,
}
impl LogIpState {
pub fn new() -> Self {
Self {
timestamps: Vec::new(),
methods: Vec::new(),
path_hashes: Vec::new(),
host_hashes: Vec::new(),
user_agent_hashes: Vec::new(),
statuses: Vec::new(),
durations: Vec::new(),
content_lengths: Vec::new(),
has_cookies: Vec::new(),
has_referer: Vec::new(),
has_accept_language: Vec::new(),
suspicious_paths: Vec::new(),
}
}
pub fn extract_features_for_window(
&self,
start: usize,
end: usize,
window_secs: f64,
) -> FeatureVector {
let n = (end - start) as f64;
if n < 1.0 {
return [0.0; NUM_FEATURES];
}
let request_rate = n / window_secs;
let unique_paths = {
let mut set = FxHashSet::default();
for i in start..end {
set.insert(self.path_hashes[i]);
}
set.len() as f64
};
let unique_hosts = {
let mut set = FxHashSet::default();
for i in start..end {
set.insert(self.host_hashes[i]);
}
set.len() as f64
};
let errors = self.statuses[start..end]
.iter()
.filter(|&&s| s >= 400)
.count() as f64;
let error_rate = errors / n;
let avg_duration_ms =
self.durations[start..end].iter().map(|&d| d as f64).sum::<f64>() / n;
let method_entropy = {
let mut counts = [0u32; 8];
for i in start..end {
counts[self.methods[i] as usize % 8] += 1;
}
let mut entropy = 0.0f64;
for &c in &counts {
if c > 0 {
let p = c as f64 / n;
entropy -= p * p.ln();
}
}
entropy
};
let burst_score = if (end - start) >= 2 {
let total_span =
self.timestamps[end - 1] - self.timestamps[start];
if total_span > 0.0 {
(end - start - 1) as f64 / total_span
} else {
n
}
} else {
0.0
};
let path_repetition = {
let mut counts = rustc_hash::FxHashMap::default();
for i in start..end {
*counts.entry(self.path_hashes[i]).or_insert(0u32) += 1;
}
let max_count = counts.values().copied().max().unwrap_or(0) as f64;
max_count / n
};
let avg_content_length = self.content_lengths[start..end]
.iter()
.map(|&c| c as f64)
.sum::<f64>()
/ n;
let unique_user_agents = {
let mut set = FxHashSet::default();
for i in start..end {
set.insert(self.user_agent_hashes[i]);
}
set.len() as f64
};
let cookie_ratio =
self.has_cookies[start..end].iter().filter(|&&v| v).count() as f64 / n;
let referer_ratio =
self.has_referer[start..end].iter().filter(|&&v| v).count() as f64 / n;
let accept_language_ratio =
self.has_accept_language[start..end].iter().filter(|&&v| v).count() as f64 / n;
let suspicious_path_ratio =
self.suspicious_paths[start..end].iter().filter(|&&v| v).count() as f64 / n;
[
request_rate,
unique_paths,
unique_hosts,
error_rate,
avg_duration_ms,
method_entropy,
burst_score,
path_repetition,
avg_content_length,
unique_user_agents,
cookie_ratio,
referer_ratio,
accept_language_ratio,
suspicious_path_ratio,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustc_hash::FxHasher;
use std::hash::{Hash, Hasher};
fn fx(s: &str) -> u64 {
let mut h = FxHasher::default();
s.hash(&mut h);
h.finish()
}
#[test]
fn test_single_event_features() {
let mut state = IpState::new(100);
state.push(RequestEvent {
timestamp: Instant::now(),
method: 0,
path_hash: fx("/"),
host_hash: fx("example.com"),
user_agent_hash: fx("curl/7.0"),
status: 200,
duration_ms: 10,
content_length: 0,
has_cookies: true,
has_referer: false,
has_accept_language: true,
suspicious_path: false,
});
let features = state.extract_features(60);
// request_rate = 1/60
assert!(features[0] > 0.0);
// error_rate = 0
assert_eq!(features[3], 0.0);
// path_repetition = 1.0 (only one path)
assert_eq!(features[7], 1.0);
// cookie_ratio = 1.0 (single event with cookies)
assert_eq!(features[10], 1.0);
// referer_ratio = 0.0
assert_eq!(features[11], 0.0);
// accept_language_ratio = 1.0
assert_eq!(features[12], 1.0);
// suspicious_path_ratio = 0.0
assert_eq!(features[13], 0.0);
}
#[test]
fn test_norm_params() {
let data = vec![[0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 20.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]];
let params = NormParams::from_data(&data);
let normalized = params.normalize(&[0.5, 15.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]);
for &v in &normalized {
assert!((v - 0.5).abs() < 1e-10);
}
}
}

6
src/ddos/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod audit_log;
pub mod detector;
pub mod features;
pub mod model;
pub mod replay;
pub mod train;

168
src/ddos/model.rs Normal file
View File

@@ -0,0 +1,168 @@
use crate::ddos::features::{FeatureVector, NormParams, NUM_FEATURES};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TrafficLabel {
Normal,
Attack,
}
#[derive(Serialize, Deserialize)]
pub struct SerializedModel {
pub points: Vec<FeatureVector>,
pub labels: Vec<TrafficLabel>,
pub norm_params: NormParams,
pub k: usize,
pub threshold: f64,
}
pub struct TrainedModel {
/// Stored points (normalized). The kD-tree borrows these.
points: Vec<[f64; NUM_FEATURES]>,
labels: Vec<TrafficLabel>,
norm_params: NormParams,
k: usize,
threshold: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DDoSAction {
Allow,
Block,
}
impl TrainedModel {
pub fn load(path: &Path, k_override: Option<usize>, threshold_override: Option<f64>) -> Result<Self> {
let data = std::fs::read(path)
.with_context(|| format!("reading model from {}", path.display()))?;
let model: SerializedModel =
bincode::deserialize(&data).context("deserializing model")?;
Ok(Self {
points: model.points,
labels: model.labels,
norm_params: model.norm_params,
k: k_override.unwrap_or(model.k),
threshold: threshold_override.unwrap_or(model.threshold),
})
}
pub fn from_serialized(model: SerializedModel) -> Self {
Self {
points: model.points,
labels: model.labels,
norm_params: model.norm_params,
k: model.k,
threshold: model.threshold,
}
}
pub fn classify(&self, features: &FeatureVector) -> DDoSAction {
let normalized = self.norm_params.normalize(features);
if self.points.is_empty() {
return DDoSAction::Allow;
}
// Build tree on-the-fly for query. In production with many queries,
// we'd cache this, but the tree build is fast for <100K points.
// fnntw::Tree borrows data, so we build it here.
let tree = match fnntw::Tree::<'_, f64, NUM_FEATURES>::new(&self.points, 32) {
Ok(t) => t,
Err(_) => return DDoSAction::Allow,
};
let k = self.k.min(self.points.len());
let result = tree.query_nearest_k(&normalized, k);
match result {
Ok((_distances, indices)) => {
let attack_count = indices
.iter()
.filter(|&&idx| self.labels[idx as usize] == TrafficLabel::Attack)
.count();
let attack_frac = attack_count as f64 / k as f64;
if attack_frac >= self.threshold {
DDoSAction::Block
} else {
DDoSAction::Allow
}
}
Err(_) => DDoSAction::Allow,
}
}
pub fn norm_params(&self) -> &NormParams {
&self.norm_params
}
pub fn point_count(&self) -> usize {
self.points.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_empty_model() {
let model = TrainedModel {
points: vec![],
labels: vec![],
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k: 5,
threshold: 0.6,
};
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
}
fn make_test_points(n: usize) -> Vec<FeatureVector> {
(0..n)
.map(|i| {
let mut v = [0.0; NUM_FEATURES];
for d in 0..NUM_FEATURES {
v[d] = ((i * (d + 1)) as f64 / n as f64) % 1.0;
}
v
})
.collect()
}
#[test]
fn test_classify_all_attack() {
let points = make_test_points(100);
let labels = vec![TrafficLabel::Attack; 100];
let model = TrainedModel {
points,
labels,
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k: 5,
threshold: 0.6,
};
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Block);
}
#[test]
fn test_classify_all_normal() {
let points = make_test_points(100);
let labels = vec![TrafficLabel::Normal; 100];
let model = TrainedModel {
points,
labels,
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k: 5,
threshold: 0.6,
};
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
}
}

291
src/ddos/replay.rs Normal file
View File

@@ -0,0 +1,291 @@
use crate::config::{DDoSConfig, RateLimitConfig};
use crate::ddos::audit_log::{self, AuditLog};
use crate::ddos::detector::DDoSDetector;
use crate::ddos::model::{DDoSAction, TrainedModel};
use crate::rate_limit::key::RateLimitKey;
use crate::rate_limit::limiter::{RateLimitResult, RateLimiter};
use anyhow::{Context, Result};
use rustc_hash::FxHashMap;
use std::io::BufRead;
use std::net::IpAddr;
use std::sync::Arc;
pub struct ReplayArgs {
pub input: String,
pub model_path: String,
pub config_path: Option<String>,
pub k: usize,
pub threshold: f64,
pub window_secs: u64,
pub min_events: usize,
pub rate_limit: bool,
}
struct ReplayStats {
total: u64,
skipped: u64,
ddos_blocked: u64,
rate_limited: u64,
allowed: u64,
ddos_blocked_ips: FxHashMap<String, u64>,
rate_limited_ips: FxHashMap<String, u64>,
}
pub fn run(args: ReplayArgs) -> Result<()> {
eprintln!("Loading model from {}...", args.model_path);
let model = TrainedModel::load(
std::path::Path::new(&args.model_path),
Some(args.k),
Some(args.threshold),
)
.with_context(|| format!("loading model from {}", args.model_path))?;
eprintln!(" {} training points, k={}, threshold={}", model.point_count(), args.k, args.threshold);
let ddos_cfg = DDoSConfig {
model_path: args.model_path.clone(),
k: args.k,
threshold: args.threshold,
window_secs: args.window_secs,
window_capacity: 1000,
min_events: args.min_events,
enabled: true,
};
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
// Optionally set up rate limiter
let rate_limiter = if args.rate_limit {
let rl_cfg = if let Some(cfg_path) = &args.config_path {
let cfg = crate::config::Config::load(cfg_path)?;
cfg.rate_limit.unwrap_or_else(default_rate_limit_config)
} else {
default_rate_limit_config()
};
eprintln!(
" Rate limiter: auth burst={} rate={}/s, unauth burst={} rate={}/s",
rl_cfg.authenticated.burst,
rl_cfg.authenticated.rate,
rl_cfg.unauthenticated.burst,
rl_cfg.unauthenticated.rate,
);
Some(RateLimiter::new(&rl_cfg))
} else {
None
};
eprintln!("Replaying {}...\n", args.input);
let file = std::fs::File::open(&args.input)
.with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
let mut stats = ReplayStats {
total: 0,
skipped: 0,
ddos_blocked: 0,
rate_limited: 0,
allowed: 0,
ddos_blocked_ips: FxHashMap::default(),
rate_limited_ips: FxHashMap::default(),
};
for line in reader.lines() {
let line = line?;
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => {
stats.skipped += 1;
continue;
}
};
if entry.fields.method.is_empty() {
stats.skipped += 1;
continue;
}
stats.total += 1;
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
let ip: IpAddr = match ip_str.parse() {
Ok(ip) => ip,
Err(_) => {
stats.skipped += 1;
continue;
}
};
// DDoS check
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
let ddos_action = detector.check(
ip,
&entry.fields.method,
&entry.fields.path,
&entry.fields.host,
&entry.fields.user_agent,
entry.fields.content_length,
has_cookies,
has_referer,
has_accept_language,
);
if ddos_action == DDoSAction::Block {
stats.ddos_blocked += 1;
*stats.ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
continue;
}
// Rate limit check
if let Some(limiter) = &rate_limiter {
// Audit logs don't have auth headers, so all traffic is keyed by IP
let rl_key = RateLimitKey::Ip(ip);
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
stats.rate_limited += 1;
*stats.rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
continue;
}
}
stats.allowed += 1;
}
// Report
let total = stats.total;
eprintln!("═══ Replay Results ═══════════════════════════════════════");
eprintln!(" Total requests: {total}");
eprintln!(" Skipped (parse): {}", stats.skipped);
eprintln!(" Allowed: {} ({:.1}%)", stats.allowed, pct(stats.allowed, total));
eprintln!(" DDoS blocked: {} ({:.1}%)", stats.ddos_blocked, pct(stats.ddos_blocked, total));
if rate_limiter.is_some() {
eprintln!(" Rate limited: {} ({:.1}%)", stats.rate_limited, pct(stats.rate_limited, total));
}
if !stats.ddos_blocked_ips.is_empty() {
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
let mut sorted: Vec<_> = stats.ddos_blocked_ips.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
for (ip, count) in sorted.iter().take(20) {
eprintln!(" {:<40} {} reqs blocked", ip, count);
}
}
if !stats.rate_limited_ips.is_empty() {
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
let mut sorted: Vec<_> = stats.rate_limited_ips.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
for (ip, count) in sorted.iter().take(20) {
eprintln!(" {:<40} {} reqs limited", ip, count);
}
}
// Check for false positives: IPs that were blocked but had 2xx statuses in the original logs
eprintln!("\n── False positive check ──────────────────────────────────");
check_false_positives(&args.input, &stats)?;
eprintln!("══════════════════════════════════════════════════════════");
Ok(())
}
/// Re-scan the log to find blocked IPs that had mostly 2xx responses originally
/// (i.e. they were legitimate traffic that the model would incorrectly block).
fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
let blocked_ips: rustc_hash::FxHashSet<&str> = stats
.ddos_blocked_ips
.keys()
.chain(stats.rate_limited_ips.keys())
.map(|s| s.as_str())
.collect();
if blocked_ips.is_empty() {
eprintln!(" No blocked IPs — nothing to check.");
return Ok(());
}
// Collect original status codes for blocked IPs
let file = std::fs::File::open(input)?;
let reader = std::io::BufReader::new(file);
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
for line in reader.lines() {
let line = line?;
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => continue,
};
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
if blocked_ips.contains(ip_str.as_str()) {
ip_statuses
.entry(ip_str)
.or_default()
.push(entry.fields.status);
}
}
let mut suspects = Vec::new();
for (ip, statuses) in &ip_statuses {
let total = statuses.len();
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
let ok_pct = (ok_count as f64 / total as f64) * 100.0;
// If >60% of original responses were 2xx/3xx, this might be a false positive
if ok_pct > 60.0 {
let blocked = stats
.ddos_blocked_ips
.get(ip)
.copied()
.unwrap_or(0)
+ stats
.rate_limited_ips
.get(ip)
.copied()
.unwrap_or(0);
suspects.push((ip.clone(), total, ok_pct, blocked));
}
}
if suspects.is_empty() {
eprintln!(" No likely false positives found.");
} else {
suspects.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
eprintln!("{} IPs were blocked but had mostly successful responses:", suspects.len());
for (ip, total, ok_pct, blocked) in suspects.iter().take(15) {
eprintln!(
" {:<40} {}/{} reqs were 2xx/3xx ({:.0}%), {} blocked",
ip, ((*ok_pct / 100.0) * *total as f64) as u64, total, ok_pct, blocked,
);
}
}
Ok(())
}
fn default_rate_limit_config() -> RateLimitConfig {
RateLimitConfig {
enabled: true,
bypass_cidrs: vec![
"10.0.0.0/8".into(),
"172.16.0.0/12".into(),
"192.168.0.0/16".into(),
"100.64.0.0/10".into(),
"fd00::/8".into(),
],
eviction_interval_secs: 300,
stale_after_secs: 600,
authenticated: crate::config::BucketConfig {
burst: 200,
rate: 50.0,
},
unauthenticated: crate::config::BucketConfig {
burst: 60,
rate: 15.0,
},
}
}
fn pct(n: u64, total: u64) -> f64 {
if total == 0 {
0.0
} else {
(n as f64 / total as f64) * 100.0
}
}

298
src/ddos/train.rs Normal file
View File

@@ -0,0 +1,298 @@
use crate::ddos::audit_log::AuditLog;
use crate::ddos::audit_log;
use crate::ddos::features::{method_to_u8, FeatureVector, LogIpState, NormParams, NUM_FEATURES};
use crate::ddos::model::{SerializedModel, TrafficLabel};
use anyhow::{bail, Context, Result};
use rustc_hash::{FxHashMap, FxHashSet};
use serde::Deserialize;
use std::hash::{Hash, Hasher};
use std::io::BufRead;
#[derive(Deserialize)]
pub struct HeuristicThresholds {
/// Requests/second above which an IP is labeled attack
#[serde(default = "default_rate_threshold")]
pub request_rate: f64,
/// Path repetition ratio above which an IP is labeled attack
#[serde(default = "default_repetition_threshold")]
pub path_repetition: f64,
/// Error rate above which an IP is labeled attack
#[serde(default = "default_error_threshold")]
pub error_rate: f64,
/// Suspicious path ratio above which an IP is labeled attack
#[serde(default = "default_suspicious_path_threshold")]
pub suspicious_path_ratio: f64,
/// Cookie ratio below which (combined with high unique paths) labels attack
#[serde(default = "default_no_cookies_threshold")]
pub no_cookies_threshold: f64,
/// Unique path count above which no-cookie traffic is labeled attack
#[serde(default = "default_no_cookies_path_count")]
pub no_cookies_path_count: f64,
/// Minimum events to consider an IP for labeling
#[serde(default = "default_min_events")]
pub min_events: usize,
}
fn default_rate_threshold() -> f64 { 10.0 }
fn default_repetition_threshold() -> f64 { 0.9 }
fn default_error_threshold() -> f64 { 0.7 }
fn default_suspicious_path_threshold() -> f64 { 0.3 }
fn default_no_cookies_threshold() -> f64 { 0.05 }
fn default_no_cookies_path_count() -> f64 { 20.0 }
fn default_min_events() -> usize { 10 }
pub struct TrainArgs {
pub input: String,
pub output: String,
pub attack_ips: Option<String>,
pub normal_ips: Option<String>,
pub heuristics: Option<String>,
pub k: usize,
pub threshold: f64,
pub window_secs: u64,
pub min_events: usize,
}
fn fx_hash(s: &str) -> u64 {
let mut h = rustc_hash::FxHasher::default();
s.hash(&mut h);
h.finish()
}
fn parse_timestamp(ts: &str) -> f64 {
// Parse ISO 8601 timestamp to seconds since epoch (approximate).
// We only need relative ordering within a log file.
// Format: "2026-03-07T17:41:40.705326Z"
let parts: Vec<&str> = ts.split('T').collect();
if parts.len() != 2 {
return 0.0;
}
let date_parts: Vec<&str> = parts[0].split('-').collect();
let time_str = parts[1].trim_end_matches('Z');
let time_parts: Vec<&str> = time_str.split(':').collect();
if date_parts.len() != 3 || time_parts.len() != 3 {
return 0.0;
}
let day: f64 = date_parts[2].parse().unwrap_or(0.0);
let hour: f64 = time_parts[0].parse().unwrap_or(0.0);
let min: f64 = time_parts[1].parse().unwrap_or(0.0);
let sec: f64 = time_parts[2].parse().unwrap_or(0.0);
// Relative seconds (day * 86400 + time)
day * 86400.0 + hour * 3600.0 + min * 60.0 + sec
}
pub fn run(args: TrainArgs) -> Result<()> {
eprintln!("Parsing logs from {}...", args.input);
// Parse logs into per-IP state
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
let file = std::fs::File::open(&args.input)
.with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
let mut total_lines = 0u64;
let mut parse_errors = 0u64;
for line in reader.lines() {
let line = line?;
total_lines += 1;
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => {
parse_errors += 1;
continue;
}
};
// Skip non-audit entries
if entry.fields.method.is_empty() {
continue;
}
let ip = audit_log::strip_port(&entry.fields.client_ip).to_string();
let ts = parse_timestamp(&entry.timestamp);
let state = ip_states.entry(ip).or_insert_with(LogIpState::new);
state.timestamps.push(ts);
state.methods.push(method_to_u8(&entry.fields.method));
state.path_hashes.push(fx_hash(&entry.fields.path));
state.host_hashes.push(fx_hash(&entry.fields.host));
state
.user_agent_hashes
.push(fx_hash(&entry.fields.user_agent));
state.statuses.push(entry.fields.status);
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
state
.content_lengths
.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
state.has_referer.push(
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
);
state.has_accept_language.push(
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
);
state.suspicious_paths.push(
crate::ddos::features::is_suspicious_path(&entry.fields.path),
);
}
eprintln!(
"Parsed {} lines ({} errors), {} unique IPs",
total_lines,
parse_errors,
ip_states.len()
);
// Extract feature vectors per IP (using sliding windows)
let window_secs = args.window_secs as f64;
let mut ip_features: FxHashMap<String, Vec<FeatureVector>> = FxHashMap::default();
for (ip, state) in &ip_states {
let n = state.timestamps.len();
if n < args.min_events {
continue;
}
// Extract one feature vector per window
let mut features = Vec::new();
let mut start = 0;
for end in 1..n {
let span = state.timestamps[end] - state.timestamps[start];
if span >= window_secs || end == n - 1 {
let fv = state.extract_features_for_window(start, end + 1, window_secs);
features.push(fv);
start = end + 1;
}
}
if !features.is_empty() {
ip_features.insert(ip.clone(), features);
}
}
// Label IPs
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default();
if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) {
// IP list mode
let attack_ips: FxHashSet<String> = std::fs::read_to_string(attack_file)
.context("reading attack IPs file")?
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
let normal_ips: FxHashSet<String> = std::fs::read_to_string(normal_file)
.context("reading normal IPs file")?
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
for ip in ip_features.keys() {
if attack_ips.contains(ip) {
ip_labels.insert(ip.clone(), TrafficLabel::Attack);
} else if normal_ips.contains(ip) {
ip_labels.insert(ip.clone(), TrafficLabel::Normal);
}
}
} else if let Some(heuristics_file) = &args.heuristics {
// Heuristic auto-labeling
let heuristics_str = std::fs::read_to_string(heuristics_file)
.context("reading heuristics file")?;
let thresholds: HeuristicThresholds =
toml::from_str(&heuristics_str).context("parsing heuristics TOML")?;
for (ip, features) in &ip_features {
// Use the aggregate (last/max) feature vector for labeling
let avg = average_features(features);
let is_attack = avg[0] > thresholds.request_rate // request_rate
|| avg[7] > thresholds.path_repetition // path_repetition
|| avg[3] > thresholds.error_rate // error_rate
|| avg[13] > thresholds.suspicious_path_ratio // suspicious_path_ratio
|| (avg[10] < thresholds.no_cookies_threshold // no cookies + high unique paths
&& avg[1] > thresholds.no_cookies_path_count);
ip_labels.insert(
ip.clone(),
if is_attack {
TrafficLabel::Attack
} else {
TrafficLabel::Normal
},
);
}
} else {
bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling");
}
// Build training dataset
let mut all_points: Vec<FeatureVector> = Vec::new();
let mut all_labels: Vec<TrafficLabel> = Vec::new();
for (ip, features) in &ip_features {
if let Some(&label) = ip_labels.get(ip) {
for fv in features {
all_points.push(*fv);
all_labels.push(label);
}
}
}
if all_points.is_empty() {
bail!("No labeled data points found. Check your IP lists or heuristic thresholds.");
}
let attack_count = all_labels
.iter()
.filter(|&&l| l == TrafficLabel::Attack)
.count();
let normal_count = all_labels.len() - attack_count;
eprintln!(
"Training with {} points ({} attack, {} normal)",
all_points.len(),
attack_count,
normal_count
);
// Normalize
let norm_params = NormParams::from_data(&all_points);
let normalized: Vec<FeatureVector> = all_points
.iter()
.map(|v| norm_params.normalize(v))
.collect();
// Serialize
let model = SerializedModel {
points: normalized,
labels: all_labels,
norm_params,
k: args.k,
threshold: args.threshold,
};
let encoded = bincode::serialize(&model).context("serializing model")?;
std::fs::write(&args.output, &encoded)
.with_context(|| format!("writing model to {}", args.output))?;
eprintln!(
"Model saved to {} ({} bytes, {} points)",
args.output,
encoded.len(),
model.points.len()
);
Ok(())
}
fn average_features(features: &[FeatureVector]) -> FeatureVector {
let n = features.len() as f64;
let mut avg = [0.0; NUM_FEATURES];
for fv in features {
for i in 0..NUM_FEATURES {
avg[i] += fv[i];
}
}
for v in &mut avg {
*v /= n;
}
avg
}

776
tests/ddos_test.rs Normal file
View File

@@ -0,0 +1,776 @@
//! Extensive DDoS detection tests.
//!
//! These tests build realistic traffic profiles — normal browsing, API usage,
//! webhook bursts, etc. — and verify the model never blocks legitimate traffic.
//! Attack scenarios are also tested to confirm blocking works.
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use sunbeam_proxy::config::DDoSConfig;
use sunbeam_proxy::ddos::detector::DDoSDetector;
use sunbeam_proxy::ddos::features::{NormParams, NUM_FEATURES};
use sunbeam_proxy::ddos::model::{DDoSAction, SerializedModel, TrafficLabel, TrainedModel};
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Build a model from explicit normal/attack feature vectors.
fn make_model(
normal: &[[f64; NUM_FEATURES]],
attack: &[[f64; NUM_FEATURES]],
k: usize,
threshold: f64,
) -> TrainedModel {
let mut points = Vec::new();
let mut labels = Vec::new();
for v in normal {
points.push(*v);
labels.push(TrafficLabel::Normal);
}
for v in attack {
points.push(*v);
labels.push(TrafficLabel::Attack);
}
let norm_params = NormParams::from_data(&points);
let normalized: Vec<[f64; NUM_FEATURES]> =
points.iter().map(|v| norm_params.normalize(v)).collect();
TrainedModel::from_serialized(SerializedModel {
points: normalized,
labels,
norm_params,
k,
threshold,
})
}
fn default_ddos_config() -> DDoSConfig {
DDoSConfig {
model_path: String::new(),
k: 5,
threshold: 0.6,
window_secs: 60,
window_capacity: 1000,
min_events: 10,
enabled: true,
}
}
fn make_detector(model: TrainedModel, min_events: usize) -> DDoSDetector {
let mut cfg = default_ddos_config();
cfg.min_events = min_events;
DDoSDetector::new(model, &cfg)
}
/// Feature vector indices (matching features.rs order):
/// 0: request_rate (requests / window_secs)
/// 1: unique_paths (count of distinct paths)
/// 2: unique_hosts (count of distinct hosts)
/// 3: error_rate (fraction 4xx/5xx)
/// 4: avg_duration_ms (mean response time)
/// 5: method_entropy (Shannon entropy of methods)
/// 6: burst_score (inverse mean inter-arrival)
/// 7: path_repetition (most-repeated path / total)
/// 8: avg_content_length (mean body size)
/// 9: unique_user_agents (count of distinct UAs)
/// 10: cookie_ratio (fraction with cookies)
/// 11: referer_ratio (fraction with referer)
/// 12: accept_language_ratio (fraction with accept-language)
/// 13: suspicious_path_ratio (fraction hitting known-bad paths)
/// 9: unique_user_agents (count of distinct UAs)
// Realistic normal traffic profiles
fn normal_browser_browsing() -> [f64; NUM_FEATURES] {
// A human browsing a site: ~0.5 req/s, many paths, 1 host, low errors,
// ~150ms avg latency, mostly GET, moderate spacing, diverse paths, no body, 1 UA
// cookies=yes, referer=sometimes, accept-lang=yes, suspicious=no
[0.5, 12.0, 1.0, 0.02, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0]
}
fn normal_api_client() -> [f64; NUM_FEATURES] {
// Backend API client: ~2 req/s, hits a few endpoints, 1 host, ~5% errors (retries),
// ~50ms latency, mix of GET/POST, steady rate, some path repetition, small bodies, 1 UA
// cookies=yes (session), referer=no, accept-lang=no, suspicious=no
[2.0, 5.0, 1.0, 0.05, 50.0, 0.69, 2.5, 0.4, 512.0, 1.0, 1.0, 0.0, 0.0, 0.0]
}
fn normal_webhook_burst() -> [f64; NUM_FEATURES] {
// CI/CD or webhook burst: ~10 req/s for a short period, 1-2 paths, 1 host,
// 0% errors, fast responses, all POST, bursty, high path repetition, medium bodies, 1 UA
// cookies=no (machine), referer=no, accept-lang=no, suspicious=no
[10.0, 2.0, 1.0, 0.0, 25.0, 0.0, 12.0, 0.8, 2048.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn normal_health_check() -> [f64; NUM_FEATURES] {
// Health check probe: ~0.2 req/s, 1 path, 1 host, 0% errors, ~5ms latency,
// all GET, very regular, 100% same path, no body, 1 UA
// cookies=no (probe), referer=no, accept-lang=no, suspicious=no
[0.2, 1.0, 1.0, 0.0, 5.0, 0.0, 0.2, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn normal_mobile_app() -> [f64; NUM_FEATURES] {
// Mobile app: ~1 req/s, several API endpoints, 1 host, ~3% errors,
// ~200ms latency (mobile network), GET + POST, moderate spacing, moderate repetition,
// small-medium bodies, 1 UA
// cookies=yes, referer=no, accept-lang=yes, suspicious=no
[1.0, 8.0, 1.0, 0.03, 200.0, 0.5, 1.2, 0.25, 256.0, 1.0, 1.0, 0.0, 1.0, 0.0]
}
fn normal_search_crawler() -> [f64; NUM_FEATURES] {
// Googlebot-style crawler: ~0.3 req/s, many unique paths, 1 host, ~10% 404s,
// ~300ms latency, all GET, slow steady rate, diverse paths, no body, 1 UA
// cookies=no (crawler), referer=no, accept-lang=no, suspicious=no
[0.3, 20.0, 1.0, 0.1, 300.0, 0.0, 0.35, 0.08, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn normal_graphql_spa() -> [f64; NUM_FEATURES] {
// SPA hitting a GraphQL endpoint: ~3 req/s, 1 path (/graphql), 1 host, ~1% errors,
// ~80ms latency, all POST, steady, 100% same path, medium bodies, 1 UA
// cookies=yes, referer=yes (SPA nav), accept-lang=yes, suspicious=no
[3.0, 1.0, 1.0, 0.01, 80.0, 0.0, 3.5, 1.0, 1024.0, 1.0, 1.0, 1.0, 1.0, 0.0]
}
fn normal_websocket_upgrade() -> [f64; NUM_FEATURES] {
// Initial HTTP requests before WS upgrade: ~0.1 req/s, 2 paths, 1 host, 0% errors,
// ~10ms latency, GET, slow, some repetition, no body, 1 UA
// cookies=yes, referer=yes, accept-lang=yes, suspicious=no
[0.1, 2.0, 1.0, 0.0, 10.0, 0.0, 0.1, 0.5, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0]
}
fn normal_file_upload() -> [f64; NUM_FEATURES] {
// File upload session: ~0.5 req/s, 3 paths (upload, status, confirm), 1 host,
// 0% errors, ~500ms latency (large bodies), POST + GET, steady, moderate repetition,
// large bodies, 1 UA
// cookies=yes, referer=yes, accept-lang=yes, suspicious=no
[0.5, 3.0, 1.0, 0.0, 500.0, 0.69, 0.6, 0.5, 1_000_000.0, 1.0, 1.0, 1.0, 1.0, 0.0]
}
fn normal_multi_tenant_api() -> [f64; NUM_FEATURES] {
// API client hitting multiple hosts (multi-tenant): ~1.5 req/s, 4 paths, 3 hosts,
// ~2% errors, ~100ms latency, GET + POST, steady, low repetition, small bodies, 1 UA
// cookies=yes, referer=no, accept-lang=no, suspicious=no
[1.5, 4.0, 3.0, 0.02, 100.0, 0.69, 1.8, 0.3, 128.0, 1.0, 1.0, 0.0, 0.0, 0.0]
}
// Realistic attack traffic profiles
fn attack_path_scan() -> [f64; NUM_FEATURES] {
// WordPress/PHP scanner: ~20 req/s, many unique paths, 1 host, 100% 404s,
// ~2ms latency (all errors), all GET, very bursty, all unique paths, no body, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.8 (most paths are probes)
[20.0, 50.0, 1.0, 1.0, 2.0, 0.0, 25.0, 0.02, 0.0, 1.0, 0.0, 0.0, 0.0, 0.8]
}
fn attack_credential_stuffing() -> [f64; NUM_FEATURES] {
// Login brute-force: ~30 req/s, 1 path (/login), 1 host, 95% 401/403,
// ~10ms latency, all POST, very bursty, 100% same path, small bodies, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.0 (/login is not in suspicious list)
[30.0, 1.0, 1.0, 0.95, 10.0, 0.0, 35.0, 1.0, 64.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn attack_slowloris() -> [f64; NUM_FEATURES] {
// Slowloris-style: ~0.5 req/s (slow), 1 path, 1 host, 0% errors (connections held),
// ~30000ms latency (!), all GET, slow, 100% same path, huge content-length, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.0
[0.5, 1.0, 1.0, 0.0, 30000.0, 0.0, 0.5, 1.0, 10_000_000.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn attack_ua_rotation() -> [f64; NUM_FEATURES] {
// Bot rotating user-agents: ~15 req/s, 2 paths, 1 host, 80% errors,
// ~5ms latency, GET + POST, bursty, high repetition, no body, 50 distinct UAs
// cookies=no, referer=no, accept-lang=no, suspicious=0.3
[15.0, 2.0, 1.0, 0.8, 5.0, 0.69, 18.0, 0.7, 0.0, 50.0, 0.0, 0.0, 0.0, 0.3]
}
fn attack_host_scan() -> [f64; NUM_FEATURES] {
// Virtual host enumeration: ~25 req/s, 1 path (/), many hosts, 100% errors,
// ~1ms latency, all GET, very bursty, 100% same path, no body, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.0
[25.0, 1.0, 40.0, 1.0, 1.0, 0.0, 30.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn attack_api_fuzzing() -> [f64; NUM_FEATURES] {
// API fuzzer: ~50 req/s, many paths, 1 host, 90% errors (bad inputs),
// ~3ms latency, mixed methods, extremely bursty, low repetition, varied bodies, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.5
[50.0, 100.0, 1.0, 0.9, 3.0, 1.5, 55.0, 0.01, 4096.0, 1.0, 0.0, 0.0, 0.0, 0.5]
}
fn all_normal_profiles() -> Vec<[f64; NUM_FEATURES]> {
vec![
normal_browser_browsing(),
normal_api_client(),
normal_webhook_burst(),
normal_health_check(),
normal_mobile_app(),
normal_search_crawler(),
normal_graphql_spa(),
normal_websocket_upgrade(),
normal_file_upload(),
normal_multi_tenant_api(),
]
}
fn all_attack_profiles() -> Vec<[f64; NUM_FEATURES]> {
vec![
attack_path_scan(),
attack_credential_stuffing(),
attack_slowloris(),
attack_ua_rotation(),
attack_host_scan(),
attack_api_fuzzing(),
]
}
/// Build a model from realistic profiles, with each profile replicated `copies`
/// times (with slight jitter) to give the KNN enough neighbors.
fn make_realistic_model(k: usize, threshold: f64) -> TrainedModel {
let mut normal = Vec::new();
let mut attack = Vec::new();
// Replicate each profile with small perturbations
for base in all_normal_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
// ±5% jitter
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
v[d] *= jitter;
}
normal.push(v);
}
}
for base in all_attack_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05);
v[d] *= jitter;
}
attack.push(v);
}
}
make_model(&normal, &attack, k, threshold)
}
// ===========================================================================
// Model classification tests — normal profiles must NEVER be blocked
// ===========================================================================
#[test]
fn normal_browser_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
}
#[test]
fn normal_api_client_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow);
}
#[test]
fn normal_webhook_burst_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_webhook_burst()), DDoSAction::Allow);
}
#[test]
fn normal_health_check_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_health_check()), DDoSAction::Allow);
}
#[test]
fn normal_mobile_app_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_mobile_app()), DDoSAction::Allow);
}
#[test]
fn normal_search_crawler_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_search_crawler()), DDoSAction::Allow);
}
#[test]
fn normal_graphql_spa_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_graphql_spa()), DDoSAction::Allow);
}
#[test]
fn normal_websocket_upgrade_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(
model.classify(&normal_websocket_upgrade()),
DDoSAction::Allow
);
}
#[test]
fn normal_file_upload_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_file_upload()), DDoSAction::Allow);
}
#[test]
fn normal_multi_tenant_api_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(
model.classify(&normal_multi_tenant_api()),
DDoSAction::Allow
);
}
// ===========================================================================
// Model classification tests — attack profiles must be blocked
// ===========================================================================
#[test]
fn attack_path_scan_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Block);
}
#[test]
fn attack_credential_stuffing_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(
model.classify(&attack_credential_stuffing()),
DDoSAction::Block
);
}
#[test]
fn attack_slowloris_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_slowloris()), DDoSAction::Block);
}
#[test]
fn attack_ua_rotation_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_ua_rotation()), DDoSAction::Block);
}
#[test]
fn attack_host_scan_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_host_scan()), DDoSAction::Block);
}
#[test]
fn attack_api_fuzzing_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_api_fuzzing()), DDoSAction::Block);
}
// ===========================================================================
// Edge cases: normal traffic that LOOKS suspicious but isn't
// ===========================================================================
#[test]
fn high_rate_legitimate_cdn_prefetch_is_allowed() {
// CDN prefetch: high rate but low errors, diverse paths, normal latency
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[8.0, 15.0, 1.0, 0.0, 100.0, 0.0, 9.0, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn single_path_api_polling_is_allowed() {
// Long-poll or SSE endpoint: single path, 100% repetition, but low rate, no errors
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[0.3, 1.0, 1.0, 0.0, 1000.0, 0.0, 0.3, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn moderate_error_rate_during_deploy_is_allowed() {
// During a rolling deploy, error rate spikes to ~20% temporarily
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[1.0, 5.0, 1.0, 0.2, 200.0, 0.5, 1.2, 0.3, 128.0, 1.0, 1.0, 0.3, 1.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn burst_of_form_submissions_is_allowed() {
// Marketing event → users submit forms rapidly: high rate, single path, all POST, no errors
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[5.0, 1.0, 1.0, 0.0, 80.0, 0.0, 6.0, 1.0, 512.0, 1.0, 1.0, 1.0, 1.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn legitimate_load_test_with_varied_paths_is_allowed() {
// Internal load test: high rate but diverse paths, low error, real latency
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[8.0, 30.0, 1.0, 0.02, 120.0, 0.69, 10.0, 0.05, 256.0, 1.0, 0.0, 0.0, 0.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
// ===========================================================================
// Threshold and k sensitivity
// ===========================================================================
#[test]
fn higher_threshold_is_more_permissive() {
// With threshold=0.9, even borderline traffic should be allowed
let model = make_realistic_model(5, 0.9);
// A profile that's borderline between attack and normal
let borderline: [f64; NUM_FEATURES] =
[12.0, 8.0, 1.0, 0.5, 20.0, 0.5, 14.0, 0.5, 100.0, 2.0, 0.0, 0.0, 0.0, 0.1];
assert_eq!(model.classify(&borderline), DDoSAction::Allow);
}
#[test]
fn larger_k_smooths_classification() {
// With larger k, noisy outliers matter less
let model_k3 = make_realistic_model(3, 0.6);
let model_k9 = make_realistic_model(9, 0.6);
// Normal traffic should be allowed by both
let profile = normal_browser_browsing();
assert_eq!(model_k3.classify(&profile), DDoSAction::Allow);
assert_eq!(model_k9.classify(&profile), DDoSAction::Allow);
}
// ===========================================================================
// Normalization tests
// ===========================================================================
#[test]
fn normalization_clamps_out_of_range() {
let params = NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
};
// Values above max should clamp to 1.0
let above = [2.0; NUM_FEATURES];
let normed = params.normalize(&above);
for &v in &normed {
assert_eq!(v, 1.0);
}
// Values below min should clamp to 0.0
let below = [-1.0; NUM_FEATURES];
let normed = params.normalize(&below);
for &v in &normed {
assert_eq!(v, 0.0);
}
}
#[test]
fn normalization_handles_zero_range() {
// When all training data has the same value for a feature, range = 0
let params = NormParams {
mins: [5.0; NUM_FEATURES],
maxs: [5.0; NUM_FEATURES],
};
let v = [5.0; NUM_FEATURES];
let normed = params.normalize(&v);
for &val in &normed {
assert_eq!(val, 0.0);
}
}
#[test]
fn normalization_preserves_midpoint() {
let params = NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [100.0; NUM_FEATURES],
};
let v = [50.0; NUM_FEATURES];
let normed = params.normalize(&v);
for &val in &normed {
assert!((val - 0.5).abs() < 1e-10);
}
}
#[test]
fn norm_params_from_data_finds_extremes() {
let data = vec![
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0],
[10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.5, 0.5, 0.5, 0.5],
];
let params = NormParams::from_data(&data);
for i in 0..NUM_FEATURES {
assert!(params.mins[i] <= params.maxs[i]);
}
assert_eq!(params.mins[0], 1.0);
assert_eq!(params.maxs[0], 10.0);
}
// ===========================================================================
// Serialization round-trip
// ===========================================================================
#[test]
fn model_serialization_roundtrip() {
// Use the realistic model (200+ points) so fnntw has enough data for the kD-tree
let model = make_realistic_model(3, 0.5);
// Rebuild from the same training data
let mut all_points = Vec::new();
let mut all_labels = Vec::new();
for base in all_normal_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
v[d] *= jitter;
}
all_points.push(v);
all_labels.push(TrafficLabel::Normal);
}
}
for base in all_attack_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05);
v[d] *= jitter;
}
all_points.push(v);
all_labels.push(TrafficLabel::Attack);
}
}
let norm_params = NormParams::from_data(&all_points);
let serialized = SerializedModel {
points: all_points.iter().map(|v| norm_params.normalize(v)).collect(),
labels: all_labels,
norm_params,
k: 3,
threshold: 0.5,
};
let encoded = bincode::serialize(&serialized).unwrap();
let decoded: SerializedModel = bincode::deserialize(&encoded).unwrap();
assert_eq!(decoded.points.len(), serialized.points.len());
assert_eq!(decoded.labels.len(), serialized.labels.len());
assert_eq!(decoded.k, 3);
assert!((decoded.threshold - 0.5).abs() < 1e-10);
// Rebuilt model should classify the same
let rebuilt = TrainedModel::from_serialized(decoded);
assert_eq!(
rebuilt.classify(&normal_browser_browsing()),
model.classify(&normal_browser_browsing())
);
}
// ===========================================================================
// Detector integration tests (full check() pipeline)
// ===========================================================================
#[test]
fn detector_allows_below_min_events() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 10);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
// Send 9 requests — below min_events threshold of 10
for _ in 0..9 {
let action = detector.check(ip, "GET", "/wp-admin", "evil.com", "bot", 0, false, false, false);
assert_eq!(action, DDoSAction::Allow, "should allow below min_events");
}
}
#[test]
fn detector_ipv4_and_ipv6_tracked_separately() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 3);
let v4 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let v6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
// Send events to v4 only
for _ in 0..5 {
detector.check(v4, "GET", "/", "example.com", "Mozilla/5.0", 0, true, false, true);
}
// v6 should still have 0 events (below min_events)
let action = detector.check(v6, "GET", "/", "example.com", "Mozilla/5.0", 0, true, false, true);
assert_eq!(action, DDoSAction::Allow);
}
#[test]
fn detector_normal_browsing_pattern_is_allowed() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 5);
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50));
let paths = ["/", "/about", "/products", "/products/1", "/contact",
"/blog", "/blog/post-1", "/docs", "/pricing", "/login",
"/dashboard", "/settings", "/api/me"];
let ua = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)";
for (i, path) in paths.iter().enumerate() {
let method = if i % 5 == 0 { "POST" } else { "GET" };
let action = detector.check(ip, method, path, "mysite.com", ua, 0, true, true, true);
// After min_events, every check should still allow normal browsing
assert_eq!(
action,
DDoSAction::Allow,
"normal browsing blocked on request #{i} to {path}"
);
}
}
#[test]
fn detector_handles_concurrent_ips() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 5);
// Simulate 50 distinct IPs each making a few normal requests
for i in 0..50u8 {
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, i));
let paths = ["/", "/about", "/products", "/contact", "/blog",
"/docs", "/api/status"];
for path in &paths {
let action = detector.check(ip, "GET", path, "example.com", "Chrome", 0, true, false, true);
assert_eq!(action, DDoSAction::Allow,
"IP 10.0.0.{i} blocked on {path}");
}
}
}
#[test]
fn detector_ipv6_normal_traffic_is_allowed() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 5);
let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42));
let paths = ["/", "/about", "/products", "/blog", "/contact",
"/login", "/dashboard"];
for path in &paths {
let action = detector.check(ip, "GET", path, "example.com",
"Mozilla/5.0", 0, true, false, true);
assert_eq!(action, DDoSAction::Allow,
"IPv6 normal traffic blocked on {path}");
}
}
// ===========================================================================
// Model robustness: slight variations of normal traffic
// ===========================================================================
#[test]
fn slightly_elevated_rate_still_allowed() {
let model = make_realistic_model(5, 0.6);
// 2x normal browsing rate — busy but not attacking
let profile: [f64; NUM_FEATURES] =
[1.0, 12.0, 1.0, 0.02, 150.0, 0.2, 1.2, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn slightly_elevated_errors_still_allowed() {
// 15% errors (e.g. some 404s from broken links) — normal for real sites
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[0.5, 10.0, 1.0, 0.15, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.3, 1.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn zero_traffic_features_allowed() {
// Edge case: all zeros (shouldn't happen in practice, but must not crash or block)
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&[0.0; NUM_FEATURES]), DDoSAction::Allow);
}
#[test]
fn empty_model_always_allows() {
let model = TrainedModel::from_serialized(SerializedModel {
points: vec![],
labels: vec![],
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k: 5,
threshold: 0.6,
});
// Must allow everything — no training data to compare against
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow);
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
}
#[test]
fn all_normal_model_allows_everything() {
// A model trained only on normal data (no attack points) should never block.
// Use enough points (200) so fnntw can build the kD-tree.
let mut normal = Vec::new();
for base in all_normal_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
v[d] *= jitter;
}
normal.push(v);
}
}
let model = make_model(&normal, &[], 5, 0.6);
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow);
// Even attack-like traffic is allowed since the model has no attack examples
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow);
}
// ===========================================================================
// Feature extraction tests
// ===========================================================================
#[test]
fn method_entropy_zero_for_single_method() {
// All GET requests → method distribution is [1.0, 0, 0, ...] → entropy = 0
let model = make_realistic_model(5, 0.6);
let profile = normal_health_check(); // all GET
assert_eq!(profile[5], 0.0); // method_entropy
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn method_entropy_positive_for_mixed_methods() {
let profile = normal_api_client(); // mix of GET/POST
assert!(profile[5] > 0.0, "method_entropy should be positive for mixed methods");
}
#[test]
fn path_repetition_is_one_for_single_path() {
let profile = normal_graphql_spa(); // single /graphql endpoint
assert_eq!(profile[7], 1.0);
}
#[test]
fn path_repetition_is_low_for_diverse_paths() {
let profile = normal_search_crawler(); // many unique paths
assert!(profile[7] < 0.2);
}
// ===========================================================================
// Load the real trained model and validate against known profiles
// ===========================================================================
#[test]
fn real_model_file_roundtrip() {
let model_path = std::path::Path::new("ddos_model.bin");
if !model_path.exists() {
// Skip if no model file present (CI environments)
eprintln!("skipping real_model_file_roundtrip: ddos_model.bin not found");
return;
}
let model = TrainedModel::load(model_path, Some(3), Some(0.5)).unwrap();
assert!(model.point_count() > 0, "model should have training points");
// Smoke test: classifying shouldn't panic
let _ = model.classify(&normal_browser_browsing());
let _ = model.classify(&attack_path_scan());
}