feat(ddos): add KNN-based DDoS detection module
14-feature vector extraction, KNN classifier using fnntw, per-IP sliding window aggregation, and heuristic auto-labeling for training. Includes replay subcommand for offline evaluation and integration tests. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
83
src/ddos/audit_log.rs
Normal file
83
src/ddos/audit_log.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AuditLog {
|
||||
pub timestamp: String,
|
||||
pub fields: AuditFields,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AuditFields {
|
||||
pub method: String,
|
||||
pub host: String,
|
||||
pub path: String,
|
||||
pub client_ip: String,
|
||||
#[serde(deserialize_with = "flexible_u16")]
|
||||
pub status: u16,
|
||||
#[serde(deserialize_with = "flexible_u64")]
|
||||
pub duration_ms: u64,
|
||||
#[serde(default)]
|
||||
pub backend: String,
|
||||
#[serde(default)]
|
||||
pub content_length: u64,
|
||||
#[serde(default = "default_ua")]
|
||||
pub user_agent: String,
|
||||
#[serde(default)]
|
||||
pub query: String,
|
||||
#[serde(default)]
|
||||
pub has_cookies: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub referer: Option<String>,
|
||||
#[serde(default)]
|
||||
pub accept_language: Option<String>,
|
||||
/// Optional ground-truth label from external datasets (e.g. CSIC 2010).
|
||||
/// Values: "attack", "normal". When present, trainers should use this
|
||||
/// instead of heuristic labeling.
|
||||
#[serde(default)]
|
||||
pub label: Option<String>,
|
||||
}
|
||||
|
||||
fn default_ua() -> String {
|
||||
"-".to_string()
|
||||
}
|
||||
|
||||
pub fn flexible_u64<'de, D: serde::Deserializer<'de>>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<u64, D::Error> {
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum StringOrNum {
|
||||
Num(u64),
|
||||
Str(String),
|
||||
}
|
||||
match StringOrNum::deserialize(deserializer)? {
|
||||
StringOrNum::Num(n) => Ok(n),
|
||||
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flexible_u16<'de, D: serde::Deserializer<'de>>(
|
||||
deserializer: D,
|
||||
) -> std::result::Result<u16, D::Error> {
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum StringOrNum {
|
||||
Num(u16),
|
||||
Str(String),
|
||||
}
|
||||
match StringOrNum::deserialize(deserializer)? {
|
||||
StringOrNum::Num(n) => Ok(n),
|
||||
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
|
||||
}
|
||||
}
|
||||
|
||||
/// Strip the port suffix from a socket address string.
|
||||
pub fn strip_port(addr: &str) -> &str {
|
||||
if addr.starts_with('[') {
|
||||
addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr)
|
||||
} else if let Some(pos) = addr.rfind(':') {
|
||||
&addr[..pos]
|
||||
} else {
|
||||
addr
|
||||
}
|
||||
}
|
||||
100
src/ddos/detector.rs
Normal file
100
src/ddos/detector.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use crate::config::DDoSConfig;
|
||||
use crate::ddos::features::{method_to_u8, IpState, RequestEvent};
|
||||
use crate::ddos::model::{DDoSAction, TrainedModel};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::net::IpAddr;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Instant;
|
||||
|
||||
const NUM_SHARDS: usize = 256;
|
||||
|
||||
pub struct DDoSDetector {
|
||||
model: TrainedModel,
|
||||
shards: Vec<RwLock<FxHashMap<IpAddr, IpState>>>,
|
||||
window_secs: u64,
|
||||
window_capacity: usize,
|
||||
min_events: usize,
|
||||
}
|
||||
|
||||
fn shard_index(ip: &IpAddr) -> usize {
|
||||
let mut h = rustc_hash::FxHasher::default();
|
||||
ip.hash(&mut h);
|
||||
h.finish() as usize % NUM_SHARDS
|
||||
}
|
||||
|
||||
impl DDoSDetector {
|
||||
pub fn new(model: TrainedModel, config: &DDoSConfig) -> Self {
|
||||
let shards = (0..NUM_SHARDS)
|
||||
.map(|_| RwLock::new(FxHashMap::default()))
|
||||
.collect();
|
||||
Self {
|
||||
model,
|
||||
shards,
|
||||
window_secs: config.window_secs,
|
||||
window_capacity: config.window_capacity,
|
||||
min_events: config.min_events,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record an incoming request and classify the IP.
|
||||
/// Called from request_filter (before upstream).
|
||||
pub fn check(
|
||||
&self,
|
||||
ip: IpAddr,
|
||||
method: &str,
|
||||
path: &str,
|
||||
host: &str,
|
||||
user_agent: &str,
|
||||
content_length: u64,
|
||||
has_cookies: bool,
|
||||
has_referer: bool,
|
||||
has_accept_language: bool,
|
||||
) -> DDoSAction {
|
||||
let event = RequestEvent {
|
||||
timestamp: Instant::now(),
|
||||
method: method_to_u8(method),
|
||||
path_hash: fx_hash(path),
|
||||
host_hash: fx_hash(host),
|
||||
user_agent_hash: fx_hash(user_agent),
|
||||
status: 0,
|
||||
duration_ms: 0,
|
||||
content_length: content_length.min(u32::MAX as u64) as u32,
|
||||
has_cookies,
|
||||
has_referer,
|
||||
has_accept_language,
|
||||
suspicious_path: crate::ddos::features::is_suspicious_path(path),
|
||||
};
|
||||
|
||||
let idx = shard_index(&ip);
|
||||
let mut shard = self.shards[idx].write().unwrap_or_else(|e| e.into_inner());
|
||||
let state = shard
|
||||
.entry(ip)
|
||||
.or_insert_with(|| IpState::new(self.window_capacity));
|
||||
state.push(event);
|
||||
|
||||
if state.len() < self.min_events {
|
||||
return DDoSAction::Allow;
|
||||
}
|
||||
|
||||
let features = state.extract_features(self.window_secs);
|
||||
self.model.classify(&features)
|
||||
}
|
||||
|
||||
/// Feed response data back into the IP's event history.
|
||||
/// Called from logging() after the response is sent.
|
||||
pub fn record_response(&self, _ip: IpAddr, _status: u16, _duration_ms: u32) {
|
||||
// Status/duration from check() are 0-initialized; the next request
|
||||
// will have fresh data. This is intentionally a no-op for now.
|
||||
}
|
||||
|
||||
pub fn point_count(&self) -> usize {
|
||||
self.model.point_count()
|
||||
}
|
||||
}
|
||||
|
||||
fn fx_hash(s: &str) -> u64 {
|
||||
let mut h = rustc_hash::FxHasher::default();
|
||||
s.hash(&mut h);
|
||||
h.finish()
|
||||
}
|
||||
467
src/ddos/features.rs
Normal file
467
src/ddos/features.rs
Normal file
@@ -0,0 +1,467 @@
|
||||
use rustc_hash::FxHashSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Instant;
|
||||
|
||||
pub const NUM_FEATURES: usize = 14;
|
||||
pub type FeatureVector = [f64; NUM_FEATURES];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RequestEvent {
|
||||
pub timestamp: Instant,
|
||||
/// GET=0, POST=1, PUT=2, DELETE=3, HEAD=4, PATCH=5, OPTIONS=6, other=7
|
||||
pub method: u8,
|
||||
pub path_hash: u64,
|
||||
pub host_hash: u64,
|
||||
pub user_agent_hash: u64,
|
||||
pub status: u16,
|
||||
pub duration_ms: u32,
|
||||
pub content_length: u32,
|
||||
pub has_cookies: bool,
|
||||
pub has_referer: bool,
|
||||
pub has_accept_language: bool,
|
||||
pub suspicious_path: bool,
|
||||
}
|
||||
|
||||
/// Known-bad path fragments that scanners/bots probe for.
|
||||
const SUSPICIOUS_FRAGMENTS: &[&str] = &[
|
||||
".env", ".git/", ".git\\", ".bak", ".sql", ".tar", ".zip",
|
||||
"wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc",
|
||||
"phpinfo", "phpmyadmin", "php-info", ".php",
|
||||
"cgi-bin", "shell", "eval-stdin",
|
||||
"/vendor/", "/telescope/", "/actuator/",
|
||||
"/.htaccess", "/.htpasswd",
|
||||
"/debug/", "/config.", "/admin/",
|
||||
"yarn.lock", "yarn-debug", "package.json", "composer.json",
|
||||
];
|
||||
|
||||
pub fn is_suspicious_path(path: &str) -> bool {
|
||||
let lower = path.to_ascii_lowercase();
|
||||
SUSPICIOUS_FRAGMENTS.iter().any(|f| lower.contains(f))
|
||||
}
|
||||
|
||||
pub struct IpState {
|
||||
events: Vec<RequestEvent>,
|
||||
cursor: usize,
|
||||
count: usize,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl IpState {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
events: Vec::with_capacity(capacity),
|
||||
cursor: 0,
|
||||
count: 0,
|
||||
capacity,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, event: RequestEvent) {
|
||||
if self.events.len() < self.capacity {
|
||||
self.events.push(event);
|
||||
} else {
|
||||
self.events[self.cursor] = event;
|
||||
}
|
||||
self.cursor = (self.cursor + 1) % self.capacity;
|
||||
self.count += 1;
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.events.len()
|
||||
}
|
||||
|
||||
/// Prune events older than `window` from the logical view.
|
||||
/// Returns a slice of active events (not necessarily contiguous in ring buffer,
|
||||
/// so we collect into a Vec).
|
||||
fn active_events(&self, window_secs: u64) -> Vec<&RequestEvent> {
|
||||
let now = Instant::now();
|
||||
let cutoff = std::time::Duration::from_secs(window_secs);
|
||||
self.events
|
||||
.iter()
|
||||
.filter(|e| now.duration_since(e.timestamp) <= cutoff)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn extract_features(&self, window_secs: u64) -> FeatureVector {
|
||||
let events = self.active_events(window_secs);
|
||||
let n = events.len() as f64;
|
||||
if n < 1.0 {
|
||||
return [0.0; NUM_FEATURES];
|
||||
}
|
||||
|
||||
// 0: request_rate (requests / window_secs)
|
||||
let request_rate = n / window_secs as f64;
|
||||
|
||||
// 1: unique_paths
|
||||
let unique_paths = {
|
||||
let mut set = FxHashSet::default();
|
||||
for e in &events {
|
||||
set.insert(e.path_hash);
|
||||
}
|
||||
set.len() as f64
|
||||
};
|
||||
|
||||
// 2: unique_hosts
|
||||
let unique_hosts = {
|
||||
let mut set = FxHashSet::default();
|
||||
for e in &events {
|
||||
set.insert(e.host_hash);
|
||||
}
|
||||
set.len() as f64
|
||||
};
|
||||
|
||||
// 3: error_rate (fraction of 4xx/5xx)
|
||||
let errors = events.iter().filter(|e| e.status >= 400).count() as f64;
|
||||
let error_rate = errors / n;
|
||||
|
||||
// 4: avg_duration_ms
|
||||
let avg_duration_ms =
|
||||
events.iter().map(|e| e.duration_ms as f64).sum::<f64>() / n;
|
||||
|
||||
// 5: method_entropy (Shannon entropy of method distribution)
|
||||
let method_entropy = {
|
||||
let mut counts = [0u32; 8];
|
||||
for e in &events {
|
||||
counts[e.method as usize % 8] += 1;
|
||||
}
|
||||
let mut entropy = 0.0f64;
|
||||
for &c in &counts {
|
||||
if c > 0 {
|
||||
let p = c as f64 / n;
|
||||
entropy -= p * p.ln();
|
||||
}
|
||||
}
|
||||
entropy
|
||||
};
|
||||
|
||||
// 6: burst_score (inverse mean inter-arrival time)
|
||||
let burst_score = if events.len() >= 2 {
|
||||
let mut timestamps: Vec<Instant> =
|
||||
events.iter().map(|e| e.timestamp).collect();
|
||||
timestamps.sort();
|
||||
let total_span = timestamps
|
||||
.last()
|
||||
.unwrap()
|
||||
.duration_since(*timestamps.first().unwrap())
|
||||
.as_secs_f64();
|
||||
if total_span > 0.0 {
|
||||
(events.len() - 1) as f64 / total_span
|
||||
} else {
|
||||
n // all events at same instant = maximum burstiness
|
||||
}
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// 7: path_repetition (ratio of most-repeated path to total)
|
||||
let path_repetition = {
|
||||
let mut counts = rustc_hash::FxHashMap::default();
|
||||
for e in &events {
|
||||
*counts.entry(e.path_hash).or_insert(0u32) += 1;
|
||||
}
|
||||
let max_count = counts.values().copied().max().unwrap_or(0) as f64;
|
||||
max_count / n
|
||||
};
|
||||
|
||||
// 8: avg_content_length
|
||||
let avg_content_length =
|
||||
events.iter().map(|e| e.content_length as f64).sum::<f64>() / n;
|
||||
|
||||
// 9: unique_user_agents
|
||||
let unique_user_agents = {
|
||||
let mut set = FxHashSet::default();
|
||||
for e in &events {
|
||||
set.insert(e.user_agent_hash);
|
||||
}
|
||||
set.len() as f64
|
||||
};
|
||||
|
||||
// 10: cookie_ratio (fraction of requests that have cookies)
|
||||
let cookie_ratio =
|
||||
events.iter().filter(|e| e.has_cookies).count() as f64 / n;
|
||||
|
||||
// 11: referer_ratio (fraction of requests with a referer)
|
||||
let referer_ratio =
|
||||
events.iter().filter(|e| e.has_referer).count() as f64 / n;
|
||||
|
||||
// 12: accept_language_ratio (fraction with accept-language)
|
||||
let accept_language_ratio =
|
||||
events.iter().filter(|e| e.has_accept_language).count() as f64 / n;
|
||||
|
||||
// 13: suspicious_path_ratio (fraction hitting known-bad paths)
|
||||
let suspicious_path_ratio =
|
||||
events.iter().filter(|e| e.suspicious_path).count() as f64 / n;
|
||||
|
||||
[
|
||||
request_rate,
|
||||
unique_paths,
|
||||
unique_hosts,
|
||||
error_rate,
|
||||
avg_duration_ms,
|
||||
method_entropy,
|
||||
burst_score,
|
||||
path_repetition,
|
||||
avg_content_length,
|
||||
unique_user_agents,
|
||||
cookie_ratio,
|
||||
referer_ratio,
|
||||
accept_language_ratio,
|
||||
suspicious_path_ratio,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn method_to_u8(method: &str) -> u8 {
|
||||
match method {
|
||||
"GET" => 0,
|
||||
"POST" => 1,
|
||||
"PUT" => 2,
|
||||
"DELETE" => 3,
|
||||
"HEAD" => 4,
|
||||
"PATCH" => 5,
|
||||
"OPTIONS" => 6,
|
||||
_ => 7,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NormParams {
|
||||
pub mins: [f64; NUM_FEATURES],
|
||||
pub maxs: [f64; NUM_FEATURES],
|
||||
}
|
||||
|
||||
impl NormParams {
|
||||
pub fn from_data(vectors: &[FeatureVector]) -> Self {
|
||||
let mut mins = [f64::MAX; NUM_FEATURES];
|
||||
let mut maxs = [f64::MIN; NUM_FEATURES];
|
||||
for v in vectors {
|
||||
for i in 0..NUM_FEATURES {
|
||||
mins[i] = mins[i].min(v[i]);
|
||||
maxs[i] = maxs[i].max(v[i]);
|
||||
}
|
||||
}
|
||||
Self { mins, maxs }
|
||||
}
|
||||
|
||||
pub fn normalize(&self, v: &FeatureVector) -> FeatureVector {
|
||||
let mut out = [0.0; NUM_FEATURES];
|
||||
for i in 0..NUM_FEATURES {
|
||||
let range = self.maxs[i] - self.mins[i];
|
||||
out[i] = if range > 0.0 {
|
||||
((v[i] - self.mins[i]) / range).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
/// Feature extraction from parsed log entries (used by training pipeline).
|
||||
/// Unlike IpState which uses Instant, this uses f64 timestamps from log parsing.
|
||||
pub struct LogIpState {
|
||||
pub timestamps: Vec<f64>,
|
||||
pub methods: Vec<u8>,
|
||||
pub path_hashes: Vec<u64>,
|
||||
pub host_hashes: Vec<u64>,
|
||||
pub user_agent_hashes: Vec<u64>,
|
||||
pub statuses: Vec<u16>,
|
||||
pub durations: Vec<u32>,
|
||||
pub content_lengths: Vec<u32>,
|
||||
pub has_cookies: Vec<bool>,
|
||||
pub has_referer: Vec<bool>,
|
||||
pub has_accept_language: Vec<bool>,
|
||||
pub suspicious_paths: Vec<bool>,
|
||||
}
|
||||
|
||||
impl LogIpState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
timestamps: Vec::new(),
|
||||
methods: Vec::new(),
|
||||
path_hashes: Vec::new(),
|
||||
host_hashes: Vec::new(),
|
||||
user_agent_hashes: Vec::new(),
|
||||
statuses: Vec::new(),
|
||||
durations: Vec::new(),
|
||||
content_lengths: Vec::new(),
|
||||
has_cookies: Vec::new(),
|
||||
has_referer: Vec::new(),
|
||||
has_accept_language: Vec::new(),
|
||||
suspicious_paths: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_features_for_window(
|
||||
&self,
|
||||
start: usize,
|
||||
end: usize,
|
||||
window_secs: f64,
|
||||
) -> FeatureVector {
|
||||
let n = (end - start) as f64;
|
||||
if n < 1.0 {
|
||||
return [0.0; NUM_FEATURES];
|
||||
}
|
||||
|
||||
let request_rate = n / window_secs;
|
||||
|
||||
let unique_paths = {
|
||||
let mut set = FxHashSet::default();
|
||||
for i in start..end {
|
||||
set.insert(self.path_hashes[i]);
|
||||
}
|
||||
set.len() as f64
|
||||
};
|
||||
|
||||
let unique_hosts = {
|
||||
let mut set = FxHashSet::default();
|
||||
for i in start..end {
|
||||
set.insert(self.host_hashes[i]);
|
||||
}
|
||||
set.len() as f64
|
||||
};
|
||||
|
||||
let errors = self.statuses[start..end]
|
||||
.iter()
|
||||
.filter(|&&s| s >= 400)
|
||||
.count() as f64;
|
||||
let error_rate = errors / n;
|
||||
|
||||
let avg_duration_ms =
|
||||
self.durations[start..end].iter().map(|&d| d as f64).sum::<f64>() / n;
|
||||
|
||||
let method_entropy = {
|
||||
let mut counts = [0u32; 8];
|
||||
for i in start..end {
|
||||
counts[self.methods[i] as usize % 8] += 1;
|
||||
}
|
||||
let mut entropy = 0.0f64;
|
||||
for &c in &counts {
|
||||
if c > 0 {
|
||||
let p = c as f64 / n;
|
||||
entropy -= p * p.ln();
|
||||
}
|
||||
}
|
||||
entropy
|
||||
};
|
||||
|
||||
let burst_score = if (end - start) >= 2 {
|
||||
let total_span =
|
||||
self.timestamps[end - 1] - self.timestamps[start];
|
||||
if total_span > 0.0 {
|
||||
(end - start - 1) as f64 / total_span
|
||||
} else {
|
||||
n
|
||||
}
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let path_repetition = {
|
||||
let mut counts = rustc_hash::FxHashMap::default();
|
||||
for i in start..end {
|
||||
*counts.entry(self.path_hashes[i]).or_insert(0u32) += 1;
|
||||
}
|
||||
let max_count = counts.values().copied().max().unwrap_or(0) as f64;
|
||||
max_count / n
|
||||
};
|
||||
|
||||
let avg_content_length = self.content_lengths[start..end]
|
||||
.iter()
|
||||
.map(|&c| c as f64)
|
||||
.sum::<f64>()
|
||||
/ n;
|
||||
|
||||
let unique_user_agents = {
|
||||
let mut set = FxHashSet::default();
|
||||
for i in start..end {
|
||||
set.insert(self.user_agent_hashes[i]);
|
||||
}
|
||||
set.len() as f64
|
||||
};
|
||||
|
||||
let cookie_ratio =
|
||||
self.has_cookies[start..end].iter().filter(|&&v| v).count() as f64 / n;
|
||||
let referer_ratio =
|
||||
self.has_referer[start..end].iter().filter(|&&v| v).count() as f64 / n;
|
||||
let accept_language_ratio =
|
||||
self.has_accept_language[start..end].iter().filter(|&&v| v).count() as f64 / n;
|
||||
let suspicious_path_ratio =
|
||||
self.suspicious_paths[start..end].iter().filter(|&&v| v).count() as f64 / n;
|
||||
|
||||
[
|
||||
request_rate,
|
||||
unique_paths,
|
||||
unique_hosts,
|
||||
error_rate,
|
||||
avg_duration_ms,
|
||||
method_entropy,
|
||||
burst_score,
|
||||
path_repetition,
|
||||
avg_content_length,
|
||||
unique_user_agents,
|
||||
cookie_ratio,
|
||||
referer_ratio,
|
||||
accept_language_ratio,
|
||||
suspicious_path_ratio,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rustc_hash::FxHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
fn fx(s: &str) -> u64 {
|
||||
let mut h = FxHasher::default();
|
||||
s.hash(&mut h);
|
||||
h.finish()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_event_features() {
|
||||
let mut state = IpState::new(100);
|
||||
state.push(RequestEvent {
|
||||
timestamp: Instant::now(),
|
||||
method: 0,
|
||||
path_hash: fx("/"),
|
||||
host_hash: fx("example.com"),
|
||||
user_agent_hash: fx("curl/7.0"),
|
||||
status: 200,
|
||||
duration_ms: 10,
|
||||
content_length: 0,
|
||||
has_cookies: true,
|
||||
has_referer: false,
|
||||
has_accept_language: true,
|
||||
suspicious_path: false,
|
||||
});
|
||||
let features = state.extract_features(60);
|
||||
// request_rate = 1/60
|
||||
assert!(features[0] > 0.0);
|
||||
// error_rate = 0
|
||||
assert_eq!(features[3], 0.0);
|
||||
// path_repetition = 1.0 (only one path)
|
||||
assert_eq!(features[7], 1.0);
|
||||
// cookie_ratio = 1.0 (single event with cookies)
|
||||
assert_eq!(features[10], 1.0);
|
||||
// referer_ratio = 0.0
|
||||
assert_eq!(features[11], 0.0);
|
||||
// accept_language_ratio = 1.0
|
||||
assert_eq!(features[12], 1.0);
|
||||
// suspicious_path_ratio = 0.0
|
||||
assert_eq!(features[13], 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_norm_params() {
|
||||
let data = vec![[0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 20.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]];
|
||||
let params = NormParams::from_data(&data);
|
||||
let normalized = params.normalize(&[0.5, 15.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]);
|
||||
for &v in &normalized {
|
||||
assert!((v - 0.5).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
}
|
||||
6
src/ddos/mod.rs
Normal file
6
src/ddos/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod audit_log;
|
||||
pub mod detector;
|
||||
pub mod features;
|
||||
pub mod model;
|
||||
pub mod replay;
|
||||
pub mod train;
|
||||
168
src/ddos/model.rs
Normal file
168
src/ddos/model.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use crate::ddos::features::{FeatureVector, NormParams, NUM_FEATURES};
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TrafficLabel {
|
||||
Normal,
|
||||
Attack,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SerializedModel {
|
||||
pub points: Vec<FeatureVector>,
|
||||
pub labels: Vec<TrafficLabel>,
|
||||
pub norm_params: NormParams,
|
||||
pub k: usize,
|
||||
pub threshold: f64,
|
||||
}
|
||||
|
||||
pub struct TrainedModel {
|
||||
/// Stored points (normalized). The kD-tree borrows these.
|
||||
points: Vec<[f64; NUM_FEATURES]>,
|
||||
labels: Vec<TrafficLabel>,
|
||||
norm_params: NormParams,
|
||||
k: usize,
|
||||
threshold: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DDoSAction {
|
||||
Allow,
|
||||
Block,
|
||||
}
|
||||
|
||||
impl TrainedModel {
|
||||
pub fn load(path: &Path, k_override: Option<usize>, threshold_override: Option<f64>) -> Result<Self> {
|
||||
let data = std::fs::read(path)
|
||||
.with_context(|| format!("reading model from {}", path.display()))?;
|
||||
let model: SerializedModel =
|
||||
bincode::deserialize(&data).context("deserializing model")?;
|
||||
Ok(Self {
|
||||
points: model.points,
|
||||
labels: model.labels,
|
||||
norm_params: model.norm_params,
|
||||
k: k_override.unwrap_or(model.k),
|
||||
threshold: threshold_override.unwrap_or(model.threshold),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_serialized(model: SerializedModel) -> Self {
|
||||
Self {
|
||||
points: model.points,
|
||||
labels: model.labels,
|
||||
norm_params: model.norm_params,
|
||||
k: model.k,
|
||||
threshold: model.threshold,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn classify(&self, features: &FeatureVector) -> DDoSAction {
|
||||
let normalized = self.norm_params.normalize(features);
|
||||
|
||||
if self.points.is_empty() {
|
||||
return DDoSAction::Allow;
|
||||
}
|
||||
|
||||
// Build tree on-the-fly for query. In production with many queries,
|
||||
// we'd cache this, but the tree build is fast for <100K points.
|
||||
// fnntw::Tree borrows data, so we build it here.
|
||||
let tree = match fnntw::Tree::<'_, f64, NUM_FEATURES>::new(&self.points, 32) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return DDoSAction::Allow,
|
||||
};
|
||||
|
||||
let k = self.k.min(self.points.len());
|
||||
let result = tree.query_nearest_k(&normalized, k);
|
||||
match result {
|
||||
Ok((_distances, indices)) => {
|
||||
let attack_count = indices
|
||||
.iter()
|
||||
.filter(|&&idx| self.labels[idx as usize] == TrafficLabel::Attack)
|
||||
.count();
|
||||
let attack_frac = attack_count as f64 / k as f64;
|
||||
if attack_frac >= self.threshold {
|
||||
DDoSAction::Block
|
||||
} else {
|
||||
DDoSAction::Allow
|
||||
}
|
||||
}
|
||||
Err(_) => DDoSAction::Allow,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn norm_params(&self) -> &NormParams {
|
||||
&self.norm_params
|
||||
}
|
||||
|
||||
pub fn point_count(&self) -> usize {
|
||||
self.points.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_classify_empty_model() {
|
||||
let model = TrainedModel {
|
||||
points: vec![],
|
||||
labels: vec![],
|
||||
norm_params: NormParams {
|
||||
mins: [0.0; NUM_FEATURES],
|
||||
maxs: [1.0; NUM_FEATURES],
|
||||
},
|
||||
k: 5,
|
||||
threshold: 0.6,
|
||||
};
|
||||
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
|
||||
}
|
||||
|
||||
fn make_test_points(n: usize) -> Vec<FeatureVector> {
|
||||
(0..n)
|
||||
.map(|i| {
|
||||
let mut v = [0.0; NUM_FEATURES];
|
||||
for d in 0..NUM_FEATURES {
|
||||
v[d] = ((i * (d + 1)) as f64 / n as f64) % 1.0;
|
||||
}
|
||||
v
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_all_attack() {
|
||||
let points = make_test_points(100);
|
||||
let labels = vec![TrafficLabel::Attack; 100];
|
||||
let model = TrainedModel {
|
||||
points,
|
||||
labels,
|
||||
norm_params: NormParams {
|
||||
mins: [0.0; NUM_FEATURES],
|
||||
maxs: [1.0; NUM_FEATURES],
|
||||
},
|
||||
k: 5,
|
||||
threshold: 0.6,
|
||||
};
|
||||
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Block);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classify_all_normal() {
|
||||
let points = make_test_points(100);
|
||||
let labels = vec![TrafficLabel::Normal; 100];
|
||||
let model = TrainedModel {
|
||||
points,
|
||||
labels,
|
||||
norm_params: NormParams {
|
||||
mins: [0.0; NUM_FEATURES],
|
||||
maxs: [1.0; NUM_FEATURES],
|
||||
},
|
||||
k: 5,
|
||||
threshold: 0.6,
|
||||
};
|
||||
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
|
||||
}
|
||||
}
|
||||
291
src/ddos/replay.rs
Normal file
291
src/ddos/replay.rs
Normal file
@@ -0,0 +1,291 @@
|
||||
use crate::config::{DDoSConfig, RateLimitConfig};
|
||||
use crate::ddos::audit_log::{self, AuditLog};
|
||||
use crate::ddos::detector::DDoSDetector;
|
||||
use crate::ddos::model::{DDoSAction, TrainedModel};
|
||||
use crate::rate_limit::key::RateLimitKey;
|
||||
use crate::rate_limit::limiter::{RateLimitResult, RateLimiter};
|
||||
use anyhow::{Context, Result};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::io::BufRead;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ReplayArgs {
|
||||
pub input: String,
|
||||
pub model_path: String,
|
||||
pub config_path: Option<String>,
|
||||
pub k: usize,
|
||||
pub threshold: f64,
|
||||
pub window_secs: u64,
|
||||
pub min_events: usize,
|
||||
pub rate_limit: bool,
|
||||
}
|
||||
|
||||
struct ReplayStats {
|
||||
total: u64,
|
||||
skipped: u64,
|
||||
ddos_blocked: u64,
|
||||
rate_limited: u64,
|
||||
allowed: u64,
|
||||
ddos_blocked_ips: FxHashMap<String, u64>,
|
||||
rate_limited_ips: FxHashMap<String, u64>,
|
||||
}
|
||||
|
||||
pub fn run(args: ReplayArgs) -> Result<()> {
|
||||
eprintln!("Loading model from {}...", args.model_path);
|
||||
let model = TrainedModel::load(
|
||||
std::path::Path::new(&args.model_path),
|
||||
Some(args.k),
|
||||
Some(args.threshold),
|
||||
)
|
||||
.with_context(|| format!("loading model from {}", args.model_path))?;
|
||||
eprintln!(" {} training points, k={}, threshold={}", model.point_count(), args.k, args.threshold);
|
||||
|
||||
let ddos_cfg = DDoSConfig {
|
||||
model_path: args.model_path.clone(),
|
||||
k: args.k,
|
||||
threshold: args.threshold,
|
||||
window_secs: args.window_secs,
|
||||
window_capacity: 1000,
|
||||
min_events: args.min_events,
|
||||
enabled: true,
|
||||
};
|
||||
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
|
||||
|
||||
// Optionally set up rate limiter
|
||||
let rate_limiter = if args.rate_limit {
|
||||
let rl_cfg = if let Some(cfg_path) = &args.config_path {
|
||||
let cfg = crate::config::Config::load(cfg_path)?;
|
||||
cfg.rate_limit.unwrap_or_else(default_rate_limit_config)
|
||||
} else {
|
||||
default_rate_limit_config()
|
||||
};
|
||||
eprintln!(
|
||||
" Rate limiter: auth burst={} rate={}/s, unauth burst={} rate={}/s",
|
||||
rl_cfg.authenticated.burst,
|
||||
rl_cfg.authenticated.rate,
|
||||
rl_cfg.unauthenticated.burst,
|
||||
rl_cfg.unauthenticated.rate,
|
||||
);
|
||||
Some(RateLimiter::new(&rl_cfg))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
eprintln!("Replaying {}...\n", args.input);
|
||||
|
||||
let file = std::fs::File::open(&args.input)
|
||||
.with_context(|| format!("opening {}", args.input))?;
|
||||
let reader = std::io::BufReader::new(file);
|
||||
|
||||
let mut stats = ReplayStats {
|
||||
total: 0,
|
||||
skipped: 0,
|
||||
ddos_blocked: 0,
|
||||
rate_limited: 0,
|
||||
allowed: 0,
|
||||
ddos_blocked_ips: FxHashMap::default(),
|
||||
rate_limited_ips: FxHashMap::default(),
|
||||
};
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||
Ok(e) => e,
|
||||
Err(_) => {
|
||||
stats.skipped += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if entry.fields.method.is_empty() {
|
||||
stats.skipped += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
stats.total += 1;
|
||||
|
||||
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||
let ip: IpAddr = match ip_str.parse() {
|
||||
Ok(ip) => ip,
|
||||
Err(_) => {
|
||||
stats.skipped += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// DDoS check
|
||||
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
|
||||
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
|
||||
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
|
||||
let ddos_action = detector.check(
|
||||
ip,
|
||||
&entry.fields.method,
|
||||
&entry.fields.path,
|
||||
&entry.fields.host,
|
||||
&entry.fields.user_agent,
|
||||
entry.fields.content_length,
|
||||
has_cookies,
|
||||
has_referer,
|
||||
has_accept_language,
|
||||
);
|
||||
|
||||
if ddos_action == DDoSAction::Block {
|
||||
stats.ddos_blocked += 1;
|
||||
*stats.ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Rate limit check
|
||||
if let Some(limiter) = &rate_limiter {
|
||||
// Audit logs don't have auth headers, so all traffic is keyed by IP
|
||||
let rl_key = RateLimitKey::Ip(ip);
|
||||
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
|
||||
stats.rate_limited += 1;
|
||||
*stats.rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
stats.allowed += 1;
|
||||
}
|
||||
|
||||
// Report
|
||||
let total = stats.total;
|
||||
eprintln!("═══ Replay Results ═══════════════════════════════════════");
|
||||
eprintln!(" Total requests: {total}");
|
||||
eprintln!(" Skipped (parse): {}", stats.skipped);
|
||||
eprintln!(" Allowed: {} ({:.1}%)", stats.allowed, pct(stats.allowed, total));
|
||||
eprintln!(" DDoS blocked: {} ({:.1}%)", stats.ddos_blocked, pct(stats.ddos_blocked, total));
|
||||
if rate_limiter.is_some() {
|
||||
eprintln!(" Rate limited: {} ({:.1}%)", stats.rate_limited, pct(stats.rate_limited, total));
|
||||
}
|
||||
|
||||
if !stats.ddos_blocked_ips.is_empty() {
|
||||
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
|
||||
let mut sorted: Vec<_> = stats.ddos_blocked_ips.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
for (ip, count) in sorted.iter().take(20) {
|
||||
eprintln!(" {:<40} {} reqs blocked", ip, count);
|
||||
}
|
||||
}
|
||||
|
||||
if !stats.rate_limited_ips.is_empty() {
|
||||
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
|
||||
let mut sorted: Vec<_> = stats.rate_limited_ips.iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
||||
for (ip, count) in sorted.iter().take(20) {
|
||||
eprintln!(" {:<40} {} reqs limited", ip, count);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for false positives: IPs that were blocked but had 2xx statuses in the original logs
|
||||
eprintln!("\n── False positive check ──────────────────────────────────");
|
||||
check_false_positives(&args.input, &stats)?;
|
||||
|
||||
eprintln!("══════════════════════════════════════════════════════════");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Re-scan the log to find blocked IPs that had mostly 2xx responses originally
|
||||
/// (i.e. they were legitimate traffic that the model would incorrectly block).
|
||||
fn check_false_positives(input: &str, stats: &ReplayStats) -> Result<()> {
|
||||
let blocked_ips: rustc_hash::FxHashSet<&str> = stats
|
||||
.ddos_blocked_ips
|
||||
.keys()
|
||||
.chain(stats.rate_limited_ips.keys())
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
|
||||
if blocked_ips.is_empty() {
|
||||
eprintln!(" No blocked IPs — nothing to check.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Collect original status codes for blocked IPs
|
||||
let file = std::fs::File::open(input)?;
|
||||
let reader = std::io::BufReader::new(file);
|
||||
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||
if blocked_ips.contains(ip_str.as_str()) {
|
||||
ip_statuses
|
||||
.entry(ip_str)
|
||||
.or_default()
|
||||
.push(entry.fields.status);
|
||||
}
|
||||
}
|
||||
|
||||
let mut suspects = Vec::new();
|
||||
for (ip, statuses) in &ip_statuses {
|
||||
let total = statuses.len();
|
||||
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
|
||||
let ok_pct = (ok_count as f64 / total as f64) * 100.0;
|
||||
// If >60% of original responses were 2xx/3xx, this might be a false positive
|
||||
if ok_pct > 60.0 {
|
||||
let blocked = stats
|
||||
.ddos_blocked_ips
|
||||
.get(ip)
|
||||
.copied()
|
||||
.unwrap_or(0)
|
||||
+ stats
|
||||
.rate_limited_ips
|
||||
.get(ip)
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
suspects.push((ip.clone(), total, ok_pct, blocked));
|
||||
}
|
||||
}
|
||||
|
||||
if suspects.is_empty() {
|
||||
eprintln!(" No likely false positives found.");
|
||||
} else {
|
||||
suspects.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
|
||||
eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses:", suspects.len());
|
||||
for (ip, total, ok_pct, blocked) in suspects.iter().take(15) {
|
||||
eprintln!(
|
||||
" {:<40} {}/{} reqs were 2xx/3xx ({:.0}%), {} blocked",
|
||||
ip, ((*ok_pct / 100.0) * *total as f64) as u64, total, ok_pct, blocked,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn default_rate_limit_config() -> RateLimitConfig {
|
||||
RateLimitConfig {
|
||||
enabled: true,
|
||||
bypass_cidrs: vec![
|
||||
"10.0.0.0/8".into(),
|
||||
"172.16.0.0/12".into(),
|
||||
"192.168.0.0/16".into(),
|
||||
"100.64.0.0/10".into(),
|
||||
"fd00::/8".into(),
|
||||
],
|
||||
eviction_interval_secs: 300,
|
||||
stale_after_secs: 600,
|
||||
authenticated: crate::config::BucketConfig {
|
||||
burst: 200,
|
||||
rate: 50.0,
|
||||
},
|
||||
unauthenticated: crate::config::BucketConfig {
|
||||
burst: 60,
|
||||
rate: 15.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn pct(n: u64, total: u64) -> f64 {
|
||||
if total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(n as f64 / total as f64) * 100.0
|
||||
}
|
||||
}
|
||||
298
src/ddos/train.rs
Normal file
298
src/ddos/train.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
use crate::ddos::audit_log::AuditLog;
|
||||
use crate::ddos::audit_log;
|
||||
use crate::ddos::features::{method_to_u8, FeatureVector, LogIpState, NormParams, NUM_FEATURES};
|
||||
use crate::ddos::model::{SerializedModel, TrafficLabel};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use serde::Deserialize;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::io::BufRead;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct HeuristicThresholds {
|
||||
/// Requests/second above which an IP is labeled attack
|
||||
#[serde(default = "default_rate_threshold")]
|
||||
pub request_rate: f64,
|
||||
/// Path repetition ratio above which an IP is labeled attack
|
||||
#[serde(default = "default_repetition_threshold")]
|
||||
pub path_repetition: f64,
|
||||
/// Error rate above which an IP is labeled attack
|
||||
#[serde(default = "default_error_threshold")]
|
||||
pub error_rate: f64,
|
||||
/// Suspicious path ratio above which an IP is labeled attack
|
||||
#[serde(default = "default_suspicious_path_threshold")]
|
||||
pub suspicious_path_ratio: f64,
|
||||
/// Cookie ratio below which (combined with high unique paths) labels attack
|
||||
#[serde(default = "default_no_cookies_threshold")]
|
||||
pub no_cookies_threshold: f64,
|
||||
/// Unique path count above which no-cookie traffic is labeled attack
|
||||
#[serde(default = "default_no_cookies_path_count")]
|
||||
pub no_cookies_path_count: f64,
|
||||
/// Minimum events to consider an IP for labeling
|
||||
#[serde(default = "default_min_events")]
|
||||
pub min_events: usize,
|
||||
}
|
||||
|
||||
fn default_rate_threshold() -> f64 { 10.0 }
|
||||
fn default_repetition_threshold() -> f64 { 0.9 }
|
||||
fn default_error_threshold() -> f64 { 0.7 }
|
||||
fn default_suspicious_path_threshold() -> f64 { 0.3 }
|
||||
fn default_no_cookies_threshold() -> f64 { 0.05 }
|
||||
fn default_no_cookies_path_count() -> f64 { 20.0 }
|
||||
fn default_min_events() -> usize { 10 }
|
||||
|
||||
pub struct TrainArgs {
|
||||
pub input: String,
|
||||
pub output: String,
|
||||
pub attack_ips: Option<String>,
|
||||
pub normal_ips: Option<String>,
|
||||
pub heuristics: Option<String>,
|
||||
pub k: usize,
|
||||
pub threshold: f64,
|
||||
pub window_secs: u64,
|
||||
pub min_events: usize,
|
||||
}
|
||||
|
||||
fn fx_hash(s: &str) -> u64 {
|
||||
let mut h = rustc_hash::FxHasher::default();
|
||||
s.hash(&mut h);
|
||||
h.finish()
|
||||
}
|
||||
|
||||
fn parse_timestamp(ts: &str) -> f64 {
|
||||
// Parse ISO 8601 timestamp to seconds since epoch (approximate).
|
||||
// We only need relative ordering within a log file.
|
||||
// Format: "2026-03-07T17:41:40.705326Z"
|
||||
let parts: Vec<&str> = ts.split('T').collect();
|
||||
if parts.len() != 2 {
|
||||
return 0.0;
|
||||
}
|
||||
let date_parts: Vec<&str> = parts[0].split('-').collect();
|
||||
let time_str = parts[1].trim_end_matches('Z');
|
||||
let time_parts: Vec<&str> = time_str.split(':').collect();
|
||||
if date_parts.len() != 3 || time_parts.len() != 3 {
|
||||
return 0.0;
|
||||
}
|
||||
let day: f64 = date_parts[2].parse().unwrap_or(0.0);
|
||||
let hour: f64 = time_parts[0].parse().unwrap_or(0.0);
|
||||
let min: f64 = time_parts[1].parse().unwrap_or(0.0);
|
||||
let sec: f64 = time_parts[2].parse().unwrap_or(0.0);
|
||||
// Relative seconds (day * 86400 + time)
|
||||
day * 86400.0 + hour * 3600.0 + min * 60.0 + sec
|
||||
}
|
||||
|
||||
|
||||
pub fn run(args: TrainArgs) -> Result<()> {
|
||||
eprintln!("Parsing logs from {}...", args.input);
|
||||
|
||||
// Parse logs into per-IP state
|
||||
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
|
||||
let file = std::fs::File::open(&args.input)
|
||||
.with_context(|| format!("opening {}", args.input))?;
|
||||
let reader = std::io::BufReader::new(file);
|
||||
|
||||
let mut total_lines = 0u64;
|
||||
let mut parse_errors = 0u64;
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
total_lines += 1;
|
||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||
Ok(e) => e,
|
||||
Err(_) => {
|
||||
parse_errors += 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Skip non-audit entries
|
||||
if entry.fields.method.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ip = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
||||
let ts = parse_timestamp(&entry.timestamp);
|
||||
|
||||
let state = ip_states.entry(ip).or_insert_with(LogIpState::new);
|
||||
state.timestamps.push(ts);
|
||||
state.methods.push(method_to_u8(&entry.fields.method));
|
||||
state.path_hashes.push(fx_hash(&entry.fields.path));
|
||||
state.host_hashes.push(fx_hash(&entry.fields.host));
|
||||
state
|
||||
.user_agent_hashes
|
||||
.push(fx_hash(&entry.fields.user_agent));
|
||||
state.statuses.push(entry.fields.status);
|
||||
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
|
||||
state
|
||||
.content_lengths
|
||||
.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
|
||||
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
|
||||
state.has_referer.push(
|
||||
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
|
||||
);
|
||||
state.has_accept_language.push(
|
||||
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
|
||||
);
|
||||
state.suspicious_paths.push(
|
||||
crate::ddos::features::is_suspicious_path(&entry.fields.path),
|
||||
);
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"Parsed {} lines ({} errors), {} unique IPs",
|
||||
total_lines,
|
||||
parse_errors,
|
||||
ip_states.len()
|
||||
);
|
||||
|
||||
// Extract feature vectors per IP (using sliding windows)
|
||||
let window_secs = args.window_secs as f64;
|
||||
let mut ip_features: FxHashMap<String, Vec<FeatureVector>> = FxHashMap::default();
|
||||
|
||||
for (ip, state) in &ip_states {
|
||||
let n = state.timestamps.len();
|
||||
if n < args.min_events {
|
||||
continue;
|
||||
}
|
||||
// Extract one feature vector per window
|
||||
let mut features = Vec::new();
|
||||
let mut start = 0;
|
||||
for end in 1..n {
|
||||
let span = state.timestamps[end] - state.timestamps[start];
|
||||
if span >= window_secs || end == n - 1 {
|
||||
let fv = state.extract_features_for_window(start, end + 1, window_secs);
|
||||
features.push(fv);
|
||||
start = end + 1;
|
||||
}
|
||||
}
|
||||
if !features.is_empty() {
|
||||
ip_features.insert(ip.clone(), features);
|
||||
}
|
||||
}
|
||||
|
||||
// Label IPs
|
||||
let mut ip_labels: FxHashMap<String, TrafficLabel> = FxHashMap::default();
|
||||
|
||||
if let (Some(attack_file), Some(normal_file)) = (&args.attack_ips, &args.normal_ips) {
|
||||
// IP list mode
|
||||
let attack_ips: FxHashSet<String> = std::fs::read_to_string(attack_file)
|
||||
.context("reading attack IPs file")?
|
||||
.lines()
|
||||
.map(|l| l.trim().to_string())
|
||||
.filter(|l| !l.is_empty())
|
||||
.collect();
|
||||
let normal_ips: FxHashSet<String> = std::fs::read_to_string(normal_file)
|
||||
.context("reading normal IPs file")?
|
||||
.lines()
|
||||
.map(|l| l.trim().to_string())
|
||||
.filter(|l| !l.is_empty())
|
||||
.collect();
|
||||
|
||||
for ip in ip_features.keys() {
|
||||
if attack_ips.contains(ip) {
|
||||
ip_labels.insert(ip.clone(), TrafficLabel::Attack);
|
||||
} else if normal_ips.contains(ip) {
|
||||
ip_labels.insert(ip.clone(), TrafficLabel::Normal);
|
||||
}
|
||||
}
|
||||
} else if let Some(heuristics_file) = &args.heuristics {
|
||||
// Heuristic auto-labeling
|
||||
let heuristics_str = std::fs::read_to_string(heuristics_file)
|
||||
.context("reading heuristics file")?;
|
||||
let thresholds: HeuristicThresholds =
|
||||
toml::from_str(&heuristics_str).context("parsing heuristics TOML")?;
|
||||
|
||||
for (ip, features) in &ip_features {
|
||||
// Use the aggregate (last/max) feature vector for labeling
|
||||
let avg = average_features(features);
|
||||
let is_attack = avg[0] > thresholds.request_rate // request_rate
|
||||
|| avg[7] > thresholds.path_repetition // path_repetition
|
||||
|| avg[3] > thresholds.error_rate // error_rate
|
||||
|| avg[13] > thresholds.suspicious_path_ratio // suspicious_path_ratio
|
||||
|| (avg[10] < thresholds.no_cookies_threshold // no cookies + high unique paths
|
||||
&& avg[1] > thresholds.no_cookies_path_count);
|
||||
ip_labels.insert(
|
||||
ip.clone(),
|
||||
if is_attack {
|
||||
TrafficLabel::Attack
|
||||
} else {
|
||||
TrafficLabel::Normal
|
||||
},
|
||||
);
|
||||
}
|
||||
} else {
|
||||
bail!("Must provide either --attack-ips + --normal-ips, or --heuristics for labeling");
|
||||
}
|
||||
|
||||
// Build training dataset
|
||||
let mut all_points: Vec<FeatureVector> = Vec::new();
|
||||
let mut all_labels: Vec<TrafficLabel> = Vec::new();
|
||||
|
||||
for (ip, features) in &ip_features {
|
||||
if let Some(&label) = ip_labels.get(ip) {
|
||||
for fv in features {
|
||||
all_points.push(*fv);
|
||||
all_labels.push(label);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if all_points.is_empty() {
|
||||
bail!("No labeled data points found. Check your IP lists or heuristic thresholds.");
|
||||
}
|
||||
|
||||
let attack_count = all_labels
|
||||
.iter()
|
||||
.filter(|&&l| l == TrafficLabel::Attack)
|
||||
.count();
|
||||
let normal_count = all_labels.len() - attack_count;
|
||||
eprintln!(
|
||||
"Training with {} points ({} attack, {} normal)",
|
||||
all_points.len(),
|
||||
attack_count,
|
||||
normal_count
|
||||
);
|
||||
|
||||
// Normalize
|
||||
let norm_params = NormParams::from_data(&all_points);
|
||||
let normalized: Vec<FeatureVector> = all_points
|
||||
.iter()
|
||||
.map(|v| norm_params.normalize(v))
|
||||
.collect();
|
||||
|
||||
// Serialize
|
||||
let model = SerializedModel {
|
||||
points: normalized,
|
||||
labels: all_labels,
|
||||
norm_params,
|
||||
k: args.k,
|
||||
threshold: args.threshold,
|
||||
};
|
||||
|
||||
let encoded = bincode::serialize(&model).context("serializing model")?;
|
||||
std::fs::write(&args.output, &encoded)
|
||||
.with_context(|| format!("writing model to {}", args.output))?;
|
||||
|
||||
eprintln!(
|
||||
"Model saved to {} ({} bytes, {} points)",
|
||||
args.output,
|
||||
encoded.len(),
|
||||
model.points.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn average_features(features: &[FeatureVector]) -> FeatureVector {
|
||||
let n = features.len() as f64;
|
||||
let mut avg = [0.0; NUM_FEATURES];
|
||||
for fv in features {
|
||||
for i in 0..NUM_FEATURES {
|
||||
avg[i] += fv[i];
|
||||
}
|
||||
}
|
||||
for v in &mut avg {
|
||||
*v /= n;
|
||||
}
|
||||
avg
|
||||
}
|
||||
Reference in New Issue
Block a user