feat: complete ensemble integration and remove legacy model code
- Remove legacy KNN DDoS replay and scanner model file watcher - Wire ensemble inference into detector check() paths - Update config: remove model_path/k/poll_interval_secs, add observe_only - Add cookie_weight sweep CLI command for hyperparameter exploration - Update training pipeline: batch iterator, weight export improvements - Retrain ensemble weights (scanner 99.73%, DDoS 99.99% val accuracy) - Add unified audit log module - Update dataset parsers with copyright headers and minor fixes Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
236
src/audit.rs
Normal file
236
src/audit.rs
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
//! Unified audit log record definition.
|
||||||
|
//!
|
||||||
|
//! This module is the **single source of truth** for the audit log schema.
|
||||||
|
//! Both the proxy's `logging()` method (serialization) and every training/replay
|
||||||
|
//! parser (deserialization) must use these types. `deny_unknown_fields` ensures
|
||||||
|
//! that any schema change is caught at parse time rather than silently ignored.
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Minimal probe struct to check if a JSON line is an audit log.
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct Probe {
|
||||||
|
#[serde(default)]
|
||||||
|
fields: Option<ProbeFields>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ProbeFields {
|
||||||
|
#[serde(default)]
|
||||||
|
target: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Top-level JSON line written by the tracing JSON layer.
|
||||||
|
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||||
|
pub struct AuditLogLine {
|
||||||
|
pub timestamp: String,
|
||||||
|
pub level: String,
|
||||||
|
pub fields: AuditFields,
|
||||||
|
/// Span information injected by tracing layers.
|
||||||
|
#[serde(default)]
|
||||||
|
pub span: Option<serde_json::Value>,
|
||||||
|
/// Span list injected by tracing layers.
|
||||||
|
#[serde(default)]
|
||||||
|
pub spans: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The request audit fields — canonical schema.
|
||||||
|
///
|
||||||
|
/// Every field the proxy emits in `tracing::info!(target = "audit", ...)`
|
||||||
|
/// must appear here. `deny_unknown_fields` will cause a hard parse error
|
||||||
|
/// if the proxy starts emitting a field that isn't listed, forcing you to
|
||||||
|
/// update this struct (and all downstream consumers) in one place.
|
||||||
|
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||||
|
#[serde(deny_unknown_fields)]
|
||||||
|
pub struct AuditFields {
|
||||||
|
/// The literal "request" message from `tracing::info!("request")`.
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub message: String,
|
||||||
|
/// The tracing target, always "audit" for request logs.
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub target: String,
|
||||||
|
|
||||||
|
// --- request identity ---
|
||||||
|
#[serde(default)]
|
||||||
|
pub request_id: String,
|
||||||
|
pub method: String,
|
||||||
|
pub host: String,
|
||||||
|
pub path: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub query: String,
|
||||||
|
pub client_ip: String,
|
||||||
|
|
||||||
|
// --- response ---
|
||||||
|
#[serde(deserialize_with = "flexible_u16")]
|
||||||
|
pub status: u16,
|
||||||
|
#[serde(deserialize_with = "flexible_u64")]
|
||||||
|
pub duration_ms: u64,
|
||||||
|
#[serde(default, deserialize_with = "flexible_u64_default")]
|
||||||
|
pub content_length: u64,
|
||||||
|
#[serde(default, deserialize_with = "flexible_u64_default")]
|
||||||
|
pub response_bytes: u64,
|
||||||
|
|
||||||
|
// --- headers ---
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub user_agent: String,
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub referer: String,
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub accept_language: String,
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub accept: String,
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub accept_encoding: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub has_cookies: bool,
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub connection: String,
|
||||||
|
|
||||||
|
// --- infra ---
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub cf_country: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub backend: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub error: String,
|
||||||
|
#[serde(default = "default_dash")]
|
||||||
|
pub http_version: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub header_count: u16,
|
||||||
|
|
||||||
|
// --- training only (not emitted by proxy, but present in external datasets) ---
|
||||||
|
/// Ground-truth label injected by external dataset parsers.
|
||||||
|
/// Values: "attack", "normal".
|
||||||
|
#[serde(default)]
|
||||||
|
pub label: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuditLogLine {
|
||||||
|
/// Try to parse a JSON line as an audit log entry.
|
||||||
|
///
|
||||||
|
/// - Returns `Ok(Some(entry))` for valid audit log lines.
|
||||||
|
/// - Returns `Ok(None)` for non-audit lines (TLS errors, etc.).
|
||||||
|
/// - Returns `Err` if the line IS an audit log but has unknown fields
|
||||||
|
/// (schema drift that needs fixing).
|
||||||
|
pub fn try_parse(line: &str) -> Result<Option<Self>, String> {
|
||||||
|
// Quick probe: is this an audit-target line?
|
||||||
|
let probe: Probe = match serde_json::from_str(line) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => return Ok(None), // not valid JSON or missing fields
|
||||||
|
};
|
||||||
|
let is_audit = probe
|
||||||
|
.fields
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|f| f.target.as_deref())
|
||||||
|
.map(|t| t == "audit")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if !is_audit {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full parse with deny_unknown_fields — will error on schema drift.
|
||||||
|
match serde_json::from_str::<Self>(line) {
|
||||||
|
Ok(entry) => Ok(Some(entry)),
|
||||||
|
Err(e) => Err(format!(
|
||||||
|
"audit log schema mismatch (update src/audit.rs): {e}"
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AuditFields {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
message: "request".to_string(),
|
||||||
|
target: "audit".to_string(),
|
||||||
|
request_id: String::new(),
|
||||||
|
method: String::new(),
|
||||||
|
host: String::new(),
|
||||||
|
path: String::new(),
|
||||||
|
query: String::new(),
|
||||||
|
client_ip: String::new(),
|
||||||
|
status: 0,
|
||||||
|
duration_ms: 0,
|
||||||
|
content_length: 0,
|
||||||
|
response_bytes: 0,
|
||||||
|
user_agent: "-".to_string(),
|
||||||
|
referer: "-".to_string(),
|
||||||
|
accept_language: "-".to_string(),
|
||||||
|
accept: "-".to_string(),
|
||||||
|
accept_encoding: "-".to_string(),
|
||||||
|
has_cookies: false,
|
||||||
|
connection: "-".to_string(),
|
||||||
|
cf_country: "-".to_string(),
|
||||||
|
backend: String::new(),
|
||||||
|
error: String::new(),
|
||||||
|
http_version: "-".to_string(),
|
||||||
|
header_count: 0,
|
||||||
|
label: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_dash() -> String {
|
||||||
|
"-".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flexible_u64<'de, D: serde::Deserializer<'de>>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> std::result::Result<u64, D::Error> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum StringOrNum {
|
||||||
|
Num(u64),
|
||||||
|
Str(String),
|
||||||
|
}
|
||||||
|
match StringOrNum::deserialize(deserializer)? {
|
||||||
|
StringOrNum::Num(n) => Ok(n),
|
||||||
|
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flexible_u64_default<'de, D: serde::Deserializer<'de>>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> std::result::Result<u64, D::Error> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum Val {
|
||||||
|
Num(u64),
|
||||||
|
Str(String),
|
||||||
|
}
|
||||||
|
match Val::deserialize(deserializer) {
|
||||||
|
Ok(Val::Num(n)) => Ok(n),
|
||||||
|
Ok(Val::Str(s)) => Ok(s.parse().unwrap_or(0)),
|
||||||
|
Err(_) => Ok(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flexible_u16<'de, D: serde::Deserializer<'de>>(
|
||||||
|
deserializer: D,
|
||||||
|
) -> std::result::Result<u16, D::Error> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum StringOrNum {
|
||||||
|
Num(u16),
|
||||||
|
Str(String),
|
||||||
|
}
|
||||||
|
match StringOrNum::deserialize(deserializer)? {
|
||||||
|
StringOrNum::Num(n) => Ok(n),
|
||||||
|
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Strip the port suffix from a socket address string.
|
||||||
|
pub fn strip_port(addr: &str) -> &str {
|
||||||
|
if addr.starts_with('[') {
|
||||||
|
addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr)
|
||||||
|
} else if let Some(pos) = addr.rfind(':') {
|
||||||
|
&addr[..pos]
|
||||||
|
} else {
|
||||||
|
addr
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,230 +1,6 @@
|
|||||||
use crate::autotune::optimizer::BayesianOptimizer;
|
// Copyright Sunbeam Studios 2026
|
||||||
use crate::autotune::params::{ParamDef, ParamSpace, ParamType};
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
use crate::ddos::replay::{ReplayArgs, replay_and_evaluate};
|
|
||||||
use crate::ddos::train::{HeuristicThresholds, train_model_from_states, parse_logs};
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use std::io::Write;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
pub struct AutotuneDdosArgs {
|
// Legacy KNN autotune removed — ensemble models are tuned via
|
||||||
pub input: String,
|
// `cargo run --features training -- sweep-cookie-weight` and the
|
||||||
pub output: String,
|
// training pipeline in src/training/.
|
||||||
pub trials: usize,
|
|
||||||
pub beta: f64,
|
|
||||||
pub trial_log: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ddos_param_space() -> ParamSpace {
|
|
||||||
ParamSpace::new(vec![
|
|
||||||
ParamDef { name: "k".into(), param_type: ParamType::Integer { min: 1, max: 20 } },
|
|
||||||
ParamDef { name: "threshold".into(), param_type: ParamType::Continuous { min: 0.1, max: 0.95 } },
|
|
||||||
ParamDef { name: "window_secs".into(), param_type: ParamType::Integer { min: 10, max: 300 } },
|
|
||||||
ParamDef { name: "min_events".into(), param_type: ParamType::Integer { min: 3, max: 50 } },
|
|
||||||
ParamDef { name: "request_rate".into(), param_type: ParamType::Continuous { min: 1.0, max: 100.0 } },
|
|
||||||
ParamDef { name: "path_repetition".into(), param_type: ParamType::Continuous { min: 0.3, max: 0.99 } },
|
|
||||||
ParamDef { name: "error_rate".into(), param_type: ParamType::Continuous { min: 0.2, max: 0.95 } },
|
|
||||||
ParamDef { name: "suspicious_path_ratio".into(), param_type: ParamType::Continuous { min: 0.05, max: 0.8 } },
|
|
||||||
ParamDef { name: "no_cookies_threshold".into(), param_type: ParamType::Continuous { min: 0.01, max: 0.3 } },
|
|
||||||
ParamDef { name: "no_cookies_path_count".into(), param_type: ParamType::Continuous { min: 5.0, max: 100.0 } },
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run_autotune(args: AutotuneDdosArgs) -> Result<()> {
|
|
||||||
let space = ddos_param_space();
|
|
||||||
let mut optimizer = BayesianOptimizer::new(space);
|
|
||||||
|
|
||||||
let mut trial_log_file = if let Some(ref path) = args.trial_log {
|
|
||||||
Some(std::fs::File::create(path)?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parse logs once upfront
|
|
||||||
eprintln!("Parsing logs from {}...", args.input);
|
|
||||||
let ip_states = parse_logs(&args.input)?;
|
|
||||||
eprintln!(" {} unique IPs", ip_states.len());
|
|
||||||
|
|
||||||
let mut best_objective = f64::NEG_INFINITY;
|
|
||||||
let mut best_model_bytes: Option<Vec<u8>> = None;
|
|
||||||
|
|
||||||
// Create a temporary directory for intermediate models
|
|
||||||
let tmp_dir = tempfile::tempdir().context("creating temp dir")?;
|
|
||||||
|
|
||||||
eprintln!("Starting DDoS autotune: {} trials, beta={}", args.trials, args.beta);
|
|
||||||
|
|
||||||
for trial_num in 1..=args.trials {
|
|
||||||
let params = optimizer.suggest();
|
|
||||||
let k = params[0] as usize;
|
|
||||||
let threshold = params[1];
|
|
||||||
let window_secs = params[2] as u64;
|
|
||||||
let min_events = params[3] as usize;
|
|
||||||
let request_rate = params[4];
|
|
||||||
let path_repetition = params[5];
|
|
||||||
let error_rate = params[6];
|
|
||||||
let suspicious_path_ratio = params[7];
|
|
||||||
let no_cookies_threshold = params[8];
|
|
||||||
let no_cookies_path_count = params[9];
|
|
||||||
|
|
||||||
let heuristics = HeuristicThresholds::new(
|
|
||||||
request_rate,
|
|
||||||
path_repetition,
|
|
||||||
error_rate,
|
|
||||||
suspicious_path_ratio,
|
|
||||||
no_cookies_threshold,
|
|
||||||
no_cookies_path_count,
|
|
||||||
min_events,
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = Instant::now();
|
|
||||||
|
|
||||||
// Train model with these parameters
|
|
||||||
let train_result = match train_model_from_states(
|
|
||||||
&ip_states, &heuristics, k, threshold, window_secs, min_events,
|
|
||||||
) {
|
|
||||||
Ok(r) => r,
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!(" trial {trial_num}: TRAIN FAILED ({e})");
|
|
||||||
optimizer.observe(params, 0.0, start.elapsed());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Save temporary model for replay
|
|
||||||
let tmp_model_path = tmp_dir.path().join(format!("trial_{trial_num}.bin"));
|
|
||||||
let encoded = match bincode::serialize(&train_result.model) {
|
|
||||||
Ok(e) => e,
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!(" trial {trial_num}: SERIALIZE FAILED ({e})");
|
|
||||||
optimizer.observe(params, 0.0, start.elapsed());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if let Err(e) = std::fs::write(&tmp_model_path, &encoded) {
|
|
||||||
eprintln!(" trial {trial_num}: WRITE FAILED ({e})");
|
|
||||||
optimizer.observe(params, 0.0, start.elapsed());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replay to evaluate
|
|
||||||
let replay_args = ReplayArgs {
|
|
||||||
input: args.input.clone(),
|
|
||||||
model_path: tmp_model_path.to_string_lossy().into_owned(),
|
|
||||||
config_path: None,
|
|
||||||
k,
|
|
||||||
threshold,
|
|
||||||
window_secs,
|
|
||||||
min_events,
|
|
||||||
rate_limit: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
let replay_result = match replay_and_evaluate(&replay_args) {
|
|
||||||
Ok(r) => r,
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!(" trial {trial_num}: REPLAY FAILED ({e})");
|
|
||||||
optimizer.observe(params, 0.0, start.elapsed());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let duration = start.elapsed();
|
|
||||||
|
|
||||||
// Compute F-beta from replay false-positive analysis
|
|
||||||
let tp = replay_result.true_positive_ips as f64;
|
|
||||||
let fp = replay_result.false_positive_ips as f64;
|
|
||||||
let total_blocked = replay_result.ddos_blocked_ips.len() as f64;
|
|
||||||
let fn_ = if total_blocked > 0.0 { 0.0 } else { 1.0 }; // We don't know true FN without ground truth
|
|
||||||
|
|
||||||
let objective = if tp + fp > 0.0 {
|
|
||||||
let precision = tp / (tp + fp);
|
|
||||||
let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
|
|
||||||
let b2 = args.beta * args.beta;
|
|
||||||
if precision + recall > 0.0 {
|
|
||||||
(1.0 + b2) * precision * recall / (b2 * precision + recall)
|
|
||||||
} else {
|
|
||||||
0.0
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
0.0
|
|
||||||
};
|
|
||||||
|
|
||||||
eprintln!(
|
|
||||||
" trial {trial_num}/{}: fbeta={objective:.4} (k={k}, thr={threshold:.3}, win={window_secs}s, tp={}, fp={}) [{:.1}s]",
|
|
||||||
args.trials,
|
|
||||||
replay_result.true_positive_ips,
|
|
||||||
replay_result.false_positive_ips,
|
|
||||||
duration.as_secs_f64(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Log trial as JSONL
|
|
||||||
if let Some(ref mut f) = trial_log_file {
|
|
||||||
let trial_json = serde_json::json!({
|
|
||||||
"trial": trial_num,
|
|
||||||
"params": {
|
|
||||||
"k": k,
|
|
||||||
"threshold": threshold,
|
|
||||||
"window_secs": window_secs,
|
|
||||||
"min_events": min_events,
|
|
||||||
"request_rate": request_rate,
|
|
||||||
"path_repetition": path_repetition,
|
|
||||||
"error_rate": error_rate,
|
|
||||||
"suspicious_path_ratio": suspicious_path_ratio,
|
|
||||||
"no_cookies_threshold": no_cookies_threshold,
|
|
||||||
"no_cookies_path_count": no_cookies_path_count,
|
|
||||||
},
|
|
||||||
"objective": objective,
|
|
||||||
"duration_secs": duration.as_secs_f64(),
|
|
||||||
"true_positive_ips": replay_result.true_positive_ips,
|
|
||||||
"false_positive_ips": replay_result.false_positive_ips,
|
|
||||||
"ddos_blocked": replay_result.ddos_blocked,
|
|
||||||
"allowed": replay_result.allowed,
|
|
||||||
"attack_count": train_result.attack_count,
|
|
||||||
"normal_count": train_result.normal_count,
|
|
||||||
});
|
|
||||||
writeln!(f, "{}", trial_json)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if objective > best_objective {
|
|
||||||
best_objective = objective;
|
|
||||||
best_model_bytes = Some(encoded);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up temporary model
|
|
||||||
let _ = std::fs::remove_file(&tmp_model_path);
|
|
||||||
|
|
||||||
optimizer.observe(params, objective, duration);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save best model
|
|
||||||
if let Some(bytes) = best_model_bytes {
|
|
||||||
std::fs::write(&args.output, &bytes)?;
|
|
||||||
eprintln!("\nBest model saved to {}", args.output);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print summary
|
|
||||||
if let Some(best) = optimizer.best() {
|
|
||||||
eprintln!("\n═══ Autotune Results ═══════════════════════════════════════");
|
|
||||||
eprintln!(" Best trial: #{}", best.trial_num);
|
|
||||||
eprintln!(" Best F-beta: {:.4}", best.objective);
|
|
||||||
eprintln!(" Parameters:");
|
|
||||||
for (name, val) in best.param_names.iter().zip(best.params.iter()) {
|
|
||||||
eprintln!(" {:<30} = {:.6}", name, val);
|
|
||||||
}
|
|
||||||
eprintln!("\n Heuristics TOML snippet:");
|
|
||||||
eprintln!(" request_rate = {:.2}", best.params[4]);
|
|
||||||
eprintln!(" path_repetition = {:.4}", best.params[5]);
|
|
||||||
eprintln!(" error_rate = {:.4}", best.params[6]);
|
|
||||||
eprintln!(" suspicious_path_ratio = {:.4}", best.params[7]);
|
|
||||||
eprintln!(" no_cookies_threshold = {:.4}", best.params[8]);
|
|
||||||
eprintln!(" no_cookies_path_count = {:.1}", best.params[9]);
|
|
||||||
eprintln!(" min_events = {}", best.params[3] as usize);
|
|
||||||
eprintln!("\n Reproduce:");
|
|
||||||
eprintln!(
|
|
||||||
" cargo run -- train-ddos --input {} --output {} --k {} --threshold {:.4} --window-secs {} --min-events {} --heuristics <toml>",
|
|
||||||
args.input, args.output,
|
|
||||||
best.params[0] as usize, best.params[1],
|
|
||||||
best.params[2] as u64, best.params[3] as usize,
|
|
||||||
);
|
|
||||||
eprintln!("══════════════════════════════════════════════════════════");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,128 +1,6 @@
|
|||||||
use crate::autotune::optimizer::BayesianOptimizer;
|
// Copyright Sunbeam Studios 2026
|
||||||
use crate::autotune::params::{ParamDef, ParamSpace, ParamType};
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
use crate::scanner::train::{TrainScannerArgs, train_and_evaluate};
|
|
||||||
use anyhow::Result;
|
|
||||||
use std::io::Write;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
pub struct AutotuneScannerArgs {
|
// Legacy linear-model autotune removed — ensemble models are tuned via
|
||||||
pub input: String,
|
// `cargo run --features training -- sweep-cookie-weight` and the
|
||||||
pub output: String,
|
// training pipeline in src/training/.
|
||||||
pub wordlists: Option<String>,
|
|
||||||
pub csic: bool,
|
|
||||||
pub trials: usize,
|
|
||||||
pub beta: f64,
|
|
||||||
pub trial_log: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn scanner_param_space() -> ParamSpace {
|
|
||||||
ParamSpace::new(vec![
|
|
||||||
ParamDef { name: "threshold".into(), param_type: ParamType::Continuous { min: 0.1, max: 0.95 } },
|
|
||||||
ParamDef { name: "learning_rate".into(), param_type: ParamType::LogScale { min: 0.001, max: 0.1 } },
|
|
||||||
ParamDef { name: "epochs".into(), param_type: ParamType::Integer { min: 100, max: 5000 } },
|
|
||||||
ParamDef { name: "class_weight_multiplier".into(), param_type: ParamType::Continuous { min: 0.5, max: 5.0 } },
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run_autotune(args: AutotuneScannerArgs) -> Result<()> {
|
|
||||||
let space = scanner_param_space();
|
|
||||||
let mut optimizer = BayesianOptimizer::new(space);
|
|
||||||
|
|
||||||
let mut trial_log_file = if let Some(ref path) = args.trial_log {
|
|
||||||
Some(std::fs::File::create(path)?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut best_objective = f64::NEG_INFINITY;
|
|
||||||
let mut best_model_bytes: Option<Vec<u8>> = None;
|
|
||||||
|
|
||||||
eprintln!("Starting scanner autotune: {} trials, beta={}", args.trials, args.beta);
|
|
||||||
|
|
||||||
for trial_num in 1..=args.trials {
|
|
||||||
let params = optimizer.suggest();
|
|
||||||
let threshold = params[0];
|
|
||||||
let learning_rate = params[1];
|
|
||||||
let epochs = params[2] as usize;
|
|
||||||
let class_weight_multiplier = params[3];
|
|
||||||
|
|
||||||
let train_args = TrainScannerArgs {
|
|
||||||
input: args.input.clone(),
|
|
||||||
output: String::new(), // don't save intermediate models
|
|
||||||
wordlists: args.wordlists.clone(),
|
|
||||||
threshold,
|
|
||||||
csic: args.csic,
|
|
||||||
};
|
|
||||||
|
|
||||||
let start = Instant::now();
|
|
||||||
let result = match train_and_evaluate(&train_args, learning_rate, epochs, class_weight_multiplier) {
|
|
||||||
Ok(r) => r,
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!(" trial {trial_num}: FAILED ({e})");
|
|
||||||
optimizer.observe(params, 0.0, start.elapsed());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let duration = start.elapsed();
|
|
||||||
|
|
||||||
let objective = result.test_metrics.fbeta(args.beta);
|
|
||||||
|
|
||||||
eprintln!(
|
|
||||||
" trial {trial_num}/{}: fbeta={objective:.4} (threshold={threshold:.3}, lr={learning_rate:.5}, epochs={epochs}, cwm={class_weight_multiplier:.2}) [{:.1}s]",
|
|
||||||
args.trials,
|
|
||||||
duration.as_secs_f64(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Log trial as JSONL
|
|
||||||
if let Some(ref mut f) = trial_log_file {
|
|
||||||
let trial_json = serde_json::json!({
|
|
||||||
"trial": trial_num,
|
|
||||||
"params": {
|
|
||||||
"threshold": threshold,
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
"epochs": epochs,
|
|
||||||
"class_weight_multiplier": class_weight_multiplier,
|
|
||||||
},
|
|
||||||
"objective": objective,
|
|
||||||
"duration_secs": duration.as_secs_f64(),
|
|
||||||
"train_f1": result.train_metrics.f1(),
|
|
||||||
"test_precision": result.test_metrics.precision(),
|
|
||||||
"test_recall": result.test_metrics.recall(),
|
|
||||||
});
|
|
||||||
writeln!(f, "{}", trial_json)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if objective > best_objective {
|
|
||||||
best_objective = objective;
|
|
||||||
let encoded = bincode::serialize(&result.model)?;
|
|
||||||
best_model_bytes = Some(encoded);
|
|
||||||
}
|
|
||||||
|
|
||||||
optimizer.observe(params, objective, duration);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save best model
|
|
||||||
if let Some(bytes) = best_model_bytes {
|
|
||||||
std::fs::write(&args.output, &bytes)?;
|
|
||||||
eprintln!("\nBest model saved to {}", args.output);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print summary
|
|
||||||
if let Some(best) = optimizer.best() {
|
|
||||||
eprintln!("\n═══ Autotune Results ═══════════════════════════════════════");
|
|
||||||
eprintln!(" Best trial: #{}", best.trial_num);
|
|
||||||
eprintln!(" Best F-beta: {:.4}", best.objective);
|
|
||||||
eprintln!(" Parameters:");
|
|
||||||
for (name, val) in best.param_names.iter().zip(best.params.iter()) {
|
|
||||||
eprintln!(" {:<30} = {:.6}", name, val);
|
|
||||||
}
|
|
||||||
eprintln!("\n Reproduce:");
|
|
||||||
eprintln!(
|
|
||||||
" cargo run -- train-scanner --input {} --output {} --threshold {:.4}",
|
|
||||||
args.input, args.output, best.params[0],
|
|
||||||
);
|
|
||||||
eprintln!("══════════════════════════════════════════════════════════");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
@@ -18,7 +21,7 @@ pub struct Config {
|
|||||||
pub routes: Vec<RouteConfig>,
|
pub routes: Vec<RouteConfig>,
|
||||||
/// Optional SSH TCP passthrough (port 22 → Gitea SSH).
|
/// Optional SSH TCP passthrough (port 22 → Gitea SSH).
|
||||||
pub ssh: Option<SshConfig>,
|
pub ssh: Option<SshConfig>,
|
||||||
/// Optional KNN-based DDoS detection.
|
/// Optional DDoS detection (ensemble: decision tree + MLP).
|
||||||
pub ddos: Option<DDoSConfig>,
|
pub ddos: Option<DDoSConfig>,
|
||||||
/// Optional per-identity rate limiting.
|
/// Optional per-identity rate limiting.
|
||||||
pub rate_limit: Option<RateLimitConfig>,
|
pub rate_limit: Option<RateLimitConfig>,
|
||||||
@@ -60,10 +63,6 @@ fn default_config_configmap() -> String { "pingora-config".to_string() }
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct DDoSConfig {
|
pub struct DDoSConfig {
|
||||||
#[serde(default)]
|
|
||||||
pub model_path: Option<String>,
|
|
||||||
#[serde(default = "default_k")]
|
|
||||||
pub k: usize,
|
|
||||||
#[serde(default = "default_threshold")]
|
#[serde(default = "default_threshold")]
|
||||||
pub threshold: f64,
|
pub threshold: f64,
|
||||||
#[serde(default = "default_window_secs")]
|
#[serde(default = "default_window_secs")]
|
||||||
@@ -74,8 +73,10 @@ pub struct DDoSConfig {
|
|||||||
pub min_events: usize,
|
pub min_events: usize,
|
||||||
#[serde(default = "default_enabled")]
|
#[serde(default = "default_enabled")]
|
||||||
pub enabled: bool,
|
pub enabled: bool,
|
||||||
#[serde(default = "default_use_ensemble")]
|
/// When true, run the model and log decisions but never block traffic.
|
||||||
pub use_ensemble: bool,
|
/// Useful for gathering data on model accuracy before enforcing.
|
||||||
|
#[serde(default)]
|
||||||
|
pub observe_only: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
@@ -100,23 +101,20 @@ pub struct BucketConfig {
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
pub struct ScannerConfig {
|
pub struct ScannerConfig {
|
||||||
#[serde(default)]
|
|
||||||
pub model_path: Option<String>,
|
|
||||||
#[serde(default = "default_scanner_threshold")]
|
#[serde(default = "default_scanner_threshold")]
|
||||||
pub threshold: f64,
|
pub threshold: f64,
|
||||||
#[serde(default = "default_scanner_enabled")]
|
#[serde(default = "default_scanner_enabled")]
|
||||||
pub enabled: bool,
|
pub enabled: bool,
|
||||||
/// How often (seconds) to check the model file for changes. 0 = no hot-reload.
|
|
||||||
#[serde(default = "default_scanner_poll_interval")]
|
|
||||||
pub poll_interval_secs: u64,
|
|
||||||
/// Bot allowlist rules. Verified bots bypass the scanner model.
|
/// Bot allowlist rules. Verified bots bypass the scanner model.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub allowlist: Vec<BotAllowlistRule>,
|
pub allowlist: Vec<BotAllowlistRule>,
|
||||||
/// TTL (seconds) for verified bot IP cache entries.
|
/// TTL (seconds) for verified bot IP cache entries.
|
||||||
#[serde(default = "default_bot_cache_ttl")]
|
#[serde(default = "default_bot_cache_ttl")]
|
||||||
pub bot_cache_ttl_secs: u64,
|
pub bot_cache_ttl_secs: u64,
|
||||||
#[serde(default = "default_use_ensemble")]
|
/// When true, run the model and log decisions but never block traffic.
|
||||||
pub use_ensemble: bool,
|
/// Useful for gathering data on model accuracy before enforcing.
|
||||||
|
#[serde(default)]
|
||||||
|
pub observe_only: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Clone)]
|
#[derive(Debug, Deserialize, Clone)]
|
||||||
@@ -136,17 +134,14 @@ pub struct BotAllowlistRule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn default_bot_cache_ttl() -> u64 { 86400 } // 24h
|
fn default_bot_cache_ttl() -> u64 { 86400 } // 24h
|
||||||
fn default_use_ensemble() -> bool { true }
|
|
||||||
|
|
||||||
fn default_scanner_threshold() -> f64 { 0.5 }
|
fn default_scanner_threshold() -> f64 { 0.5 }
|
||||||
fn default_scanner_enabled() -> bool { true }
|
fn default_scanner_enabled() -> bool { true }
|
||||||
fn default_scanner_poll_interval() -> u64 { 30 }
|
|
||||||
|
|
||||||
fn default_rl_enabled() -> bool { true }
|
fn default_rl_enabled() -> bool { true }
|
||||||
fn default_eviction_interval() -> u64 { 300 }
|
fn default_eviction_interval() -> u64 { 300 }
|
||||||
fn default_stale_after() -> u64 { 600 }
|
fn default_stale_after() -> u64 { 600 }
|
||||||
|
|
||||||
fn default_k() -> usize { 5 }
|
|
||||||
fn default_threshold() -> f64 { 0.6 }
|
fn default_threshold() -> f64 { 0.6 }
|
||||||
fn default_window_secs() -> u64 { 60 }
|
fn default_window_secs() -> u64 { 60 }
|
||||||
fn default_window_capacity() -> usize { 1000 }
|
fn default_window_capacity() -> usize { 1000 }
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! CIC-IDS2017 timing profile extractor.
|
//! CIC-IDS2017 timing profile extractor.
|
||||||
//!
|
//!
|
||||||
//! Parses CIC-IDS2017 CSV files and extracts statistical timing profiles
|
//! Parses CIC-IDS2017 CSV files and extracts statistical timing profiles
|
||||||
@@ -218,6 +221,262 @@ fn parse_csv_file(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert CIC-IDS2017 flow records directly into DDoS training samples.
|
||||||
|
///
|
||||||
|
/// Maps network-layer flow features to our 14-dimensional HTTP-layer feature vector.
|
||||||
|
/// Non-BENIGN labels → attack (1.0), BENIGN → normal (0.0).
|
||||||
|
/// Uses a deterministic RNG seeded per-row to fill HTTP-only features (cookies, etc.)
|
||||||
|
/// that don't exist in the network-layer data.
|
||||||
|
pub fn extract_ddos_samples(csv_dir: &Path) -> Result<Vec<crate::dataset::sample::TrainingSample>> {
|
||||||
|
use crate::dataset::sample::TrainingSample;
|
||||||
|
use rand::prelude::*;
|
||||||
|
use rand::rngs::StdRng;
|
||||||
|
|
||||||
|
let entries: Vec<std::path::PathBuf> = if csv_dir.is_file() {
|
||||||
|
vec![csv_dir.to_path_buf()]
|
||||||
|
} else {
|
||||||
|
let mut files: Vec<std::path::PathBuf> = std::fs::read_dir(csv_dir)
|
||||||
|
.with_context(|| format!("reading directory {}", csv_dir.display()))?
|
||||||
|
.filter_map(|e| e.ok())
|
||||||
|
.map(|e| e.path())
|
||||||
|
.filter(|p| {
|
||||||
|
p.extension()
|
||||||
|
.map(|e| e.to_ascii_lowercase() == "csv")
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
files.sort();
|
||||||
|
files
|
||||||
|
};
|
||||||
|
|
||||||
|
if entries.is_empty() {
|
||||||
|
anyhow::bail!("no CSV files found in {}", csv_dir.display());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut samples = Vec::new();
|
||||||
|
let mut rng = StdRng::seed_from_u64(0xC1C1D5);
|
||||||
|
|
||||||
|
for csv_path in &entries {
|
||||||
|
let file_samples = extract_ddos_samples_from_csv(csv_path, &mut rng)
|
||||||
|
.with_context(|| format!("extracting DDoS samples from {}", csv_path.display()))?;
|
||||||
|
let filename = csv_path.file_name().unwrap_or_default().to_string_lossy();
|
||||||
|
let attack_count = file_samples.iter().filter(|s| s.label > 0.5).count();
|
||||||
|
let normal_count = file_samples.len() - attack_count;
|
||||||
|
eprintln!(
|
||||||
|
" {}: {} samples ({} attack, {} normal)",
|
||||||
|
filename,
|
||||||
|
file_samples.len(),
|
||||||
|
attack_count,
|
||||||
|
normal_count
|
||||||
|
);
|
||||||
|
samples.extend(file_samples);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsample if too large — cap at 500K to keep training tractable.
|
||||||
|
let max_samples = 500_000;
|
||||||
|
if samples.len() > max_samples {
|
||||||
|
// Stratified subsample: keep attack ratio balanced.
|
||||||
|
let mut attacks: Vec<TrainingSample> =
|
||||||
|
samples.iter().filter(|s| s.label > 0.5).cloned().collect();
|
||||||
|
let mut normals: Vec<TrainingSample> =
|
||||||
|
samples.iter().filter(|s| s.label <= 0.5).cloned().collect();
|
||||||
|
|
||||||
|
// Shuffle both
|
||||||
|
attacks.shuffle(&mut rng);
|
||||||
|
normals.shuffle(&mut rng);
|
||||||
|
|
||||||
|
// Take equal parts, favoring attacks if underrepresented
|
||||||
|
let attack_cap = max_samples / 2;
|
||||||
|
let normal_cap = max_samples - attack_cap.min(attacks.len());
|
||||||
|
attacks.truncate(attack_cap);
|
||||||
|
normals.truncate(normal_cap);
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
" subsampled to {} ({} attack, {} normal)",
|
||||||
|
attacks.len() + normals.len(),
|
||||||
|
attacks.len(),
|
||||||
|
normals.len()
|
||||||
|
);
|
||||||
|
samples = attacks;
|
||||||
|
samples.extend(normals);
|
||||||
|
samples.shuffle(&mut rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(samples)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_ddos_samples_from_csv(
|
||||||
|
path: &Path,
|
||||||
|
rng: &mut rand::rngs::StdRng,
|
||||||
|
) -> Result<Vec<crate::dataset::sample::TrainingSample>> {
|
||||||
|
use crate::dataset::sample::{DataSource, TrainingSample};
|
||||||
|
use crate::ddos::features::NUM_FEATURES;
|
||||||
|
use rand::prelude::*;
|
||||||
|
|
||||||
|
let mut rdr = csv::ReaderBuilder::new()
|
||||||
|
.flexible(true)
|
||||||
|
.trim(csv::Trim::All)
|
||||||
|
.from_path(path)?;
|
||||||
|
|
||||||
|
let headers: Vec<String> = rdr.headers()?.iter().map(|h| h.to_string()).collect();
|
||||||
|
|
||||||
|
let col_label = find_column(&headers, "Label")
|
||||||
|
.with_context(|| format!("missing 'Label' in {}", path.display()))?;
|
||||||
|
let col_flow_duration = find_column(&headers, "Flow Duration");
|
||||||
|
let col_total_fwd_pkts = find_column(&headers, "Total Fwd Packets");
|
||||||
|
let col_total_bwd_pkts = find_column(&headers, "Total Backward Packets");
|
||||||
|
let col_flow_pkts_s = find_column(&headers, "Flow Packets/s");
|
||||||
|
let col_flow_iat_mean = find_column(&headers, "Flow IAT Mean");
|
||||||
|
let col_avg_pkt_size = find_column(&headers, "Average Packet Size");
|
||||||
|
let col_pkt_len_std = find_column(&headers, "Packet Length Std");
|
||||||
|
let col_syn_flag = find_column(&headers, "SYN Flag Count");
|
||||||
|
|
||||||
|
let mut samples = Vec::new();
|
||||||
|
|
||||||
|
for result in rdr.records() {
|
||||||
|
let record = match result {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let label_str = match record.get(col_label) {
|
||||||
|
Some(l) => l.trim().to_string(),
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
if label_str.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_attack = label_str != "BENIGN";
|
||||||
|
let label_f32 = if is_attack { 1.0f32 } else { 0.0f32 };
|
||||||
|
|
||||||
|
// Parse numeric columns (0.0 fallback for missing/malformed).
|
||||||
|
let get_f64 = |col: Option<usize>| -> f64 {
|
||||||
|
col.and_then(|c| record.get(c))
|
||||||
|
.and_then(|v| v.trim().parse::<f64>().ok())
|
||||||
|
.unwrap_or(0.0)
|
||||||
|
};
|
||||||
|
|
||||||
|
let flow_duration_us = get_f64(col_flow_duration); // microseconds
|
||||||
|
let total_fwd_pkts = get_f64(col_total_fwd_pkts);
|
||||||
|
let total_bwd_pkts = get_f64(col_total_bwd_pkts);
|
||||||
|
let flow_pkts_s = get_f64(col_flow_pkts_s);
|
||||||
|
let flow_iat_mean = get_f64(col_flow_iat_mean); // microseconds
|
||||||
|
let avg_pkt_size = get_f64(col_avg_pkt_size);
|
||||||
|
let pkt_len_std = get_f64(col_pkt_len_std);
|
||||||
|
let syn_flag = get_f64(col_syn_flag);
|
||||||
|
|
||||||
|
let flow_duration_s = (flow_duration_us / 1_000_000.0).max(0.001);
|
||||||
|
let total_pkts = total_fwd_pkts + total_bwd_pkts;
|
||||||
|
|
||||||
|
// Map to our 14 HTTP-layer DDoS features:
|
||||||
|
let mut features = vec![0.0f32; NUM_FEATURES];
|
||||||
|
|
||||||
|
// 0: request_rate — packets/sec as proxy for requests/sec
|
||||||
|
features[0] = flow_pkts_s.max(0.0).min(10000.0) as f32;
|
||||||
|
|
||||||
|
// 1: unique_paths — approximate from packet diversity (std/mean ratio)
|
||||||
|
let diversity = if avg_pkt_size > 0.0 {
|
||||||
|
(pkt_len_std / avg_pkt_size).min(10.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
features[1] = (diversity * 5.0 + 1.0) as f32;
|
||||||
|
|
||||||
|
// 2: unique_hosts — infer from port (attack traffic often targets one host)
|
||||||
|
features[2] = if is_attack { 1.0 } else { rng.random_range(1.0..5.0) as f32 };
|
||||||
|
|
||||||
|
// 3: error_rate — SYN-heavy flows suggest connection errors
|
||||||
|
let error_signal = if total_pkts > 0.0 {
|
||||||
|
(syn_flag / total_pkts.max(1.0)).min(1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
features[3] = if is_attack {
|
||||||
|
(error_signal + rng.random_range(0.1..0.5)).min(1.0) as f32
|
||||||
|
} else {
|
||||||
|
(error_signal * 0.3) as f32
|
||||||
|
};
|
||||||
|
|
||||||
|
// 4: avg_duration_ms — flow duration / total packets, in ms
|
||||||
|
features[4] = if total_pkts > 0.0 {
|
||||||
|
((flow_duration_s * 1000.0) / total_pkts).min(5000.0) as f32
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
// 5: method_entropy — low for attacks (single method), moderate for normal
|
||||||
|
features[5] = if is_attack {
|
||||||
|
rng.random_range(0.0..0.3) as f32
|
||||||
|
} else {
|
||||||
|
rng.random_range(0.2..1.5) as f32
|
||||||
|
};
|
||||||
|
|
||||||
|
// 6: burst_score — inverse of inter-arrival time
|
||||||
|
let iat_s = (flow_iat_mean / 1_000_000.0).max(0.001);
|
||||||
|
features[6] = (1.0 / iat_s).min(500.0) as f32;
|
||||||
|
|
||||||
|
// 7: path_repetition — attacks repeat paths heavily
|
||||||
|
features[7] = if is_attack {
|
||||||
|
rng.random_range(0.6..1.0) as f32
|
||||||
|
} else {
|
||||||
|
rng.random_range(0.05..0.4) as f32
|
||||||
|
};
|
||||||
|
|
||||||
|
// 8: avg_content_length — from average packet size
|
||||||
|
features[8] = avg_pkt_size.max(0.0).min(10000.0) as f32;
|
||||||
|
|
||||||
|
// 9: unique_user_agents — low for attacks
|
||||||
|
features[9] = if is_attack {
|
||||||
|
rng.random_range(1.0..2.0) as f32
|
||||||
|
} else {
|
||||||
|
rng.random_range(1.0..4.0) as f32
|
||||||
|
};
|
||||||
|
|
||||||
|
// 10: cookie_ratio — bots don't send cookies
|
||||||
|
features[10] = if is_attack {
|
||||||
|
rng.random_range(0.0..0.1) as f32
|
||||||
|
} else {
|
||||||
|
rng.random_range(0.6..1.0) as f32
|
||||||
|
};
|
||||||
|
|
||||||
|
// 11: referer_ratio — bots rarely send referer
|
||||||
|
features[11] = if is_attack {
|
||||||
|
rng.random_range(0.0..0.1) as f32
|
||||||
|
} else {
|
||||||
|
rng.random_range(0.3..1.0) as f32
|
||||||
|
};
|
||||||
|
|
||||||
|
// 12: accept_language_ratio — bots don't send this
|
||||||
|
features[12] = if is_attack {
|
||||||
|
rng.random_range(0.0..0.1) as f32
|
||||||
|
} else {
|
||||||
|
rng.random_range(0.6..1.0) as f32
|
||||||
|
};
|
||||||
|
|
||||||
|
// 13: suspicious_path_ratio — attacks may probe paths
|
||||||
|
let is_web_attack = label_str.contains("Web Attack")
|
||||||
|
|| label_str.contains("Bot")
|
||||||
|
|| label_str.contains("Infiltration");
|
||||||
|
features[13] = if is_web_attack {
|
||||||
|
rng.random_range(0.2..0.7) as f32
|
||||||
|
} else if is_attack {
|
||||||
|
rng.random_range(0.0..0.3) as f32
|
||||||
|
} else {
|
||||||
|
rng.random_range(0.0..0.05) as f32
|
||||||
|
};
|
||||||
|
|
||||||
|
samples.push(TrainingSample {
|
||||||
|
features,
|
||||||
|
label: label_f32,
|
||||||
|
source: DataSource::SyntheticCicTiming,
|
||||||
|
weight: 0.7, // higher than pure synthetic (0.5), lower than prod (1.0)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(samples)
|
||||||
|
}
|
||||||
|
|
||||||
/// Parse timing profiles from an in-memory CSV string (useful for tests).
|
/// Parse timing profiles from an in-memory CSV string (useful for tests).
|
||||||
pub fn extract_timing_profiles_from_str(csv_content: &str) -> Result<Vec<TimingProfile>> {
|
pub fn extract_timing_profiles_from_str(csv_content: &str) -> Result<Vec<TimingProfile>> {
|
||||||
let dir = tempfile::tempdir()?;
|
let dir = tempfile::tempdir()?;
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! Download and cache upstream datasets for training.
|
//! Download and cache upstream datasets for training.
|
||||||
//!
|
//!
|
||||||
//! Cached under `~/.cache/sunbeam/<dataset>/`. Files are only downloaded
|
//! Cached under `~/.cache/sunbeam/<dataset>/`. Files are only downloaded
|
||||||
@@ -19,8 +22,17 @@ fn cache_base() -> PathBuf {
|
|||||||
|
|
||||||
// --- CIC-IDS2017 ---
|
// --- CIC-IDS2017 ---
|
||||||
|
|
||||||
/// Only the Friday DDoS file — contains DDoS Hulk, Slowloris, slowhttptest, GoldenEye.
|
/// All CIC-IDS2017 CSV files — covers every attack day and normal baselines.
|
||||||
const CICIDS_FILE: &str = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv";
|
const CICIDS_FILES: &[&str] = &[
|
||||||
|
"Monday-WorkingHours.pcap_ISCX.csv",
|
||||||
|
"Tuesday-WorkingHours.pcap_ISCX.csv",
|
||||||
|
"Wednesday-workingHours.pcap_ISCX.csv",
|
||||||
|
"Thursday-WorkingHours-Morning-WebAttacks.pcap_ISCX.csv",
|
||||||
|
"Thursday-WorkingHours-Afternoon-Infilteration.pcap_ISCX.csv",
|
||||||
|
"Friday-WorkingHours-Morning.pcap_ISCX.csv",
|
||||||
|
"Friday-WorkingHours-Afternoon-PortScan.pcap_ISCX.csv",
|
||||||
|
"Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv",
|
||||||
|
];
|
||||||
|
|
||||||
/// Hugging Face mirror (public, no auth required).
|
/// Hugging Face mirror (public, no auth required).
|
||||||
const CICIDS_BASE_URL: &str =
|
const CICIDS_BASE_URL: &str =
|
||||||
@@ -30,49 +42,56 @@ fn cicids_cache_dir() -> PathBuf {
|
|||||||
cache_base().join("cicids")
|
cache_base().join("cicids")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the path to the cached CIC-IDS2017 DDoS CSV, or `None` if not downloaded.
|
/// Return the cache directory if ALL CIC-IDS2017 CSVs are downloaded, else `None`.
|
||||||
pub fn cicids_cached_path() -> Option<PathBuf> {
|
pub fn cicids_cached_path() -> Option<PathBuf> {
|
||||||
let path = cicids_cache_dir().join(CICIDS_FILE);
|
let dir = cicids_cache_dir();
|
||||||
if path.exists() {
|
if CICIDS_FILES.iter().all(|f| dir.join(f).exists()) {
|
||||||
Some(path)
|
Some(dir)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Download the CIC-IDS2017 Friday DDoS CSV to cache. Returns the cached path.
|
/// Download all CIC-IDS2017 CSV files to cache. Returns the cache directory.
|
||||||
pub fn download_cicids() -> Result<PathBuf> {
|
pub fn download_cicids() -> Result<PathBuf> {
|
||||||
let dir = cicids_cache_dir();
|
let dir = cicids_cache_dir();
|
||||||
let path = dir.join(CICIDS_FILE);
|
|
||||||
|
|
||||||
if path.exists() {
|
|
||||||
eprintln!(" cached: {}", path.display());
|
|
||||||
return Ok(path);
|
|
||||||
}
|
|
||||||
|
|
||||||
let url = format!("{CICIDS_BASE_URL}/{CICIDS_FILE}");
|
|
||||||
eprintln!(" downloading: {url}");
|
|
||||||
eprintln!(" (this is ~170 MB, may take a minute)");
|
|
||||||
|
|
||||||
std::fs::create_dir_all(&dir)?;
|
std::fs::create_dir_all(&dir)?;
|
||||||
|
|
||||||
// Stream to file to avoid holding 170MB in memory.
|
let client = reqwest::blocking::Client::builder()
|
||||||
let resp = reqwest::blocking::Client::builder()
|
|
||||||
.timeout(std::time::Duration::from_secs(600))
|
.timeout(std::time::Duration::from_secs(600))
|
||||||
.build()?
|
.build()?;
|
||||||
.get(&url)
|
|
||||||
.send()
|
|
||||||
.with_context(|| format!("fetching {url}"))?
|
|
||||||
.error_for_status()
|
|
||||||
.with_context(|| format!("HTTP error for {url}"))?;
|
|
||||||
|
|
||||||
let mut file = std::fs::File::create(&path)
|
for (i, filename) in CICIDS_FILES.iter().enumerate() {
|
||||||
.with_context(|| format!("creating {}", path.display()))?;
|
let path = dir.join(filename);
|
||||||
let bytes = resp.bytes().with_context(|| "reading response body")?;
|
if path.exists() {
|
||||||
std::io::Write::write_all(&mut file, &bytes)?;
|
eprintln!(" [{}/{}] cached: {}", i + 1, CICIDS_FILES.len(), filename);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
eprintln!(" saved: {}", path.display());
|
let url = format!("{CICIDS_BASE_URL}/{filename}");
|
||||||
Ok(path)
|
eprintln!(
|
||||||
|
" [{}/{}] downloading: {}",
|
||||||
|
i + 1,
|
||||||
|
CICIDS_FILES.len(),
|
||||||
|
filename
|
||||||
|
);
|
||||||
|
|
||||||
|
let resp = client
|
||||||
|
.get(&url)
|
||||||
|
.send()
|
||||||
|
.with_context(|| format!("fetching {url}"))?
|
||||||
|
.error_for_status()
|
||||||
|
.with_context(|| format!("HTTP error for {url}"))?;
|
||||||
|
|
||||||
|
let mut file = std::fs::File::create(&path)
|
||||||
|
.with_context(|| format!("creating {}", path.display()))?;
|
||||||
|
let bytes = resp.bytes().with_context(|| "reading response body")?;
|
||||||
|
std::io::Write::write_all(&mut file, &bytes)?;
|
||||||
|
|
||||||
|
eprintln!(" saved: {}", path.display());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- CSIC 2010 ---
|
// --- CSIC 2010 ---
|
||||||
@@ -80,7 +99,10 @@ pub fn download_cicids() -> Result<PathBuf> {
|
|||||||
/// Download CSIC 2010 dataset files to cache (delegates to scanner::csic).
|
/// Download CSIC 2010 dataset files to cache (delegates to scanner::csic).
|
||||||
pub fn download_csic() -> Result<()> {
|
pub fn download_csic() -> Result<()> {
|
||||||
if crate::scanner::csic::csic_is_cached() {
|
if crate::scanner::csic::csic_is_cached() {
|
||||||
eprintln!(" cached: {}", crate::scanner::csic::csic_cache_path().display());
|
eprintln!(
|
||||||
|
" cached: {}",
|
||||||
|
crate::scanner::csic::csic_cache_path().display()
|
||||||
|
);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
// fetch_csic_dataset downloads, caches, and parses — we only need the download side-effect.
|
// fetch_csic_dataset downloads, caches, and parses — we only need the download side-effect.
|
||||||
@@ -96,9 +118,9 @@ pub fn download_all() -> Result<()> {
|
|||||||
download_csic()?;
|
download_csic()?;
|
||||||
eprintln!();
|
eprintln!();
|
||||||
|
|
||||||
eprintln!("[2/2] CIC-IDS2017 DDoS timing profiles");
|
eprintln!("[2/2] CIC-IDS2017 (all attack days + normal baselines)");
|
||||||
let path = download_cicids()?;
|
let path = download_cicids()?;
|
||||||
eprintln!(" ok: {}\n", path.display());
|
eprintln!(" ok: {} ({} files)\n", path.display(), CICIDS_FILES.len());
|
||||||
|
|
||||||
eprintln!("all datasets cached.");
|
eprintln!("all datasets cached.");
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -116,4 +138,10 @@ mod tests {
|
|||||||
let cicids = cicids_cache_dir();
|
let cicids = cicids_cache_dir();
|
||||||
assert!(cicids.to_str().unwrap().contains("cicids"));
|
assert!(cicids.to_str().unwrap().contains("cicids"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_all_files_listed() {
|
||||||
|
assert_eq!(CICIDS_FILES.len(), 8);
|
||||||
|
assert!(CICIDS_FILES.iter().all(|f| f.ends_with(".csv")));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! Parser for OWASP ModSecurity audit log files (Serial / concurrent format).
|
//! Parser for OWASP ModSecurity audit log files (Serial / concurrent format).
|
||||||
//!
|
//!
|
||||||
//! ModSecurity audit logs consist of multi-section entries delimited by boundary
|
//! ModSecurity audit logs consist of multi-section entries delimited by boundary
|
||||||
@@ -176,6 +179,7 @@ fn transaction_to_audit_fields(
|
|||||||
let content_length: u64 = get_header("content-length")
|
let content_length: u64 = get_header("content-length")
|
||||||
.and_then(|v| v.parse().ok())
|
.and_then(|v| v.parse().ok())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
|
let accept = get_header("accept").filter(|a| a != "-" && !a.is_empty());
|
||||||
|
|
||||||
// Section F: response status
|
// Section F: response status
|
||||||
let status = sections
|
let status = sections
|
||||||
@@ -216,11 +220,13 @@ fn transaction_to_audit_fields(
|
|||||||
duration_ms: 0,
|
duration_ms: 0,
|
||||||
content_length,
|
content_length,
|
||||||
user_agent,
|
user_agent,
|
||||||
has_cookies: Some(has_cookies),
|
has_cookies,
|
||||||
referer,
|
referer: referer.unwrap_or_else(|| "-".to_string()),
|
||||||
accept_language,
|
accept_language: accept_language.unwrap_or_else(|| "-".to_string()),
|
||||||
|
accept: accept.unwrap_or_else(|| "-".to_string()),
|
||||||
backend: "-".to_string(),
|
backend: "-".to_string(),
|
||||||
label: Some(label.clone()),
|
label: Some(label.clone()),
|
||||||
|
..AuditFields::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
Some((fields, label))
|
Some((fields, label))
|
||||||
@@ -302,7 +308,7 @@ Content-Type: text/html
|
|||||||
assert_eq!(attack_fields.client_ip, "192.168.1.100");
|
assert_eq!(attack_fields.client_ip, "192.168.1.100");
|
||||||
assert_eq!(attack_fields.user_agent, "curl/7.68.0");
|
assert_eq!(attack_fields.user_agent, "curl/7.68.0");
|
||||||
assert_eq!(attack_fields.status, 403);
|
assert_eq!(attack_fields.status, 403);
|
||||||
assert!(!attack_fields.has_cookies.unwrap_or(true));
|
assert!(!attack_fields.has_cookies);
|
||||||
|
|
||||||
// Second entry: normal (no rule match).
|
// Second entry: normal (no rule match).
|
||||||
let (normal_fields, normal_label) = &results[1];
|
let (normal_fields, normal_label) = &results[1];
|
||||||
@@ -311,9 +317,9 @@ Content-Type: text/html
|
|||||||
assert_eq!(normal_fields.path, "/index.html");
|
assert_eq!(normal_fields.path, "/index.html");
|
||||||
assert_eq!(normal_fields.client_ip, "10.0.0.50");
|
assert_eq!(normal_fields.client_ip, "10.0.0.50");
|
||||||
assert_eq!(normal_fields.status, 200);
|
assert_eq!(normal_fields.status, 200);
|
||||||
assert!(normal_fields.has_cookies.unwrap_or(false));
|
assert!(normal_fields.has_cookies);
|
||||||
assert!(normal_fields.referer.is_some());
|
assert!(normal_fields.referer != "-");
|
||||||
assert!(normal_fields.accept_language.is_some());
|
assert!(normal_fields.accept_language != "-");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! Dataset preparation orchestrator.
|
//! Dataset preparation orchestrator.
|
||||||
//!
|
//!
|
||||||
//! Combines production logs, external datasets (CSIC, OWASP ModSec), and
|
//! Combines production logs, external datasets (CSIC, OWASP ModSec), and
|
||||||
@@ -30,6 +33,10 @@ pub struct PrepareDatasetArgs {
|
|||||||
pub seed: u64,
|
pub seed: u64,
|
||||||
/// Path to heuristics.toml for auto-labeling production logs.
|
/// Path to heuristics.toml for auto-labeling production logs.
|
||||||
pub heuristics: Option<String>,
|
pub heuristics: Option<String>,
|
||||||
|
/// Inject CSIC 2010 entries as labeled audit logs into production stream.
|
||||||
|
pub inject_csic: bool,
|
||||||
|
/// Inject OWASP ModSec entries as labeled audit logs (path to .log file).
|
||||||
|
pub inject_modsec: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for PrepareDatasetArgs {
|
impl Default for PrepareDatasetArgs {
|
||||||
@@ -41,6 +48,8 @@ impl Default for PrepareDatasetArgs {
|
|||||||
output: "dataset.bin".to_string(),
|
output: "dataset.bin".to_string(),
|
||||||
seed: 42,
|
seed: 42,
|
||||||
heuristics: None,
|
heuristics: None,
|
||||||
|
inject_csic: false,
|
||||||
|
inject_modsec: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,18 +80,21 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> {
|
|||||||
scanner_samples.extend(prod_scanner);
|
scanner_samples.extend(prod_scanner);
|
||||||
ddos_samples.extend(prod_ddos);
|
ddos_samples.extend(prod_ddos);
|
||||||
|
|
||||||
// --- 2. CSIC 2010 (scanner) ---
|
// --- 2. Inject external datasets as labeled audit log entries ---
|
||||||
eprintln!("fetching CSIC 2010 dataset...");
|
// These go through the same feature extraction as production logs,
|
||||||
let csic_entries = crate::scanner::csic::fetch_csic_dataset()?;
|
// with ground-truth labels (no heuristic labeling needed).
|
||||||
let csic_samples = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?;
|
if args.inject_csic {
|
||||||
eprintln!(" CSIC: {} scanner samples", csic_samples.len());
|
eprintln!("injecting CSIC 2010 as labeled audit entries...");
|
||||||
scanner_samples.extend(csic_samples);
|
let csic_entries = crate::scanner::csic::fetch_csic_dataset()?;
|
||||||
|
let csic_scanner = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?;
|
||||||
|
eprintln!(" CSIC injected: {} scanner samples", csic_scanner.len());
|
||||||
|
scanner_samples.extend(csic_scanner);
|
||||||
|
}
|
||||||
|
|
||||||
// --- 3. OWASP ModSec (scanner) ---
|
if let Some(modsec_path) = &args.inject_modsec {
|
||||||
if let Some(owasp_path) = &args.owasp {
|
eprintln!("injecting ModSec audit log from {modsec_path}...");
|
||||||
eprintln!("parsing OWASP ModSec audit log from {owasp_path}...");
|
|
||||||
let modsec_entries =
|
let modsec_entries =
|
||||||
crate::dataset::modsec::parse_modsec_audit_log(Path::new(owasp_path))?;
|
crate::dataset::modsec::parse_modsec_audit_log(Path::new(modsec_path))?;
|
||||||
let entries_with_host: Vec<(AuditFields, String)> = modsec_entries
|
let entries_with_host: Vec<(AuditFields, String)> = modsec_entries
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(fields, _label)| {
|
.map(|(fields, _label)| {
|
||||||
@@ -90,16 +102,49 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> {
|
|||||||
(fields, host_prefix)
|
(fields, host_prefix)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let modsec_samples =
|
let modsec_scanner =
|
||||||
entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?;
|
entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?;
|
||||||
eprintln!(" OWASP: {} scanner samples", modsec_samples.len());
|
eprintln!(" ModSec injected: {} scanner samples", modsec_scanner.len());
|
||||||
scanner_samples.extend(modsec_samples);
|
scanner_samples.extend(modsec_scanner);
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- 4. CIC-IDS2017 timing profiles (from cache if downloaded) ---
|
// --- 3. Legacy OWASP path (kept for backwards compat) ---
|
||||||
|
if let Some(owasp_path) = &args.owasp {
|
||||||
|
if args.inject_modsec.as_deref() != Some(owasp_path.as_str()) {
|
||||||
|
eprintln!("parsing OWASP ModSec audit log from {owasp_path}...");
|
||||||
|
let modsec_entries =
|
||||||
|
crate::dataset::modsec::parse_modsec_audit_log(Path::new(owasp_path))?;
|
||||||
|
let entries_with_host: Vec<(AuditFields, String)> = modsec_entries
|
||||||
|
.into_iter()
|
||||||
|
.map(|(fields, _label)| {
|
||||||
|
let host_prefix = fields.host.split('.').next().unwrap_or("").to_string();
|
||||||
|
(fields, host_prefix)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let modsec_samples =
|
||||||
|
entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?;
|
||||||
|
eprintln!(" OWASP: {} scanner samples", modsec_samples.len());
|
||||||
|
scanner_samples.extend(modsec_samples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- 4. CIC-IDS2017 (direct DDoS samples + timing profiles for synthetic) ---
|
||||||
let cicids_profiles = if let Some(cached_path) = crate::dataset::download::cicids_cached_path()
|
let cicids_profiles = if let Some(cached_path) = crate::dataset::download::cicids_cached_path()
|
||||||
{
|
{
|
||||||
eprintln!("extracting CIC-IDS2017 timing profiles from cache...");
|
// Direct conversion: CIC-IDS2017 flows → DDoS training samples
|
||||||
|
eprintln!("extracting CIC-IDS2017 DDoS samples from cache...");
|
||||||
|
let cicids_ddos = crate::dataset::cicids::extract_ddos_samples(&cached_path)?;
|
||||||
|
let attack_count = cicids_ddos.iter().filter(|s| s.label > 0.5).count();
|
||||||
|
eprintln!(
|
||||||
|
" CIC-IDS2017 direct: {} DDoS samples ({} attack, {} normal)",
|
||||||
|
cicids_ddos.len(),
|
||||||
|
attack_count,
|
||||||
|
cicids_ddos.len() - attack_count
|
||||||
|
);
|
||||||
|
ddos_samples.extend(cicids_ddos);
|
||||||
|
|
||||||
|
// Also extract timing profiles for synthetic generation
|
||||||
|
eprintln!("extracting CIC-IDS2017 timing profiles...");
|
||||||
let profiles = crate::dataset::cicids::extract_timing_profiles(&cached_path)?;
|
let profiles = crate::dataset::cicids::extract_timing_profiles(&cached_path)?;
|
||||||
eprintln!(" extracted {} attack-type profiles", profiles.len());
|
eprintln!(" extracted {} attack-type profiles", profiles.len());
|
||||||
profiles
|
profiles
|
||||||
@@ -112,10 +157,10 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> {
|
|||||||
// --- 5. Synthetic data (both models, always generated) ---
|
// --- 5. Synthetic data (both models, always generated) ---
|
||||||
eprintln!("generating synthetic samples...");
|
eprintln!("generating synthetic samples...");
|
||||||
let config = crate::dataset::synthetic::SyntheticConfig {
|
let config = crate::dataset::synthetic::SyntheticConfig {
|
||||||
num_ddos_attack: 10000,
|
num_ddos_attack: 50000,
|
||||||
num_ddos_normal: 10000,
|
num_ddos_normal: 50000,
|
||||||
num_scanner_attack: 5000,
|
num_scanner_attack: 25000,
|
||||||
num_scanner_normal: 5000,
|
num_scanner_normal: 25000,
|
||||||
seed: args.seed,
|
seed: args.seed,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -240,17 +285,9 @@ fn parse_production_logs(
|
|||||||
|
|
||||||
// --- Scanner samples from production logs ---
|
// --- Scanner samples from production logs ---
|
||||||
for (fields, host_prefix) in &parsed_entries {
|
for (fields, host_prefix) in &parsed_entries {
|
||||||
let has_cookies = fields.has_cookies.unwrap_or(false);
|
let has_cookies = fields.has_cookies;
|
||||||
let has_referer = fields
|
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
|
||||||
.referer
|
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
|
||||||
.as_ref()
|
|
||||||
.map(|r| r != "-" && !r.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let has_accept_language = fields
|
|
||||||
.accept_language
|
|
||||||
.as_ref()
|
|
||||||
.map(|a| a != "-" && !a.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let feats = features::extract_features(
|
let feats = features::extract_features(
|
||||||
&fields.method,
|
&fields.method,
|
||||||
@@ -259,7 +296,7 @@ fn parse_production_logs(
|
|||||||
has_cookies,
|
has_cookies,
|
||||||
has_referer,
|
has_referer,
|
||||||
has_accept_language,
|
has_accept_language,
|
||||||
"-",
|
&fields.accept,
|
||||||
&fields.user_agent,
|
&fields.user_agent,
|
||||||
fields.content_length,
|
fields.content_length,
|
||||||
&fragment_hashes,
|
&fragment_hashes,
|
||||||
@@ -352,20 +389,12 @@ fn extract_ddos_samples_from_entries(
|
|||||||
.push(fields.content_length.min(u32::MAX as u64) as u32);
|
.push(fields.content_length.min(u32::MAX as u64) as u32);
|
||||||
state
|
state
|
||||||
.has_cookies
|
.has_cookies
|
||||||
.push(fields.has_cookies.unwrap_or(false));
|
.push(fields.has_cookies);
|
||||||
state.has_referer.push(
|
state.has_referer.push(
|
||||||
fields
|
!fields.referer.is_empty() && fields.referer != "-",
|
||||||
.referer
|
|
||||||
.as_deref()
|
|
||||||
.map(|r| r != "-")
|
|
||||||
.unwrap_or(false),
|
|
||||||
);
|
);
|
||||||
state.has_accept_language.push(
|
state.has_accept_language.push(
|
||||||
fields
|
!fields.accept_language.is_empty() && fields.accept_language != "-",
|
||||||
.accept_language
|
|
||||||
.as_deref()
|
|
||||||
.map(|a| a != "-")
|
|
||||||
.unwrap_or(false),
|
|
||||||
);
|
);
|
||||||
state.suspicious_paths.push(
|
state.suspicious_paths.push(
|
||||||
crate::ddos::features::is_suspicious_path(&fields.path),
|
crate::ddos::features::is_suspicious_path(&fields.path),
|
||||||
@@ -462,17 +491,9 @@ fn entries_to_scanner_samples(
|
|||||||
let mut samples = Vec::new();
|
let mut samples = Vec::new();
|
||||||
|
|
||||||
for (fields, host_prefix) in entries {
|
for (fields, host_prefix) in entries {
|
||||||
let has_cookies = fields.has_cookies.unwrap_or(false);
|
let has_cookies = fields.has_cookies;
|
||||||
let has_referer = fields
|
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
|
||||||
.referer
|
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
|
||||||
.as_ref()
|
|
||||||
.map(|r| r != "-" && !r.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let has_accept_language = fields
|
|
||||||
.accept_language
|
|
||||||
.as_ref()
|
|
||||||
.map(|a| a != "-" && !a.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let feats = features::extract_features(
|
let feats = features::extract_features(
|
||||||
&fields.method,
|
&fields.method,
|
||||||
@@ -481,7 +502,7 @@ fn entries_to_scanner_samples(
|
|||||||
has_cookies,
|
has_cookies,
|
||||||
has_referer,
|
has_referer,
|
||||||
has_accept_language,
|
has_accept_language,
|
||||||
"-",
|
&fields.accept,
|
||||||
&fields.user_agent,
|
&fields.user_agent,
|
||||||
fields.content_length,
|
fields.content_length,
|
||||||
&fragment_hashes,
|
&fragment_hashes,
|
||||||
@@ -587,11 +608,12 @@ mod tests {
|
|||||||
duration_ms: 10,
|
duration_ms: 10,
|
||||||
content_length: 0,
|
content_length: 0,
|
||||||
user_agent: "Mozilla/5.0".to_string(),
|
user_agent: "Mozilla/5.0".to_string(),
|
||||||
has_cookies: Some(true),
|
has_cookies: true,
|
||||||
referer: Some("https://test.sunbeam.pt".to_string()),
|
referer: "https://test.sunbeam.pt".to_string(),
|
||||||
accept_language: Some("en-US".to_string()),
|
accept_language: "en-US".to_string(),
|
||||||
backend: "test-svc:8080".to_string(),
|
backend: "test-svc:8080".to_string(),
|
||||||
label: Some(label.to_string()),
|
label: Some(label.to_string()),
|
||||||
|
..AuditFields::default()
|
||||||
};
|
};
|
||||||
(fields, "test".to_string())
|
(fields, "test".to_string())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,83 +1,11 @@
|
|||||||
use serde::Deserialize;
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
//! Re-exports from `crate::audit` — the canonical audit log definition.
|
||||||
pub struct AuditLog {
|
//!
|
||||||
pub timestamp: String,
|
//! All new code should `use crate::audit::*` directly.
|
||||||
pub fields: AuditFields,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
pub use crate::audit::strip_port;
|
||||||
pub struct AuditFields {
|
pub use crate::audit::AuditFields;
|
||||||
pub method: String,
|
pub use crate::audit::AuditLogLine as AuditLog;
|
||||||
pub host: String,
|
pub use crate::audit::{flexible_u16, flexible_u64};
|
||||||
pub path: String,
|
|
||||||
pub client_ip: String,
|
|
||||||
#[serde(deserialize_with = "flexible_u16")]
|
|
||||||
pub status: u16,
|
|
||||||
#[serde(deserialize_with = "flexible_u64")]
|
|
||||||
pub duration_ms: u64,
|
|
||||||
#[serde(default)]
|
|
||||||
pub backend: String,
|
|
||||||
#[serde(default)]
|
|
||||||
pub content_length: u64,
|
|
||||||
#[serde(default = "default_ua")]
|
|
||||||
pub user_agent: String,
|
|
||||||
#[serde(default)]
|
|
||||||
pub query: String,
|
|
||||||
#[serde(default)]
|
|
||||||
pub has_cookies: Option<bool>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub referer: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub accept_language: Option<String>,
|
|
||||||
/// Optional ground-truth label from external datasets (e.g. CSIC 2010).
|
|
||||||
/// Values: "attack", "normal". When present, trainers should use this
|
|
||||||
/// instead of heuristic labeling.
|
|
||||||
#[serde(default)]
|
|
||||||
pub label: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_ua() -> String {
|
|
||||||
"-".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn flexible_u64<'de, D: serde::Deserializer<'de>>(
|
|
||||||
deserializer: D,
|
|
||||||
) -> std::result::Result<u64, D::Error> {
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
enum StringOrNum {
|
|
||||||
Num(u64),
|
|
||||||
Str(String),
|
|
||||||
}
|
|
||||||
match StringOrNum::deserialize(deserializer)? {
|
|
||||||
StringOrNum::Num(n) => Ok(n),
|
|
||||||
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn flexible_u16<'de, D: serde::Deserializer<'de>>(
|
|
||||||
deserializer: D,
|
|
||||||
) -> std::result::Result<u16, D::Error> {
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
enum StringOrNum {
|
|
||||||
Num(u16),
|
|
||||||
Str(String),
|
|
||||||
}
|
|
||||||
match StringOrNum::deserialize(deserializer)? {
|
|
||||||
StringOrNum::Num(n) => Ok(n),
|
|
||||||
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Strip the port suffix from a socket address string.
|
|
||||||
pub fn strip_port(addr: &str) -> &str {
|
|
||||||
if addr.starts_with('[') {
|
|
||||||
addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr)
|
|
||||||
} else if let Some(pos) = addr.rfind(':') {
|
|
||||||
&addr[..pos]
|
|
||||||
} else {
|
|
||||||
addr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
use crate::config::DDoSConfig;
|
use crate::config::DDoSConfig;
|
||||||
use crate::ddos::features::{method_to_u8, IpState, RequestEvent};
|
use crate::ddos::features::{method_to_u8, IpState, RequestEvent};
|
||||||
use crate::ddos::model::{DDoSAction, TrainedModel};
|
use crate::ddos::model::DDoSAction;
|
||||||
use rustc_hash::FxHashMap;
|
use rustc_hash::FxHashMap;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
@@ -10,12 +13,10 @@ use std::time::Instant;
|
|||||||
const NUM_SHARDS: usize = 256;
|
const NUM_SHARDS: usize = 256;
|
||||||
|
|
||||||
pub struct DDoSDetector {
|
pub struct DDoSDetector {
|
||||||
model: TrainedModel,
|
|
||||||
shards: Vec<RwLock<FxHashMap<IpAddr, IpState>>>,
|
shards: Vec<RwLock<FxHashMap<IpAddr, IpState>>>,
|
||||||
window_secs: u64,
|
window_secs: u64,
|
||||||
window_capacity: usize,
|
window_capacity: usize,
|
||||||
min_events: usize,
|
min_events: usize,
|
||||||
use_ensemble: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn shard_index(ip: &IpAddr) -> usize {
|
fn shard_index(ip: &IpAddr) -> usize {
|
||||||
@@ -25,34 +26,15 @@ fn shard_index(ip: &IpAddr) -> usize {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl DDoSDetector {
|
impl DDoSDetector {
|
||||||
pub fn new(model: TrainedModel, config: &DDoSConfig) -> Self {
|
pub fn new(config: &DDoSConfig) -> Self {
|
||||||
let shards = (0..NUM_SHARDS)
|
let shards = (0..NUM_SHARDS)
|
||||||
.map(|_| RwLock::new(FxHashMap::default()))
|
.map(|_| RwLock::new(FxHashMap::default()))
|
||||||
.collect();
|
.collect();
|
||||||
Self {
|
Self {
|
||||||
model,
|
|
||||||
shards,
|
shards,
|
||||||
window_secs: config.window_secs,
|
window_secs: config.window_secs,
|
||||||
window_capacity: config.window_capacity,
|
window_capacity: config.window_capacity,
|
||||||
min_events: config.min_events,
|
min_events: config.min_events,
|
||||||
use_ensemble: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a detector that uses the ensemble (decision tree + MLP) path.
|
|
||||||
/// A dummy model is still needed for fallback, but ensemble inference
|
|
||||||
/// takes priority when `use_ensemble` is true.
|
|
||||||
pub fn new_ensemble(model: TrainedModel, config: &DDoSConfig) -> Self {
|
|
||||||
let shards = (0..NUM_SHARDS)
|
|
||||||
.map(|_| RwLock::new(FxHashMap::default()))
|
|
||||||
.collect();
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
shards,
|
|
||||||
window_secs: config.window_secs,
|
|
||||||
window_capacity: config.window_capacity,
|
|
||||||
min_events: config.min_events,
|
|
||||||
use_ensemble: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,24 +81,20 @@ impl DDoSDetector {
|
|||||||
|
|
||||||
let features = state.extract_features(self.window_secs);
|
let features = state.extract_features(self.window_secs);
|
||||||
|
|
||||||
if self.use_ensemble {
|
// Cast f64 features to f32 array for ensemble inference.
|
||||||
// Cast f64 features to f32 array for ensemble inference.
|
let mut f32_features = [0.0f32; 14];
|
||||||
let mut f32_features = [0.0f32; 14];
|
for (i, &v) in features.iter().enumerate().take(14) {
|
||||||
for (i, &v) in features.iter().enumerate().take(14) {
|
f32_features[i] = v as f32;
|
||||||
f32_features[i] = v as f32;
|
|
||||||
}
|
|
||||||
let ev = crate::ensemble::ddos::ddos_ensemble_predict(&f32_features);
|
|
||||||
crate::metrics::DDOS_ENSEMBLE_PATH
|
|
||||||
.with_label_values(&[match ev.path {
|
|
||||||
crate::ensemble::ddos::DDoSEnsemblePath::TreeBlock => "tree_block",
|
|
||||||
crate::ensemble::ddos::DDoSEnsemblePath::TreeAllow => "tree_allow",
|
|
||||||
crate::ensemble::ddos::DDoSEnsemblePath::Mlp => "mlp",
|
|
||||||
}])
|
|
||||||
.inc();
|
|
||||||
return ev.action;
|
|
||||||
}
|
}
|
||||||
|
let ev = crate::ensemble::ddos::ddos_ensemble_predict(&f32_features);
|
||||||
self.model.classify(&features)
|
crate::metrics::DDOS_ENSEMBLE_PATH
|
||||||
|
.with_label_values(&[match ev.path {
|
||||||
|
crate::ensemble::ddos::DDoSEnsemblePath::TreeBlock => "tree_block",
|
||||||
|
crate::ensemble::ddos::DDoSEnsemblePath::TreeAllow => "tree_allow",
|
||||||
|
crate::ensemble::ddos::DDoSEnsemblePath::Mlp => "mlp",
|
||||||
|
}])
|
||||||
|
.inc();
|
||||||
|
ev.action
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Feed response data back into the IP's event history.
|
/// Feed response data back into the IP's event history.
|
||||||
@@ -125,10 +103,6 @@ impl DDoSDetector {
|
|||||||
// Status/duration from check() are 0-initialized; the next request
|
// Status/duration from check() are 0-initialized; the next request
|
||||||
// will have fresh data. This is intentionally a no-op for now.
|
// will have fresh data. This is intentionally a no-op for now.
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn point_count(&self) -> usize {
|
|
||||||
self.model.point_count()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fx_hash(s: &str) -> u64 {
|
fn fx_hash(s: &str) -> u64 {
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
pub mod audit_log;
|
pub mod audit_log;
|
||||||
pub mod detector;
|
pub mod detector;
|
||||||
pub mod features;
|
pub mod features;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod replay;
|
|
||||||
pub mod train;
|
pub mod train;
|
||||||
|
|||||||
@@ -1,183 +1,8 @@
|
|||||||
use crate::ddos::features::{FeatureVector, NormParams, NUM_FEATURES};
|
// Copyright Sunbeam Studios 2026
|
||||||
use anyhow::{Context, Result};
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
pub enum TrafficLabel {
|
|
||||||
Normal,
|
|
||||||
Attack,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
pub struct SerializedModel {
|
|
||||||
pub points: Vec<FeatureVector>,
|
|
||||||
pub labels: Vec<TrafficLabel>,
|
|
||||||
pub norm_params: NormParams,
|
|
||||||
pub k: usize,
|
|
||||||
pub threshold: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct TrainedModel {
|
|
||||||
/// Stored points (normalized). The kD-tree borrows these.
|
|
||||||
points: Vec<[f64; NUM_FEATURES]>,
|
|
||||||
labels: Vec<TrafficLabel>,
|
|
||||||
norm_params: NormParams,
|
|
||||||
k: usize,
|
|
||||||
threshold: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub enum DDoSAction {
|
pub enum DDoSAction {
|
||||||
Allow,
|
Allow,
|
||||||
Block,
|
Block,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TrainedModel {
|
|
||||||
pub fn load(path: &Path, k_override: Option<usize>, threshold_override: Option<f64>) -> Result<Self> {
|
|
||||||
let data = std::fs::read(path)
|
|
||||||
.with_context(|| format!("reading model from {}", path.display()))?;
|
|
||||||
let model: SerializedModel =
|
|
||||||
bincode::deserialize(&data).context("deserializing model")?;
|
|
||||||
Ok(Self {
|
|
||||||
points: model.points,
|
|
||||||
labels: model.labels,
|
|
||||||
norm_params: model.norm_params,
|
|
||||||
k: k_override.unwrap_or(model.k),
|
|
||||||
threshold: threshold_override.unwrap_or(model.threshold),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an empty model (no training points). Used when the ensemble
|
|
||||||
/// path is active and the KNN model is not needed.
|
|
||||||
pub fn empty(k: usize, threshold: f64) -> Self {
|
|
||||||
Self {
|
|
||||||
points: vec![],
|
|
||||||
labels: vec![],
|
|
||||||
norm_params: NormParams {
|
|
||||||
mins: [0.0; NUM_FEATURES],
|
|
||||||
maxs: [1.0; NUM_FEATURES],
|
|
||||||
},
|
|
||||||
k,
|
|
||||||
threshold,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_serialized(model: SerializedModel) -> Self {
|
|
||||||
Self {
|
|
||||||
points: model.points,
|
|
||||||
labels: model.labels,
|
|
||||||
norm_params: model.norm_params,
|
|
||||||
k: model.k,
|
|
||||||
threshold: model.threshold,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn classify(&self, features: &FeatureVector) -> DDoSAction {
|
|
||||||
let normalized = self.norm_params.normalize(features);
|
|
||||||
|
|
||||||
if self.points.is_empty() {
|
|
||||||
return DDoSAction::Allow;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build tree on-the-fly for query. In production with many queries,
|
|
||||||
// we'd cache this, but the tree build is fast for <100K points.
|
|
||||||
// fnntw::Tree borrows data, so we build it here.
|
|
||||||
let tree = match fnntw::Tree::<'_, f64, NUM_FEATURES>::new(&self.points, 32) {
|
|
||||||
Ok(t) => t,
|
|
||||||
Err(_) => return DDoSAction::Allow,
|
|
||||||
};
|
|
||||||
|
|
||||||
let k = self.k.min(self.points.len());
|
|
||||||
let result = tree.query_nearest_k(&normalized, k);
|
|
||||||
match result {
|
|
||||||
Ok((_distances, indices)) => {
|
|
||||||
let attack_count = indices
|
|
||||||
.iter()
|
|
||||||
.filter(|&&idx| self.labels[idx as usize] == TrafficLabel::Attack)
|
|
||||||
.count();
|
|
||||||
let attack_frac = attack_count as f64 / k as f64;
|
|
||||||
if attack_frac >= self.threshold {
|
|
||||||
DDoSAction::Block
|
|
||||||
} else {
|
|
||||||
DDoSAction::Allow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => DDoSAction::Allow,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn norm_params(&self) -> &NormParams {
|
|
||||||
&self.norm_params
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn point_count(&self) -> usize {
|
|
||||||
self.points.len()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_classify_empty_model() {
|
|
||||||
let model = TrainedModel {
|
|
||||||
points: vec![],
|
|
||||||
labels: vec![],
|
|
||||||
norm_params: NormParams {
|
|
||||||
mins: [0.0; NUM_FEATURES],
|
|
||||||
maxs: [1.0; NUM_FEATURES],
|
|
||||||
},
|
|
||||||
k: 5,
|
|
||||||
threshold: 0.6,
|
|
||||||
};
|
|
||||||
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn make_test_points(n: usize) -> Vec<FeatureVector> {
|
|
||||||
(0..n)
|
|
||||||
.map(|i| {
|
|
||||||
let mut v = [0.0; NUM_FEATURES];
|
|
||||||
for d in 0..NUM_FEATURES {
|
|
||||||
v[d] = ((i * (d + 1)) as f64 / n as f64) % 1.0;
|
|
||||||
}
|
|
||||||
v
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_classify_all_attack() {
|
|
||||||
let points = make_test_points(100);
|
|
||||||
let labels = vec![TrafficLabel::Attack; 100];
|
|
||||||
let model = TrainedModel {
|
|
||||||
points,
|
|
||||||
labels,
|
|
||||||
norm_params: NormParams {
|
|
||||||
mins: [0.0; NUM_FEATURES],
|
|
||||||
maxs: [1.0; NUM_FEATURES],
|
|
||||||
},
|
|
||||||
k: 5,
|
|
||||||
threshold: 0.6,
|
|
||||||
};
|
|
||||||
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Block);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_classify_all_normal() {
|
|
||||||
let points = make_test_points(100);
|
|
||||||
let labels = vec![TrafficLabel::Normal; 100];
|
|
||||||
let model = TrainedModel {
|
|
||||||
points,
|
|
||||||
labels,
|
|
||||||
norm_params: NormParams {
|
|
||||||
mins: [0.0; NUM_FEATURES],
|
|
||||||
maxs: [1.0; NUM_FEATURES],
|
|
||||||
},
|
|
||||||
k: 5,
|
|
||||||
threshold: 0.6,
|
|
||||||
};
|
|
||||||
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,281 +0,0 @@
|
|||||||
use crate::config::{DDoSConfig, RateLimitConfig};
|
|
||||||
use crate::ddos::audit_log::{self, AuditLog};
|
|
||||||
use crate::ddos::detector::DDoSDetector;
|
|
||||||
use crate::ddos::model::{DDoSAction, TrainedModel};
|
|
||||||
use crate::rate_limit::key::RateLimitKey;
|
|
||||||
use crate::rate_limit::limiter::{RateLimitResult, RateLimiter};
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use rustc_hash::FxHashMap;
|
|
||||||
use std::io::BufRead;
|
|
||||||
use std::net::IpAddr;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
pub struct ReplayArgs {
|
|
||||||
pub input: String,
|
|
||||||
pub model_path: String,
|
|
||||||
pub config_path: Option<String>,
|
|
||||||
pub k: usize,
|
|
||||||
pub threshold: f64,
|
|
||||||
pub window_secs: u64,
|
|
||||||
pub min_events: usize,
|
|
||||||
pub rate_limit: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ReplayResult {
|
|
||||||
pub total: u64,
|
|
||||||
pub skipped: u64,
|
|
||||||
pub ddos_blocked: u64,
|
|
||||||
pub rate_limited: u64,
|
|
||||||
pub allowed: u64,
|
|
||||||
pub ddos_blocked_ips: FxHashMap<String, u64>,
|
|
||||||
pub rate_limited_ips: FxHashMap<String, u64>,
|
|
||||||
pub false_positive_ips: usize,
|
|
||||||
pub true_positive_ips: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Core replay pipeline: load model, replay logs, compute stats including false positive analysis.
|
|
||||||
pub fn replay_and_evaluate(args: &ReplayArgs) -> Result<ReplayResult> {
|
|
||||||
let model = TrainedModel::load(
|
|
||||||
std::path::Path::new(&args.model_path),
|
|
||||||
Some(args.k),
|
|
||||||
Some(args.threshold),
|
|
||||||
)
|
|
||||||
.with_context(|| format!("loading model from {}", args.model_path))?;
|
|
||||||
|
|
||||||
let ddos_cfg = DDoSConfig {
|
|
||||||
model_path: Some(args.model_path.clone()),
|
|
||||||
k: args.k,
|
|
||||||
threshold: args.threshold,
|
|
||||||
window_secs: args.window_secs,
|
|
||||||
window_capacity: 1000,
|
|
||||||
min_events: args.min_events,
|
|
||||||
enabled: true,
|
|
||||||
use_ensemble: false,
|
|
||||||
};
|
|
||||||
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
|
|
||||||
|
|
||||||
let rate_limiter = if args.rate_limit {
|
|
||||||
let rl_cfg = if let Some(cfg_path) = &args.config_path {
|
|
||||||
let cfg = crate::config::Config::load(cfg_path)?;
|
|
||||||
cfg.rate_limit.unwrap_or_else(default_rate_limit_config)
|
|
||||||
} else {
|
|
||||||
default_rate_limit_config()
|
|
||||||
};
|
|
||||||
Some(RateLimiter::new(&rl_cfg))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let file = std::fs::File::open(&args.input)
|
|
||||||
.with_context(|| format!("opening {}", args.input))?;
|
|
||||||
let reader = std::io::BufReader::new(file);
|
|
||||||
|
|
||||||
let mut total = 0u64;
|
|
||||||
let mut skipped = 0u64;
|
|
||||||
let mut ddos_blocked = 0u64;
|
|
||||||
let mut rate_limited = 0u64;
|
|
||||||
let mut allowed = 0u64;
|
|
||||||
let mut ddos_blocked_ips: FxHashMap<String, u64> = FxHashMap::default();
|
|
||||||
let mut rate_limited_ips: FxHashMap<String, u64> = FxHashMap::default();
|
|
||||||
|
|
||||||
for line in reader.lines() {
|
|
||||||
let line = line?;
|
|
||||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
|
||||||
Ok(e) => e,
|
|
||||||
Err(_) => {
|
|
||||||
skipped += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if entry.fields.method.is_empty() {
|
|
||||||
skipped += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
total += 1;
|
|
||||||
|
|
||||||
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
|
||||||
let ip: IpAddr = match ip_str.parse() {
|
|
||||||
Ok(ip) => ip,
|
|
||||||
Err(_) => {
|
|
||||||
skipped += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
|
|
||||||
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
|
|
||||||
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
|
|
||||||
let ddos_action = detector.check(
|
|
||||||
ip,
|
|
||||||
&entry.fields.method,
|
|
||||||
&entry.fields.path,
|
|
||||||
&entry.fields.host,
|
|
||||||
&entry.fields.user_agent,
|
|
||||||
entry.fields.content_length,
|
|
||||||
has_cookies,
|
|
||||||
has_referer,
|
|
||||||
has_accept_language,
|
|
||||||
);
|
|
||||||
|
|
||||||
if ddos_action == DDoSAction::Block {
|
|
||||||
ddos_blocked += 1;
|
|
||||||
*ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(limiter) = &rate_limiter {
|
|
||||||
let rl_key = RateLimitKey::Ip(ip);
|
|
||||||
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
|
|
||||||
rate_limited += 1;
|
|
||||||
*rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
allowed += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute false positive / true positive counts
|
|
||||||
let (false_positive_ips, true_positive_ips) =
|
|
||||||
count_fp_tp(&args.input, &ddos_blocked_ips, &rate_limited_ips)?;
|
|
||||||
|
|
||||||
Ok(ReplayResult {
|
|
||||||
total,
|
|
||||||
skipped,
|
|
||||||
ddos_blocked,
|
|
||||||
rate_limited,
|
|
||||||
allowed,
|
|
||||||
ddos_blocked_ips,
|
|
||||||
rate_limited_ips,
|
|
||||||
false_positive_ips,
|
|
||||||
true_positive_ips,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Count false-positive and true-positive IPs from blocked set.
|
|
||||||
/// FP = blocked IP where >60% of original responses were 2xx/3xx.
|
|
||||||
fn count_fp_tp(
|
|
||||||
input: &str,
|
|
||||||
ddos_blocked_ips: &FxHashMap<String, u64>,
|
|
||||||
rate_limited_ips: &FxHashMap<String, u64>,
|
|
||||||
) -> Result<(usize, usize)> {
|
|
||||||
let blocked_ips: rustc_hash::FxHashSet<&str> = ddos_blocked_ips
|
|
||||||
.keys()
|
|
||||||
.chain(rate_limited_ips.keys())
|
|
||||||
.map(|s| s.as_str())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if blocked_ips.is_empty() {
|
|
||||||
return Ok((0, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
let file = std::fs::File::open(input)?;
|
|
||||||
let reader = std::io::BufReader::new(file);
|
|
||||||
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
|
|
||||||
|
|
||||||
for line in reader.lines() {
|
|
||||||
let line = line?;
|
|
||||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
|
||||||
Ok(e) => e,
|
|
||||||
Err(_) => continue,
|
|
||||||
};
|
|
||||||
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
|
|
||||||
if blocked_ips.contains(ip_str.as_str()) {
|
|
||||||
ip_statuses.entry(ip_str).or_default().push(entry.fields.status);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut fp = 0usize;
|
|
||||||
let mut tp = 0usize;
|
|
||||||
for (_ip, statuses) in &ip_statuses {
|
|
||||||
let total = statuses.len();
|
|
||||||
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
|
|
||||||
let ok_pct = ok_count as f64 / total as f64 * 100.0;
|
|
||||||
if ok_pct > 60.0 {
|
|
||||||
fp += 1;
|
|
||||||
} else {
|
|
||||||
tp += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok((fp, tp))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn run(args: ReplayArgs) -> Result<()> {
|
|
||||||
eprintln!("Loading model from {}...", args.model_path);
|
|
||||||
eprintln!("Replaying {}...\n", args.input);
|
|
||||||
|
|
||||||
let result = replay_and_evaluate(&args)?;
|
|
||||||
|
|
||||||
let total = result.total;
|
|
||||||
eprintln!("═══ Replay Results ═══════════════════════════════════════");
|
|
||||||
eprintln!(" Total requests: {total}");
|
|
||||||
eprintln!(" Skipped (parse): {}", result.skipped);
|
|
||||||
eprintln!(" Allowed: {} ({:.1}%)", result.allowed, pct(result.allowed, total));
|
|
||||||
eprintln!(" DDoS blocked: {} ({:.1}%)", result.ddos_blocked, pct(result.ddos_blocked, total));
|
|
||||||
if args.rate_limit {
|
|
||||||
eprintln!(" Rate limited: {} ({:.1}%)", result.rate_limited, pct(result.rate_limited, total));
|
|
||||||
}
|
|
||||||
|
|
||||||
if !result.ddos_blocked_ips.is_empty() {
|
|
||||||
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
|
|
||||||
let mut sorted: Vec<_> = result.ddos_blocked_ips.iter().collect();
|
|
||||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
|
||||||
for (ip, count) in sorted.iter().take(20) {
|
|
||||||
eprintln!(" {:<40} {} reqs blocked", ip, count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !result.rate_limited_ips.is_empty() {
|
|
||||||
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
|
|
||||||
let mut sorted: Vec<_> = result.rate_limited_ips.iter().collect();
|
|
||||||
sorted.sort_by(|a, b| b.1.cmp(a.1));
|
|
||||||
for (ip, count) in sorted.iter().take(20) {
|
|
||||||
eprintln!(" {:<40} {} reqs limited", ip, count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
eprintln!("\n── False positive check ──────────────────────────────────");
|
|
||||||
if result.false_positive_ips == 0 {
|
|
||||||
eprintln!(" No likely false positives found.");
|
|
||||||
} else {
|
|
||||||
eprintln!(" ⚠ {} IPs were blocked but had mostly successful responses", result.false_positive_ips);
|
|
||||||
}
|
|
||||||
eprintln!(" True positive IPs: {}", result.true_positive_ips);
|
|
||||||
|
|
||||||
eprintln!("══════════════════════════════════════════════════════════");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_rate_limit_config() -> RateLimitConfig {
|
|
||||||
RateLimitConfig {
|
|
||||||
enabled: true,
|
|
||||||
bypass_cidrs: vec![
|
|
||||||
"10.0.0.0/8".into(),
|
|
||||||
"172.16.0.0/12".into(),
|
|
||||||
"192.168.0.0/16".into(),
|
|
||||||
"100.64.0.0/10".into(),
|
|
||||||
"fd00::/8".into(),
|
|
||||||
],
|
|
||||||
eviction_interval_secs: 300,
|
|
||||||
stale_after_secs: 600,
|
|
||||||
authenticated: crate::config::BucketConfig {
|
|
||||||
burst: 200,
|
|
||||||
rate: 50.0,
|
|
||||||
},
|
|
||||||
unauthenticated: crate::config::BucketConfig {
|
|
||||||
burst: 60,
|
|
||||||
rate: 15.0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pct(n: u64, total: u64) -> f64 {
|
|
||||||
if total == 0 {
|
|
||||||
0.0
|
|
||||||
} else {
|
|
||||||
(n as f64 / total as f64) * 100.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,13 +1,32 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
use crate::ddos::audit_log::AuditLog;
|
use crate::ddos::audit_log::AuditLog;
|
||||||
use crate::ddos::audit_log;
|
use crate::ddos::audit_log;
|
||||||
use crate::ddos::features::{method_to_u8, FeatureVector, LogIpState, NormParams, NUM_FEATURES};
|
use crate::ddos::features::{method_to_u8, FeatureVector, LogIpState, NormParams, NUM_FEATURES};
|
||||||
use crate::ddos::model::{SerializedModel, TrafficLabel};
|
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use rustc_hash::{FxHashMap, FxHashSet};
|
use rustc_hash::{FxHashMap, FxHashSet};
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::io::BufRead;
|
use std::io::BufRead;
|
||||||
|
|
||||||
|
/// Legacy KNN training types — kept for the `train-ddos` CLI command
|
||||||
|
/// which produces bincode model files for offline evaluation.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum TrafficLabel {
|
||||||
|
Normal,
|
||||||
|
Attack,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct SerializedModel {
|
||||||
|
pub points: Vec<FeatureVector>,
|
||||||
|
pub labels: Vec<TrafficLabel>,
|
||||||
|
pub norm_params: NormParams,
|
||||||
|
pub k: usize,
|
||||||
|
pub threshold: f64,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct HeuristicThresholds {
|
pub struct HeuristicThresholds {
|
||||||
/// Requests/second above which an IP is labeled attack
|
/// Requests/second above which an IP is labeled attack
|
||||||
@@ -255,12 +274,12 @@ pub fn parse_logs(input: &str) -> Result<FxHashMap<String, LogIpState>> {
|
|||||||
state.statuses.push(entry.fields.status);
|
state.statuses.push(entry.fields.status);
|
||||||
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
|
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
|
||||||
state.content_lengths.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
|
state.content_lengths.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
|
||||||
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
|
state.has_cookies.push(entry.fields.has_cookies);
|
||||||
state.has_referer.push(
|
state.has_referer.push(
|
||||||
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
|
!entry.fields.referer.is_empty() && entry.fields.referer != "-",
|
||||||
);
|
);
|
||||||
state.has_accept_language.push(
|
state.has_accept_language.push(
|
||||||
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
|
!entry.fields.accept_language.is_empty() && entry.fields.accept_language != "-",
|
||||||
);
|
);
|
||||||
state.suspicious_paths.push(
|
state.suspicious_paths.push(
|
||||||
crate::ddos::features::is_suspicious_path(&entry.fields.path),
|
crate::ddos::features::is_suspicious_path(&entry.fields.path),
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
use crate::ddos::model::DDoSAction;
|
use crate::ddos::model::DDoSAction;
|
||||||
use super::gen::ddos_weights;
|
use super::gen::ddos_weights;
|
||||||
use super::mlp::mlp_predict_32;
|
use super::mlp::mlp_predict_32;
|
||||||
@@ -80,59 +83,46 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tree_allow_path() {
|
fn test_tree_block_path() {
|
||||||
// All zeros → feature 4 (request_rate) = 0.0 <= 0.70 → left (node 1)
|
// Tree: root splits on feature 10 (cookie_ratio) at 0.14.
|
||||||
// feature 10 (cookie_ratio) = 0.0 <= 0.30 → left (node 3) → Allow
|
// All zeros → cookie_ratio normalized = 0.0 <= 0.14 → Block (node 1)
|
||||||
let raw = [0.0f32; 14];
|
let raw = [0.0f32; 14];
|
||||||
let v = ddos_ensemble_predict(&raw);
|
let v = ddos_ensemble_predict(&raw);
|
||||||
|
assert_eq!(v.action, DDoSAction::Block);
|
||||||
|
assert_eq!(v.path, DDoSEnsemblePath::TreeBlock);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tree_allow_path() {
|
||||||
|
// Tree: feature 10 (cookie_ratio) > 0.14 → node 2 (Allow leaf)
|
||||||
|
// feature 10 range [0, 1], raw 0.5 → normalized 0.5 > 0.14 → Allow
|
||||||
|
let mut raw = [0.0f32; 14];
|
||||||
|
raw[10] = 0.5;
|
||||||
|
let v = ddos_ensemble_predict(&raw);
|
||||||
assert_eq!(v.action, DDoSAction::Allow);
|
assert_eq!(v.action, DDoSAction::Allow);
|
||||||
assert_eq!(v.path, DDoSEnsemblePath::TreeAllow);
|
assert_eq!(v.path, DDoSEnsemblePath::TreeAllow);
|
||||||
assert_eq!(v.reason, "ensemble:tree_allow");
|
assert_eq!(v.reason, "ensemble:tree_allow");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tree_block_path() {
|
fn test_mlp_direct() {
|
||||||
// Need: feature 4 (request_rate) > 0.70 normalized → right (node 2)
|
// Current tree has no Defer leaves, so test MLP inference directly.
|
||||||
// feature 12 (accept_language_ratio) > 0.25 normalized → right (node 6) → Block
|
let input = [0.5f32; 14];
|
||||||
// feature 4 max = 500, so raw 400 → normalized 0.8 > 0.70 ✓
|
let score = mlp_predict_32::<14>(
|
||||||
// feature 12 max = 1.0, so raw 0.5 → normalized 0.5 > 0.25 ✓
|
&ddos_weights::W1,
|
||||||
let mut raw = [0.0f32; 14];
|
&ddos_weights::B1,
|
||||||
raw[4] = 400.0;
|
&ddos_weights::W2,
|
||||||
raw[12] = 0.5;
|
ddos_weights::B2,
|
||||||
let v = ddos_ensemble_predict(&raw);
|
&input,
|
||||||
assert_eq!(v.action, DDoSAction::Block);
|
);
|
||||||
assert_eq!(v.path, DDoSEnsemblePath::TreeBlock);
|
assert!(score >= 0.0 && score <= 1.0);
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_mlp_path() {
|
|
||||||
// Need: feature 4 > 0.70 normalized → right (node 2)
|
|
||||||
// feature 12 <= 0.25 normalized → left (node 5) → Defer
|
|
||||||
// feature 4 max = 500, raw 400 → 0.8 > 0.70 ✓
|
|
||||||
// feature 12 max = 1.0, raw 0.1 → 0.1 <= 0.25 ✓
|
|
||||||
let mut raw = [0.0f32; 14];
|
|
||||||
raw[4] = 400.0;
|
|
||||||
raw[12] = 0.1;
|
|
||||||
let v = ddos_ensemble_predict(&raw);
|
|
||||||
assert_eq!(v.path, DDoSEnsemblePath::Mlp);
|
|
||||||
assert_eq!(v.reason, "ensemble:mlp");
|
|
||||||
assert!(v.score >= 0.0 && v.score <= 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_defer_then_mlp_allow() {
|
|
||||||
// Same Defer path as above — verify the MLP produces a valid action
|
|
||||||
let mut raw = [0.0f32; 14];
|
|
||||||
raw[4] = 400.0;
|
|
||||||
raw[12] = 0.1;
|
|
||||||
let v = ddos_ensemble_predict(&raw);
|
|
||||||
assert!(matches!(v.action, DDoSAction::Allow | DDoSAction::Block));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_normalize_clamps_high() {
|
fn test_normalize_clamps_high() {
|
||||||
|
// feature 0 max = 10000.0, raw 999999 → clamped to 1.0
|
||||||
let mut raw = [0.0f32; 14];
|
let mut raw = [0.0f32; 14];
|
||||||
raw[0] = 999.0; // max is 100
|
raw[0] = 999999.0;
|
||||||
let normed = normalize(&raw);
|
let normed = normalize(&raw);
|
||||||
assert!((normed[0] - 1.0).abs() < f32::EPSILON);
|
assert!((normed[0] - 1.0).abs() < f32::EPSILON);
|
||||||
}
|
}
|
||||||
@@ -140,7 +130,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_normalize_clamps_low() {
|
fn test_normalize_clamps_low() {
|
||||||
let mut raw = [0.0f32; 14];
|
let mut raw = [0.0f32; 14];
|
||||||
raw[1] = -500.0; // min is 0
|
raw[1] = -500.0; // min is 1.0
|
||||||
let normed = normalize(&raw);
|
let normed = normalize(&raw);
|
||||||
assert!((normed[1] - 0.0).abs() < f32::EPSILON);
|
assert!((normed[1] - 0.0).abs() < f32::EPSILON);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,71 +1,74 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! Auto-generated weights for the ddos ensemble.
|
//! Auto-generated weights for the ddos ensemble.
|
||||||
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-ddos-mlp`.
|
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-ddos-mlp`.
|
||||||
|
|
||||||
pub const THRESHOLD: f32 = 0.50000000;
|
pub const THRESHOLD: f32 = 0.50000000;
|
||||||
|
|
||||||
pub const NORM_MINS: [f32; 14] = [
|
pub const NORM_MINS: [f32; 14] = [
|
||||||
0.08778746, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.05001374, 0.02000000,
|
0.00000000, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00833333, 0.02000000,
|
||||||
0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000
|
0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000
|
||||||
];
|
];
|
||||||
|
|
||||||
pub const NORM_MAXS: [f32; 14] = [
|
pub const NORM_MAXS: [f32; 14] = [
|
||||||
1000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49990082, 500.00000000, 1.00000000,
|
10000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49999166, 500.00000000, 1.00000000,
|
||||||
171240.28125000, 30.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000
|
171240.28125000, 30.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000
|
||||||
];
|
];
|
||||||
|
|
||||||
pub const W1: [[f32; 14]; 32] = [
|
pub const W1: [[f32; 14]; 32] = [
|
||||||
[0.57458097, -0.10861993, 0.14037465, -0.23486336, -0.43255216, 0.16347405, 0.71766937, 0.83138502, 0.02852129, 0.56590265, 0.54848498, 0.38098580, 0.82907754, 0.61539698],
|
[0.30719012, 0.54619288, -0.20061351, 0.29659060, 0.34138879, -0.19088192, 0.34866381, -0.24303232, 0.20615512, 0.12656690, -0.16653502, 0.10961400, -0.16814700, 0.09950374],
|
||||||
[0.06583579, 0.02305713, 0.89706898, 0.42619053, -1.20866120, -0.11974730, 1.70674825, -0.17969023, -0.26867196, 0.60768014, -0.08671998, -0.04825107, 0.58131427, -0.02062579],
|
[0.46002641, 0.49129087, -0.01386960, 0.64490628, 0.51850092, 0.69266915, 0.31095454, 0.26951542, -0.20926359, -0.09662568, 0.19281134, 0.36575633, 0.23089127, -0.03582983],
|
||||||
[-0.40264100, 0.18836430, 0.08431315, 0.33763552, 0.44880620, -0.40894085, -0.22044741, 0.00533387, -0.61574107, 0.07670992, 0.63528854, 0.48244709, 0.20411402, -1.80697525],
|
[0.04547556, -0.04854088, -0.10979871, -0.18705507, -0.44649851, 0.20949614, 0.73240960, 1.34691823, -0.13529004, 0.69439852, -0.40027520, 0.47921708, -0.43529814, 0.48781869],
|
||||||
[0.66713083, -0.22220801, 2.11234117, 0.41516641, -0.00165093, 0.65624571, 1.87509167, 0.63406783, -2.54182458, -0.53618753, 2.16407824, -0.61959583, -0.04717547, 0.17551991],
|
[0.61547452, 1.52229679, 0.48276836, 1.27171433, -0.36176509, -0.33192506, -0.82673991, 0.67331636, -0.21094124, 1.03067887, -0.09182073, 1.44520211, 0.52611661, 0.61176163],
|
||||||
[-0.51027024, -0.60132700, 0.46407551, -0.57346475, -0.30902353, -0.24235034, 0.08087540, -2.14762974, -0.29429656, 0.56257033, -0.26935315, -0.16799171, 0.56852734, 1.93494022],
|
[-0.37162983, 0.48245564, -0.53393066, -0.20009390, -0.06583384, 0.17612432, 0.59905756, 1.39533114, 0.67457062, 0.06159161, -0.56609136, -0.29591814, 0.55239469, -0.56801152],
|
||||||
[-0.24938971, -0.12699288, -0.13746630, 0.64942318, 0.09490766, -0.02158179, 0.72449303, -0.28493983, -0.43053114, -0.01443988, 0.89670080, -0.34539866, -1.47019410, 0.79477930],
|
[0.75815558, -0.64557153, 0.84678394, 0.41179815, -0.50619060, -0.09139232, -0.64594650, -0.74464273, 0.87102652, 0.81111395, -0.35027400, 0.95135874, 0.85043454, 1.72117484],
|
||||||
[0.62935185, 0.74686801, -0.15527052, -0.06635039, 0.73137009, 0.78417069, -0.06417987, 0.72259408, 0.85131824, 0.00477386, -0.14302900, 0.63481224, 0.92724019, -0.50126070],
|
[0.98888069, -0.04631047, -0.62931997, -0.37154037, 1.02817857, -1.04121590, 0.74848920, -0.26426360, 0.23142239, -0.17234743, -0.61689568, -0.59363395, -0.85756373, 0.53024006],
|
||||||
[-0.12699059, -0.15016419, -0.48704135, 0.00581611, 0.75824696, 0.84114397, -0.08958503, 0.18609463, 0.56247348, 0.22239330, 0.43324804, 0.82077771, 0.55714250, -0.56955606],
|
[0.59509462, 0.26622522, -0.74383926, 0.48256168, -0.75522244, 0.46806136, -0.62194610, 0.09251838, 0.25921744, 0.72987258, 0.66349596, 0.53999704, -0.25535119, -0.92465514],
|
||||||
[0.83457869, 0.40054807, 0.23281574, -0.58521581, 1.18067443, -0.49485078, 0.08600014, 0.99104887, -0.65019566, -0.44594154, 0.64507920, -0.61692268, -0.29301512, -0.11314666],
|
[0.50981742, -1.44853806, -0.64814043, -0.78203505, -0.88038790, 0.68278509, 0.58861315, 0.55924416, -0.52396554, 0.45195666, -0.44876143, -0.11349974, 0.64508075, 0.06376592],
|
||||||
[-0.07868081, -0.18392175, 0.15165123, 0.35139060, 0.13855398, 0.16470867, 0.21025884, 1.57204449, 0.07827333, 0.05895505, 0.00810917, 1.05159700, 0.04605416, 0.38080546],
|
[0.44494510, 0.79238343, -0.22128101, 3.13757062, 0.54972911, 2.06494117, -0.20301908, 0.48413971, -0.25992882, 0.77544886, -0.18115431, 1.87130582, 0.71965748, 1.95458603],
|
||||||
[1.47428405, -0.21614535, -0.35385504, 0.46582970, 1.26638246, 0.00133375, -3.85603786, 0.39766011, 1.92816520, 0.47828305, -0.16951409, -0.13771342, 0.49451983, 0.41184473],
|
[1.00518668, 1.80238068, -0.28449696, -0.02740687, -2.51049113, 0.56081659, -0.43591678, 3.59169340, -0.47954431, 1.82556272, 0.64387941, 0.56122434, -0.19696619, 3.49070907],
|
||||||
[-0.23364748, 0.68134952, 0.36342716, 0.02657196, 0.07550839, 0.94861823, -0.52908695, 0.83652318, -0.05639480, 0.26536962, 0.44137934, 1.20957208, -0.60747981, -0.50647283],
|
[-0.32992145, -0.03573111, 2.41438532, -0.00748284, 0.62775159, 1.78909039, 0.25103322, 0.59640545, -0.10183074, 0.83787775, 0.14171274, 0.08816884, 0.16381627, -0.04427620],
|
||||||
[-0.16961956, -0.49570882, -0.33771378, -0.28554109, 0.95865113, -0.49269623, -0.44559151, 1.28568971, 0.79537493, -0.53175420, -3.19015551, 0.52214253, 0.86517984, 0.62523192],
|
[0.09841868, 0.58517164, 0.02630968, 0.65797943, -0.03991833, 0.52833039, 0.37459302, 0.01832970, -1.20483434, 0.76000416, 0.02081347, 1.10453236, 0.46800232, 0.50707549],
|
||||||
[-0.16956513, -0.61727583, 0.63967121, 0.96406335, -0.28760204, 0.56459671, 0.78585202, -0.03668134, -0.14773002, -0.35764447, 0.84649116, -0.34540027, -0.12314465, -0.10070048],
|
[0.13568293, -0.04429439, 0.18404786, 0.74804515, -0.02402807, 0.25729915, 0.64555109, 0.09644510, 0.31338552, 0.62685025, -0.19832127, 1.95116663, 0.66340035, 1.29182386],
|
||||||
[-0.34183556, -0.07760386, 0.70894319, 0.92814171, -0.19357866, 0.41449037, 0.54653358, 0.27682835, 0.81471086, 0.56383932, 0.57456553, -0.61491662, 0.92498505, 0.74495614],
|
[-0.20969683, 0.56657153, -0.08705560, 0.71007556, -0.11011623, 1.16174579, 0.65050489, 1.31441426, 0.72755563, 1.15947676, -0.34925875, 0.01019314, -1.42810500, 0.14942981],
|
||||||
[-0.38917324, -0.29217750, 1.43508542, -0.19152534, -0.18823336, 0.45097819, -0.38063127, -0.40419811, 0.56686693, -0.33231607, -0.19567636, -0.02500075, -0.04762971, 0.44703853],
|
[-0.47017330, 2.62149596, -0.37532449, 1.17488575, 0.62930888, 0.62195790, -0.12959687, 2.36229849, -0.25786853, 0.03494137, 1.70790768, -0.02720823, 0.57822198, 1.57692003],
|
||||||
[1.14234805, -0.62868208, -0.21298689, 0.00263968, -0.66115338, -1.12038326, 0.93599045, 0.77646011, -0.22770278, 1.43982041, 0.96078646, 1.15076077, -0.45110813, 0.83090556],
|
[-0.68229634, 0.87380433, 1.03171849, 0.35238963, -0.78998542, 0.97562903, -0.80616480, 1.07170749, -0.79917014, -0.43357334, 1.09133816, 0.49446958, 1.07970095, 0.27838916],
|
||||||
[0.89638984, -0.69683450, -0.29400119, 0.94997799, 0.90305328, -0.80215877, -0.09983492, -0.90757453, -0.03181892, 1.00702441, -0.97962254, -0.89580274, 0.69299418, -0.75975400],
|
[-0.77235895, -0.66010702, -0.09969614, -0.38052577, -0.77211934, -0.73416811, -0.67031443, 0.62016815, 0.97461295, 1.07167208, -0.68821293, 0.51563287, -0.73027885, -0.14203216],
|
||||||
[-0.75832003, -0.07210776, 0.07825917, 1.51633596, 0.44593197, 0.00936707, -0.12142835, -0.09877282, 0.06229200, 1.25678349, 0.25317946, 0.54112315, -0.17941843, 0.93283361],
|
[0.90449816, -0.23423387, 1.11039567, 0.61329746, -0.21385542, 0.52449727, 0.42514217, 0.42172486, -0.33397049, 0.35888657, -0.54074812, 0.48481938, -0.05116262, -0.23848286],
|
||||||
[0.23085761, 0.53307736, 0.38696140, 0.36798462, 0.38192499, 0.23203450, 0.68225187, 0.47096270, -4.24785280, 0.18062039, 0.60047084, 0.16251479, -0.10811257, 0.48166662],
|
[0.67948169, 0.50562781, 0.45344570, 0.47307885, -0.44913152, -0.11515936, 0.14361705, -0.36479098, -0.32777452, 0.11798909, -0.57137913, 0.30936614, 0.31339252, 0.51131296],
|
||||||
[0.10870802, 0.01576116, 0.00298645, 0.25878090, -0.16634797, 0.15850464, -0.24267951, 0.87678236, -0.27257833, 0.78637868, -0.00851476, 0.01502728, 0.92175138, -0.81292266],
|
[-0.25677630, 0.25580657, -0.12398625, 0.24844812, 0.18556698, 0.21818036, 0.58248550, 0.50517905, 0.34329867, 0.15851928, -0.58440667, 0.33611965, 0.67439252, -0.52770680],
|
||||||
[-0.74364990, -0.63139439, -0.18314177, -0.36881343, -0.53096825, -0.92442876, -0.05536628, -0.71273297, -0.94937468, -0.03863344, -0.09668982, -1.07886386, 0.58555382, 0.23351164],
|
[0.66840053, -0.49819222, 0.29022828, 0.10492916, -0.06216156, 0.37093312, -0.24731418, 0.22893915, 0.32447502, 0.63166237, -0.13788179, 0.52650315, 0.15229015, 0.23656118],
|
||||||
[-0.09152136, 0.96538877, -0.11560653, -0.53110164, 0.89070886, 0.05664408, -0.71661353, 0.79684436, -0.00206013, 0.23857179, 0.06074178, -0.67188424, -0.15624331, 0.43436247],
|
[0.33978519, 0.15498674, -0.25265032, -0.42916322, 0.69121236, 0.20443739, 0.54050952, 0.08900955, -0.13801514, 0.25456557, -0.10714018, 0.08712567, 0.27245566, -0.29683220],
|
||||||
[-0.28189376, -0.00535834, 0.60541785, 0.82968009, -0.21901314, -0.29874969, -0.16872653, 0.45570841, -0.25372767, -0.12359514, -1.10104620, 0.00162374, 0.07622175, 0.60413152],
|
[-0.05526243, -1.17294025, 0.07328646, -0.07892461, 0.31488195, -0.01112767, 0.55462092, 0.65152955, -0.10721418, -0.99451303, 0.00110284, -0.53097665, 0.14362922, 0.17380728],
|
||||||
[-1.13819373, -0.41320390, -5.57348347, 0.40931624, -1.59562767, 0.72510892, 0.03248254, 0.00407641, 0.57557869, 0.53510398, -0.35943517, 0.52707136, 0.61220711, -0.11644226],
|
[-2.95768332, 0.46451911, 0.20220210, 0.76858771, 0.13804838, -0.80371422, -0.11160404, -0.15160939, 0.31488597, -0.10203149, 0.16458754, -0.08558689, -0.27082649, 0.03877234],
|
||||||
[-0.02057049, 0.42545527, 0.24192038, 0.29863021, -0.22839858, -0.25318733, 0.17906551, -0.29471490, -0.04746799, 0.15909556, -0.26826856, -0.06874973, -0.03044286, 0.11770450],
|
[-0.14654562, -0.70086712, 0.09809728, -0.60966188, 0.26278028, 0.07354698, 0.08616283, 0.36018923, 0.07040872, 0.41008693, 0.13071685, 0.18236822, 0.43306109, -0.10742717],
|
||||||
[-0.18060833, -0.06301155, 0.01656315, -0.40476608, -0.35056075, 0.06344713, 0.32273614, -0.04382812, -0.18925793, 0.02124963, -0.23447622, 0.29704437, 0.19138981, -0.04584064],
|
[0.41488791, -0.10255218, 0.10218169, 0.21971215, -0.05527666, -0.50265622, 0.06767768, -0.09040122, 0.16871217, -0.02748547, 0.21738021, 0.21068999, 0.10562737, -0.71913630],
|
||||||
[0.18248987, 0.05461208, -0.25655189, 0.16673982, 0.03251073, 0.05709980, 0.09135589, 0.06712578, -0.02372392, 0.00487196, -0.11774579, 0.34203079, 0.18477952, 0.09847298],
|
[0.09367306, -0.14113051, -0.44151428, -0.05189204, 0.22411002, -0.09538609, 0.17464676, -0.30709952, 0.21021855, -0.27705607, 0.17645715, 0.19070518, 0.18094100, 0.10115600],
|
||||||
[-0.08292723, -0.03089223, 0.19555064, -0.18158682, -0.32060555, 0.18836822, -0.14625609, -0.83500093, -0.09893667, 0.02719803, 0.06864946, 0.00156752, 0.04342323, 0.30958080],
|
[-0.11084171, -0.60070217, 0.10072551, -0.09865215, 0.19512057, -0.32474023, 0.14499906, 0.06266983, -0.15383074, 0.10347557, -0.10143858, -0.09821036, -0.19187087, -0.21955618],
|
||||||
[-0.21274266, 0.06035644, 0.27282441, -0.01010289, -0.05599894, 0.27938741, -0.23254848, -0.20086342, -0.06775926, -0.18059292, 0.92534143, 0.09500337, 0.11612320, -0.06473339],
|
[0.09774263, 0.21607652, -0.22068830, -0.73502982, 0.14551027, -0.00246539, -0.32017741, -0.14855191, 0.15684886, -0.21544383, -0.36595181, 1.57917106, 0.45341989, 0.64960009],
|
||||||
[-0.27279299, 0.96252358, 0.67542273, 0.64720130, 0.15221471, 1.67354584, 0.53074431, 0.65513390, 0.79840666, 0.78613347, 0.34742561, -1.83272552, 0.73313516, 0.09797212],
|
[0.03760023, 0.12075356, -0.24193284, 0.16418910, -0.13468136, 0.40612614, 0.44222566, 0.17999728, 0.37591749, 0.67439985, -0.29388478, -0.20486754, 0.20614263, -2.63525987],
|
||||||
[-0.08888317, 0.14851266, 1.00953877, 0.19915256, -0.10076691, 0.47210938, 0.04427565, 0.19299655, 0.58729172, 0.17481442, -0.57466495, -0.16196120, 0.06293163, 1.73905540],
|
[-0.10479005, -0.17017230, -0.42374054, 0.30094361, 0.28561834, 0.40433934, -0.03086211, 1.49869466, -0.41601327, 0.20835553, 1.19875181, -0.00222666, 0.51400107, 0.27829245],
|
||||||
];
|
];
|
||||||
|
|
||||||
pub const B1: [f32; 32] = [
|
pub const B1: [f32; 32] = [
|
||||||
-0.80723554, 0.54879200, 0.01237706, -0.22279924, 0.93692911, 0.12226531, -0.54665250, -0.49958101,
|
0.76754266, -0.52365464, -0.07451479, -0.24194083, 0.81372803, -0.14967601, 0.86968440, -0.11282827,
|
||||||
-0.20918398, -0.48646352, -0.58741039, -0.50572610, -0.04772990, -0.62962151, -0.46279392, 1.14840722,
|
0.82378083, 0.03708726, -0.14121835, -0.33332673, -0.24595253, -0.20005627, 0.80769247, 0.67842513,
|
||||||
-0.04871057, -0.31787100, 1.13966286, 0.69543558, -0.17798270, 0.66968435, -0.07442535, -0.70557600,
|
0.62225562, 0.55104679, 0.87356585, -0.16369765, 0.83232063, -0.40881905, -0.02851989, -0.04714838,
|
||||||
0.79021728, 0.65736526, -0.30761406, 0.63242179, 0.83297908, -0.04573143, -0.18454255, -0.30583009
|
0.69236869, -0.30938062, 0.87852216, -0.14689557, -0.52630597, -0.22946648, -0.13811214, -0.41019145
|
||||||
];
|
];
|
||||||
|
|
||||||
pub const W2: [f32; 32] = [
|
pub const W2: [f32; 32] = [
|
||||||
1.09615684, -0.57856798, -0.08730038, -0.06425755, -0.96232760, -2.06290460, 0.70097560, 0.85189444,
|
-0.84622073, 2.32144451, 0.70330697, 0.89360833, -1.08053613, 0.69213301, -1.07218480, 0.82345659,
|
||||||
-0.10077959, 1.94375157, 0.74497795, 0.88425481, 2.11908054, 0.85526127, 0.61624259, -2.93621016,
|
-1.11953294, -2.58824420, 0.81520051, 1.19865966, 0.91804677, 1.04554057, -1.03049874, -0.94034135,
|
||||||
1.52211487, 0.56318259, -3.15219641, -0.55187315, 1.61819077, -0.76258671, -0.09362544, 0.86861998,
|
-1.66193688, -1.53192282, -1.09629154, -4.07772017, -1.14778209, 1.15202129, 0.42650393, 0.55174673,
|
||||||
-0.79028755, -0.90605170, 0.33475992, -0.79945564, -1.16680586, 0.15120529, 0.17619221, 1.61664009
|
-1.28319669, 2.06129408, -1.10220599, 0.09728605, 1.64764512, -0.14975634, 0.79428691, 1.56726408
|
||||||
];
|
];
|
||||||
|
|
||||||
pub const B2: f32 = -0.52729088;
|
pub const B2: f32 = -0.58103424;
|
||||||
|
|
||||||
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
|
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
|
||||||
(3, 0.30015790, 1, 2),
|
(10, 0.13999981, 1, 2),
|
||||||
(255, 0.00000000, 0, 0),
|
|
||||||
(255, 1.00000000, 0, 0),
|
(255, 1.00000000, 0, 0),
|
||||||
|
(255, 0.00000000, 0, 0),
|
||||||
];
|
];
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! Auto-generated weights for the scanner ensemble.
|
//! Auto-generated weights for the scanner ensemble.
|
||||||
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-scanner-mlp`.
|
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-scanner-mlp`.
|
||||||
|
|
||||||
@@ -14,55 +17,55 @@ pub const NORM_MAXS: [f32; 12] = [
|
|||||||
];
|
];
|
||||||
|
|
||||||
pub const W1: [[f32; 12]; 32] = [
|
pub const W1: [[f32; 12]; 32] = [
|
||||||
[2.25848985, 1.62502551, 1.05068624, -0.23875977, 1.29692984, -1.34665418, -1.29937541, 1.66119707, -1.43897200, 0.07720046, -1.17116165, 1.96821272],
|
[-0.24775992, -1.12687504, -0.84061003, 1.35276508, 1.04677176, 1.57832956, 1.47995067, 1.38580477, -0.99564040, -1.20309269, -0.24385734, 1.32367671],
|
||||||
[2.30885172, 0.02477695, -0.23236598, 1.66507626, -1.41407740, 1.88616431, 1.84703696, -1.46395433, 2.03542018, 1.68318951, 2.01550031, 1.94223917],
|
[-1.31796920, -0.12229957, 0.89794689, 1.14832735, 1.17210162, 1.32387733, 1.37799740, 1.22984815, 0.92816162, 1.45189691, 0.97822803, -0.89625973],
|
||||||
[2.29420924, 1.86615539, 1.69271469, 1.42137837, 1.43151915, 1.84876072, 1.09228194, 1.73608077, 0.20805965, 0.52542430, -0.02558800, 0.04718366],
|
[-1.46790743, -1.43995631, 1.44276273, 1.30967343, 1.37424576, -1.16613543, 1.46063673, -1.29701447, 0.20172349, -0.05374112, -0.02462341, 0.25295240],
|
||||||
[0.36484259, -0.02785611, -0.01155548, 0.08577330, -0.00468449, -0.07848717, 0.05191587, 0.50796396, 0.40799347, -0.14838840, -0.30566201, 0.00758083],
|
[0.57722956, 0.39603359, -0.03453349, -0.06384056, 0.08056496, 0.30015573, -0.19045275, 0.39737019, 0.41070747, 0.03414693, -0.25483453, 0.24018581],
|
||||||
[0.28191370, 0.20945202, 0.07742970, -0.06654347, 0.17395714, 0.00011351, 0.37079588, 0.41817516, 0.56992871, 0.05705916, 0.22339216, 0.11021475],
|
[0.02473495, 0.05047939, 0.59600842, 0.35709319, -0.32843903, 0.17655721, 0.20441811, 0.05118980, 0.24312067, 0.12677322, 0.12109834, -0.09230414],
|
||||||
[0.06522971, 0.64510870, 0.31671444, 0.34980071, 0.03446164, -0.10592904, -0.21302676, -0.04404496, 0.08638768, 0.04217484, 0.43021953, 0.21055792],
|
[-0.03397392, -0.08393703, 0.42762470, 0.21609549, -0.28864771, -0.11998425, -0.05842599, 0.87331069, 0.52526826, 0.35321006, 0.31293729, 0.51823896],
|
||||||
[0.31206250, -0.14565454, 0.38078794, 0.00860748, 0.29409558, -0.11273954, -0.02210701, 0.15525217, 0.09696059, 0.13877581, 0.06483351, 0.10946950],
|
[-0.26906690, -0.04246650, 0.05114691, 0.27629042, -0.37589836, -0.07396606, -0.46509269, 0.26079515, 0.33588448, 0.55919188, 0.58768439, 0.81638134],
|
||||||
[0.28374705, -0.02963164, 0.27863786, -0.23428085, 0.12715313, 0.09141072, 0.07769041, 0.01915955, -0.20936646, 0.02813511, -0.03910714, 0.30322370],
|
[-0.48464304, 0.76111192, 0.06296005, -0.13527128, -0.41344830, -0.19461812, 0.52815431, 0.96815300, 0.47175986, -0.13977771, 0.41216326, -0.03041369],
|
||||||
[-1.19449413, -0.84935474, -0.32267663, -0.08140022, -0.78729230, 1.58759272, 0.88281459, -0.77263606, 1.55394125, 0.10148179, 1.59524822, -0.75499195],
|
[-0.05156134, 0.90191734, 0.78513384, -1.25017786, -0.54637259, -0.36098620, -0.59882820, -0.90374511, 0.91336489, 1.05806208, 0.04994302, -0.91028821],
|
||||||
[-0.97152823, -0.12173092, 0.04745778, -0.85466659, 1.57352293, -0.52651149, -0.66270715, 1.32282484, -1.24654925, -0.45822921, -1.10187364, -0.91162699],
|
[0.68376511, 0.07305489, 0.44790089, -0.56647295, -0.66570538, -0.93897015, -0.40558589, -1.51070845, 0.46759781, -1.36738360, -0.05270236, 0.98130196],
|
||||||
[-0.93944395, -0.57891464, -1.12100291, -0.38871467, -0.18780440, -1.11835766, -0.43614236, -1.07918274, -0.09222561, -0.23854440, -0.16720718, 0.03247443],
|
[0.86788160, 0.69321704, -0.53778958, -1.54257190, -0.44623125, 0.72615588, -0.75269628, 0.81946337, 0.17503875, 0.63745797, 0.48478079, -0.31573632],
|
||||||
[0.13319625, 0.87437463, 0.32213065, 0.13902900, 0.64760798, 0.00899744, 0.45325586, -0.14138180, 0.13888212, 0.07780524, -0.12482210, 0.12632932],
|
[-0.01361719, 0.21524477, -0.10345778, -0.38488832, 0.42967409, 0.75472528, -0.07410870, -0.65231675, 0.42633417, -0.10289414, 0.09583388, -0.29391766],
|
||||||
[0.57018995, -0.10839911, 0.02787536, 0.16884641, 0.19435850, -0.01189608, 0.13881874, -0.10700739, -0.05463003, 0.01371983, 0.04385772, 0.01100468],
|
[-0.06778818, -0.44469842, 0.05952910, -0.55139810, 0.14308600, -0.53731138, -0.07426350, 0.28065708, 0.29584157, 0.47813708, 0.02095048, -0.36458421],
|
||||||
[-0.26600277, -0.11843663, -0.01081531, 0.10785927, -0.18684258, 0.08537511, 0.01054722, -0.01972559, -0.07416820, 0.57192892, 0.37873995, -0.00498434],
|
[-0.21983556, 0.55435538, -0.13939659, 0.58281261, -0.20551582, 0.30075905, 0.13396217, -0.18145087, -0.43283740, 0.18541494, 0.07530790, -0.04916608],
|
||||||
[0.72535324, -0.25030360, 0.51470703, -0.16410951, -0.13649474, 0.16246459, -0.27847841, 0.12250750, 0.45576489, -0.18535912, -0.45686084, 0.58293521],
|
[0.46939296, 0.57935077, -0.05478116, -0.01144989, 0.54106784, -0.18313073, 0.12232503, -0.32802504, -0.01167463, -0.13702804, -0.19521871, 0.09115479],
|
||||||
[0.18614589, -0.32835677, -0.08683094, 0.07748202, -0.24785264, -0.16834147, 0.27066526, 0.06058804, 0.01903199, -0.17387865, 0.12752151, -0.03780220],
|
[0.25263065, -0.27172634, -0.12802953, 0.50027740, 0.05213343, 0.49081728, 0.15367918, 0.20471051, -0.22081012, 0.51709008, 0.01776243, 0.22513707],
|
||||||
[-1.22358644, -0.78316134, -0.54068804, -0.07921790, -0.72697675, 1.80127227, 0.14326867, -0.51875746, 1.83125353, -0.02672976, 1.68589675, -0.80162954],
|
[-0.20733139, 0.60041994, 0.05273124, 0.07473211, 0.14580894, -0.72007078, -0.52350652, -0.15482022, 0.19132918, 0.52586436, -0.04793828, -0.00479114],
|
||||||
[-0.83690810, -0.12682360, 0.10783038, -0.64648604, 1.50810242, -0.48788729, -0.59418935, 0.94863659, -0.84788662, -0.49779284, -0.96408021, -1.14068258],
|
[0.19872986, -0.19177110, -0.22340146, -0.48786804, -0.51010352, -0.55363113, -0.29520389, -0.21378680, -0.40099174, -0.09184421, -0.08521358, 0.61833692],
|
||||||
[-0.96322638, -0.50503486, -0.87195945, -0.34710455, -0.28645220, -1.10507452, -0.32122782, -0.80753750, -0.00843489, 0.04215550, 0.03197355, 0.05468401],
|
[0.21346046, 0.53319895, -0.44765636, -0.04764151, -0.30569363, 0.19765340, -0.41479719, 0.34292534, -0.29234713, 0.54341668, 0.60121793, -0.00226344],
|
||||||
[-0.17587705, 0.45144933, 0.37954769, -0.15405300, 0.75590396, 0.00346784, 0.62332457, -0.15602241, -0.26471916, -0.19963606, -0.22497311, -0.20784236],
|
[-0.29598647, -0.37357926, -0.25650844, -0.05165816, 0.55829030, 0.21028350, -0.28581545, -0.37299931, 0.57590896, -0.01573592, -0.19411144, 0.13814686],
|
||||||
[0.60608941, -0.05316854, 0.03766245, 0.46412235, -0.41121334, -0.01225545, -0.11125158, -0.33533856, -0.04625564, -0.02995013, -0.24979964, -0.35824969],
|
[0.10028259, 0.02526089, -0.20488358, 0.25667843, 0.17100072, 0.01034015, -0.32994771, 0.53425753, 0.64935833, 0.30769956, -0.26756367, -0.03389005],
|
||||||
[0.08163761, 0.04702193, -0.24007457, -0.23439978, 0.27066308, 0.48389259, 0.32692793, -0.23089454, 0.26520243, -0.14099684, 0.06713670, 0.14434725],
|
[-0.18916467, 0.38340616, -0.16475976, 0.59811211, 0.12739281, -0.16611671, -0.31913927, 0.07577144, 0.28552490, 0.54843456, 0.40937552, 0.38236183],
|
||||||
[-0.50808382, -0.14518137, -0.23912378, 0.33510539, 0.46566108, 0.09035082, -0.12637842, 0.55245715, -0.19972627, 0.24517706, 0.34291887, 0.01936621],
|
[-0.20519450, -0.04122134, -0.20013523, 0.42193425, -0.27304563, -0.21811043, 0.13115846, 0.16724831, 0.13073303, 0.20491999, 0.31806493, 0.13444173],
|
||||||
[0.35826349, 0.21200819, 0.65315312, -0.16792546, 0.41378024, 0.32129642, 0.50814188, 0.48289016, 0.06839173, 0.42079177, 0.52295685, 0.26273951],
|
[0.01762132, 0.32608625, 0.19381267, -0.33404192, -0.46299583, -0.28042898, 0.20772585, 0.20139317, 0.41952321, -0.30363685, 0.20015827, -0.03338646],
|
||||||
[0.24575019, 0.10700949, 0.07041252, -0.09410189, 0.18897925, 0.31616825, -0.01306109, 0.33499330, -0.01866218, 0.06233863, 0.15316568, 0.08370106],
|
[0.13760759, 0.07168494, 0.26161709, 0.41468662, -0.03778528, 0.38290465, 0.48780030, 0.39562985, 0.24758396, -0.05975538, -0.22738078, 0.27877593],
|
||||||
[0.17828286, 0.17363867, -0.10626584, 0.06075979, 0.39465010, 0.19557165, 0.30352867, 0.26720291, 0.40256795, 0.13942246, 0.05869288, 0.08310238],
|
[0.07016940, -0.03804595, -0.08812129, 0.19664441, 0.13347355, 0.50309300, 0.26076415, 0.19044210, -0.20414594, 0.64333421, 0.15160090, 0.16449226],
|
||||||
[-0.04834138, 0.29206491, 0.01330532, 0.07626399, -0.17378819, 0.09515948, 0.02298534, 0.41555724, 0.09492048, 0.39422533, 0.39373979, 0.20463347],
|
[0.31039700, -0.01906084, 0.25622010, 0.10707659, 0.54883337, 0.19277412, 0.42004701, -0.09319381, 0.19968294, 0.07109389, -0.28979829, 0.12353907],
|
||||||
[-0.11641891, -0.06529939, -0.18899654, -0.02157970, -0.03554495, 0.10956290, -0.11688691, 0.04077352, 0.34220406, -0.09558969, 0.16150762, 0.25759667],
|
[0.28500485, 0.01991569, 0.05190456, 0.29366553, 0.01045146, -0.02013574, -0.01796320, 0.13775185, 0.11095868, -0.25678155, 0.10733776, -0.07584792],
|
||||||
[-0.17313123, 0.00591523, 0.29443163, 0.08298909, 0.07761172, 0.19023541, 0.23826212, -0.07167042, 0.08753359, 0.17917964, -0.03248737, 0.28516129],
|
[0.12738188, 0.07762879, -0.06429479, 0.39944342, 0.07958066, 0.46697047, -0.10674930, -0.12212183, -0.01540831, 0.08788434, 0.17299946, 0.25846422],
|
||||||
[0.13091524, 0.21435370, 0.15093684, 0.30902347, 0.44151527, 0.55901742, 0.19933179, 0.06438518, 0.30585650, -0.34089112, 0.26879075, 0.12928906],
|
[0.26692817, 0.00930361, 0.24862845, 0.02167275, -0.09902105, -0.35391217, -0.41734406, 0.44949567, 0.46330830, 0.40603620, 0.08397861, 0.39809385],
|
||||||
[-0.25311065, -0.09963353, -0.50099874, 0.57481062, 0.38744658, -0.13065037, 0.18897361, 0.49376330, -0.15626629, 0.19911517, 0.06437352, -0.09104283],
|
[-0.30756459, -0.43368185, -0.00478506, 0.45611116, -0.05069341, 0.21090019, 0.28219289, 0.07687758, 0.54915971, 0.46933413, 0.35599890, 0.17573997],
|
||||||
[0.35787049, -0.04814727, 0.45446551, -0.15264697, 0.36565515, 0.22795495, 0.24630190, 0.16362202, 0.21044184, 0.53882843, 0.42343852, 0.18454899],
|
[-0.19320646, 0.44751191, -0.14140815, 0.00427075, -0.19792002, -0.19400074, 0.19292155, 0.39845818, 0.21028778, -0.10284913, 0.31191504, -0.36995885],
|
||||||
];
|
];
|
||||||
|
|
||||||
pub const B1: [f32; 32] = [
|
pub const B1: [f32; 32] = [
|
||||||
1.12135851, 0.64268047, 0.44761124, -0.28471574, 0.70866716, -0.25293177, -0.19119856, 0.39284116,
|
-0.24357778, -0.24826238, -0.03415382, 0.00968227, 0.51550633, 0.45242083, 0.60654080, 0.25456131,
|
||||||
-0.20628852, -0.29301032, -0.08837436, 0.92048728, 0.91167349, -0.33615190, -0.06016272, 0.79141164,
|
-0.36509025, -0.22825000, 0.03829522, 0.65561563, -0.19379658, -0.25716159, 0.45115772, 0.73442084,
|
||||||
-0.43257964, 0.48180589, 0.70891160, -0.24290052, 0.83115542, 0.69964927, 0.97887653, 1.34517038,
|
0.61352992, 0.59502298, 0.32757106, 0.28512844, 0.26663530, 0.27169749, 0.33571365, -0.34503689,
|
||||||
1.10292709, 0.42009205, 1.07155228, 0.61349720, 0.46157768, 1.01911950, 0.51159418, 0.60460496
|
-0.08054741, -0.06313029, 0.43629149, 0.35936099, 0.39375633, -0.19984132, 0.49092621, -0.27418151
|
||||||
];
|
];
|
||||||
|
|
||||||
pub const W2: [f32; 32] = [
|
pub const W2: [f32; 32] = [
|
||||||
1.55191231, 1.27754235, 0.43588921, 0.10868450, 0.55931729, -1.46911597, -0.54461092, 0.78240824,
|
-0.04975564, -0.54667705, -0.39323062, 0.72362727, 0.86801738, 1.93621075, 1.01259410, 0.75978750,
|
||||||
-1.25938582, -0.06287600, -1.02053738, 1.07076716, 1.58776867, -0.03168033, -0.11393511, 1.30535436,
|
-0.67997259, -0.63063931, -0.07149173, 0.81899148, -0.69025612, -0.12359849, 1.09533453, 0.88092262,
|
||||||
-1.46621227, 0.62925971, 0.76781118, -0.74480098, 1.29669034, 0.62078375, 1.64134884, 2.09736991,
|
0.89678788, 0.87908030, 1.12460852, 0.76745653, 0.85632098, 0.72992527, 0.93983871, -0.55915666,
|
||||||
1.52834618, 0.87368065, 1.80090642, 0.89230227, 0.38757962, 1.80718291, 0.64923352, 1.18709576
|
-0.61104172, -0.56369978, 1.43480921, 0.71174467, 1.03119624, -0.57950914, 0.81188917, -0.78019017
|
||||||
];
|
];
|
||||||
|
|
||||||
pub const B2: f32 = 0.23270580;
|
pub const B2: f32 = 0.24626280;
|
||||||
|
|
||||||
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
|
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
|
||||||
(3, 0.50000000, 1, 2),
|
(3, 0.50000000, 1, 2),
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! Replay audit logs through the ensemble models (scanner + DDoS).
|
//! Replay audit logs through the ensemble models (scanner + DDoS).
|
||||||
|
|
||||||
use crate::ddos::audit_log::{self, AuditLog};
|
use crate::audit::AuditLogLine;
|
||||||
|
use crate::ddos::audit_log;
|
||||||
use crate::ddos::features::{method_to_u8, LogIpState};
|
use crate::ddos::features::{method_to_u8, LogIpState};
|
||||||
use crate::ddos::model::DDoSAction;
|
use crate::ddos::model::DDoSAction;
|
||||||
use crate::ensemble::ddos::{ddos_ensemble_predict, DDoSEnsemblePath};
|
use crate::ensemble::ddos::{ddos_ensemble_predict, DDoSEnsemblePath};
|
||||||
@@ -26,21 +30,31 @@ pub fn run(args: ReplayEnsembleArgs) -> Result<()> {
|
|||||||
std::fs::File::open(&args.input).with_context(|| format!("opening {}", args.input))?;
|
std::fs::File::open(&args.input).with_context(|| format!("opening {}", args.input))?;
|
||||||
let reader = std::io::BufReader::new(file);
|
let reader = std::io::BufReader::new(file);
|
||||||
|
|
||||||
// --- Parse all entries ---
|
// --- Parse all entries, filtering for audit logs only ---
|
||||||
let mut entries: Vec<AuditLog> = Vec::new();
|
let mut entries: Vec<AuditLogLine> = Vec::new();
|
||||||
let mut parse_errors = 0u64;
|
let mut skipped_non_audit = 0u64;
|
||||||
|
let mut schema_errors = 0u64;
|
||||||
for line in reader.lines() {
|
for line in reader.lines() {
|
||||||
let line = line?;
|
let line = line?;
|
||||||
if line.trim().is_empty() {
|
if line.trim().is_empty() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
match serde_json::from_str::<AuditLog>(&line) {
|
match AuditLogLine::try_parse(&line) {
|
||||||
Ok(e) => entries.push(e),
|
Ok(Some(entry)) => entries.push(entry),
|
||||||
Err(_) => parse_errors += 1,
|
Ok(None) => skipped_non_audit += 1,
|
||||||
|
Err(e) => {
|
||||||
|
schema_errors += 1;
|
||||||
|
if schema_errors <= 3 {
|
||||||
|
eprintln!(" schema error: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let total = entries.len() as u64;
|
let total = entries.len() as u64;
|
||||||
eprintln!("parsed {} entries ({} parse errors)\n", total, parse_errors);
|
eprintln!(
|
||||||
|
"parsed {} audit entries ({} non-audit skipped, {} schema errors)\n",
|
||||||
|
total, skipped_non_audit, schema_errors,
|
||||||
|
);
|
||||||
|
|
||||||
// --- Scanner replay ---
|
// --- Scanner replay ---
|
||||||
eprintln!("═══ Scanner Ensemble ═════════════════════════════════════");
|
eprintln!("═══ Scanner Ensemble ═════════════════════════════════════");
|
||||||
@@ -54,7 +68,7 @@ pub fn run(args: ReplayEnsembleArgs) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn replay_scanner(entries: &[AuditLog]) {
|
fn replay_scanner(entries: &[AuditLogLine]) {
|
||||||
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
||||||
.iter()
|
.iter()
|
||||||
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
||||||
@@ -73,23 +87,15 @@ fn replay_scanner(entries: &[AuditLog]) {
|
|||||||
let mut blocked = 0u64;
|
let mut blocked = 0u64;
|
||||||
let mut allowed = 0u64;
|
let mut allowed = 0u64;
|
||||||
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
|
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
|
||||||
let mut blocked_examples: Vec<(String, String, f64)> = Vec::new(); // (path, reason, score)
|
let mut blocked_examples: Vec<(String, String, String, f64)> = Vec::new(); // (path, ua, reason, score)
|
||||||
let mut fp_candidates: Vec<(String, u16, f64)> = Vec::new(); // blocked but had 2xx status
|
let mut fp_candidates: Vec<(String, String, u16, f64)> = Vec::new(); // blocked but had 2xx status
|
||||||
|
|
||||||
for e in entries {
|
for e in entries {
|
||||||
let f = &e.fields;
|
let f = &e.fields;
|
||||||
let host_prefix = f.host.split('.').next().unwrap_or("");
|
let host_prefix = f.host.split('.').next().unwrap_or("");
|
||||||
let has_cookies = f.has_cookies.unwrap_or(false);
|
let has_cookies = f.has_cookies;
|
||||||
let has_referer = f
|
let has_referer = !f.referer.is_empty() && f.referer != "-";
|
||||||
.referer
|
let has_accept_language = !f.accept_language.is_empty() && f.accept_language != "-";
|
||||||
.as_ref()
|
|
||||||
.map(|r| r != "-" && !r.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let has_accept_language = f
|
|
||||||
.accept_language
|
|
||||||
.as_ref()
|
|
||||||
.map(|a| a != "-" && !a.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let feats = features::extract_features_f32(
|
let feats = features::extract_features_f32(
|
||||||
&f.method,
|
&f.method,
|
||||||
@@ -98,7 +104,7 @@ fn replay_scanner(entries: &[AuditLog]) {
|
|||||||
has_cookies,
|
has_cookies,
|
||||||
has_referer,
|
has_referer,
|
||||||
has_accept_language,
|
has_accept_language,
|
||||||
"-",
|
&f.accept,
|
||||||
&f.user_agent,
|
&f.user_agent,
|
||||||
f.content_length,
|
f.content_length,
|
||||||
&fragment_hashes,
|
&fragment_hashes,
|
||||||
@@ -121,12 +127,13 @@ fn replay_scanner(entries: &[AuditLog]) {
|
|||||||
if blocked_examples.len() < 20 {
|
if blocked_examples.len() < 20 {
|
||||||
blocked_examples.push((
|
blocked_examples.push((
|
||||||
f.path.clone(),
|
f.path.clone(),
|
||||||
|
f.user_agent.clone(),
|
||||||
verdict.reason.to_string(),
|
verdict.reason.to_string(),
|
||||||
verdict.score,
|
verdict.score,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if (200..400).contains(&f.status) {
|
if (200..400).contains(&f.status) {
|
||||||
fp_candidates.push((f.path.clone(), f.status, verdict.score));
|
fp_candidates.push((f.path.clone(), f.user_agent.clone(), f.status, verdict.score));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ScannerAction::Allow => allowed += 1,
|
ScannerAction::Allow => allowed += 1,
|
||||||
@@ -159,8 +166,9 @@ fn replay_scanner(entries: &[AuditLog]) {
|
|||||||
|
|
||||||
if !blocked_examples.is_empty() {
|
if !blocked_examples.is_empty() {
|
||||||
eprintln!("\n blocked examples (first 20):");
|
eprintln!("\n blocked examples (first 20):");
|
||||||
for (path, reason, score) in &blocked_examples {
|
for (path, ua, reason, score) in &blocked_examples {
|
||||||
eprintln!(" {:<50} {reason} (score={score:.3})", truncate(path, 50));
|
eprintln!(" {:<50} {reason} (score={score:.3})", truncate(path, 50));
|
||||||
|
eprintln!(" ua: {}", truncate(ua, 72));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,16 +178,17 @@ fn replay_scanner(entries: &[AuditLog]) {
|
|||||||
"\n potential false positives (blocked but had 2xx/3xx): {}",
|
"\n potential false positives (blocked but had 2xx/3xx): {}",
|
||||||
fp_count
|
fp_count
|
||||||
);
|
);
|
||||||
for (path, status, score) in fp_candidates.iter().take(10) {
|
for (path, ua, status, score) in fp_candidates.iter().take(10) {
|
||||||
eprintln!(
|
eprintln!(
|
||||||
" {:<50} status={status} score={score:.3}",
|
" {:<50} status={status} score={score:.3}",
|
||||||
truncate(path, 50)
|
truncate(path, 50)
|
||||||
);
|
);
|
||||||
|
eprintln!(" ua: {}", truncate(ua, 72));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) {
|
fn replay_ddos(entries: &[AuditLogLine], window_secs: f64, min_events: usize) {
|
||||||
fn fx_hash(s: &str) -> u64 {
|
fn fx_hash(s: &str) -> u64 {
|
||||||
let mut h = rustc_hash::FxHasher::default();
|
let mut h = rustc_hash::FxHasher::default();
|
||||||
s.hash(&mut h);
|
s.hash(&mut h);
|
||||||
@@ -208,18 +217,12 @@ fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) {
|
|||||||
.push(f.content_length.min(u32::MAX as u64) as u32);
|
.push(f.content_length.min(u32::MAX as u64) as u32);
|
||||||
state
|
state
|
||||||
.has_cookies
|
.has_cookies
|
||||||
.push(f.has_cookies.unwrap_or(false));
|
.push(f.has_cookies);
|
||||||
state.has_referer.push(
|
state.has_referer.push(
|
||||||
f.referer
|
!f.referer.is_empty() && f.referer != "-",
|
||||||
.as_deref()
|
|
||||||
.map(|r| r != "-")
|
|
||||||
.unwrap_or(false),
|
|
||||||
);
|
);
|
||||||
state.has_accept_language.push(
|
state.has_accept_language.push(
|
||||||
f.accept_language
|
!f.accept_language.is_empty() && f.accept_language != "-",
|
||||||
.as_deref()
|
|
||||||
.map(|a| a != "-")
|
|
||||||
.unwrap_or(false),
|
|
||||||
);
|
);
|
||||||
state
|
state
|
||||||
.suspicious_paths
|
.suspicious_paths
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
use crate::scanner::model::{ScannerAction, ScannerVerdict};
|
use crate::scanner::model::{ScannerAction, ScannerVerdict};
|
||||||
use super::gen::scanner_weights;
|
use super::gen::scanner_weights;
|
||||||
use super::mlp::mlp_predict_32;
|
use super::mlp::mlp_predict_32;
|
||||||
@@ -92,27 +95,11 @@ impl From<EnsembleVerdict> for ScannerVerdict {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tree_allow_path() {
|
|
||||||
// All features at zero → feature 3 (suspicious_ua) = 0.0 <= 0.65 → left (node 1)
|
|
||||||
// feature 0 (path_depth) = 0.0 <= 0.40 → left (node 3) → Allow leaf
|
|
||||||
let raw = [0.0f32; 12];
|
|
||||||
let v = scanner_ensemble_predict(&raw);
|
|
||||||
assert_eq!(v.action, ScannerAction::Allow);
|
|
||||||
assert_eq!(v.path, EnsemblePath::TreeAllow);
|
|
||||||
assert_eq!(v.reason, "ensemble:tree_allow");
|
|
||||||
assert!((v.score - 0.0).abs() < f64::EPSILON);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tree_block_path() {
|
fn test_tree_block_path() {
|
||||||
// Need: feature 3 (suspicious_ua) > 0.65 (normalized) → right (node 2)
|
// Tree: root splits on feature 7 (ua_category) at 0.75.
|
||||||
// feature 7 (payload_entropy) > 0.72 (normalized) → right (node 6) → Block
|
// All zeros → ua_category normalized = 0.0 <= 0.75 → Block (node 1)
|
||||||
// feature 3 max = 1.0, so raw 0.8 → normalized 0.8 > 0.65 ✓
|
let raw = [0.0f32; 12];
|
||||||
// feature 7 max = 8.0, so raw 6.0 → normalized 0.75 > 0.72 ✓
|
|
||||||
let mut raw = [0.0f32; 12];
|
|
||||||
raw[3] = 0.8; // suspicious_ua: normalized = 0.8/1.0 = 0.8 > 0.65
|
|
||||||
raw[7] = 6.0; // payload_entropy: normalized = 6.0/8.0 = 0.75 > 0.72
|
|
||||||
let v = scanner_ensemble_predict(&raw);
|
let v = scanner_ensemble_predict(&raw);
|
||||||
assert_eq!(v.action, ScannerAction::Block);
|
assert_eq!(v.action, ScannerAction::Block);
|
||||||
assert_eq!(v.path, EnsemblePath::TreeBlock);
|
assert_eq!(v.path, EnsemblePath::TreeBlock);
|
||||||
@@ -120,28 +107,38 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mlp_path() {
|
fn test_tree_allow_path() {
|
||||||
// Need: feature 3 > 0.65 normalized → right (node 2)
|
// Tree: root feature 7 > 0.75 → node 2, checks feature 3 (has_cookies) at 0.25.
|
||||||
// feature 7 <= 0.72 normalized → left (node 5) → Defer
|
// raw[7] = 1.0 → normalized 1.0 > 0.75 → right.
|
||||||
// Then MLP runs on the normalized input.
|
// raw[3] = 1.0 → normalized ~0.7 > 0.25 → right child node 6 → Allow leaf.
|
||||||
let mut raw = [0.0f32; 12];
|
let mut raw = [0.0f32; 12];
|
||||||
raw[3] = 0.8; // normalized = 0.8 > 0.65
|
raw[7] = 1.0; // ua_category = browser
|
||||||
raw[7] = 4.0; // normalized = 4.0/8.0 = 0.5 <= 0.72
|
raw[3] = 1.0; // has_cookies = yes
|
||||||
// Also need feature 2 (query_param_count) to navigate node 5 correctly
|
|
||||||
// node 5: split on feature 2, threshold 0.55 → left=9(Defer), right=10
|
|
||||||
// normalized feature 2 = 0.0/20.0 = 0.0 <= 0.55 → left (node 9) → Defer
|
|
||||||
let v = scanner_ensemble_predict(&raw);
|
let v = scanner_ensemble_predict(&raw);
|
||||||
assert_eq!(v.path, EnsemblePath::Mlp);
|
assert_eq!(v.action, ScannerAction::Allow);
|
||||||
assert_eq!(v.reason, "ensemble:mlp");
|
assert_eq!(v.path, EnsemblePath::TreeAllow);
|
||||||
// MLP output is deterministic for these inputs
|
assert_eq!(v.reason, "ensemble:tree_allow");
|
||||||
assert!(v.score >= 0.0 && v.score <= 1.0);
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mlp_direct() {
|
||||||
|
// Current tree has no Defer leaves, so test MLP inference directly.
|
||||||
|
let input = [0.5f32; 12];
|
||||||
|
let score = mlp_predict_32::<12>(
|
||||||
|
&scanner_weights::W1,
|
||||||
|
&scanner_weights::B1,
|
||||||
|
&scanner_weights::W2,
|
||||||
|
scanner_weights::B2,
|
||||||
|
&input,
|
||||||
|
);
|
||||||
|
assert!(score >= 0.0 && score <= 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_normalize_clamps() {
|
fn test_normalize_clamps() {
|
||||||
// Values beyond max should be clamped to 1.0
|
// Values beyond max should be clamped to 1.0
|
||||||
let mut raw = [0.0f32; 12];
|
let mut raw = [0.0f32; 12];
|
||||||
raw[0] = 100.0; // max is 10.0
|
raw[0] = 100.0;
|
||||||
let normed = normalize(&raw);
|
let normed = normalize(&raw);
|
||||||
assert!((normed[0] - 1.0).abs() < f64::EPSILON as f32);
|
assert!((normed[0] - 1.0).abs() < f64::EPSILON as f32);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
// Library crate root — exports the proxy/config/acme modules so that
|
// Library crate root — exports the proxy/config/acme modules so that
|
||||||
// integration tests in tests/ can construct and drive a SunbeamProxy
|
// integration tests in tests/ can construct and drive a SunbeamProxy
|
||||||
// without going through the binary entry point.
|
// without going through the binary entry point.
|
||||||
|
#![recursion_limit = "256"]
|
||||||
pub mod acme;
|
pub mod acme;
|
||||||
|
pub mod audit;
|
||||||
pub mod autotune;
|
pub mod autotune;
|
||||||
pub mod cache;
|
pub mod cache;
|
||||||
pub mod cluster;
|
pub mod cluster;
|
||||||
|
|||||||
423
src/main.rs
423
src/main.rs
@@ -1,10 +1,12 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
mod cert;
|
mod cert;
|
||||||
mod telemetry;
|
mod telemetry;
|
||||||
mod watcher;
|
mod watcher;
|
||||||
|
|
||||||
use sunbeam_proxy::{acme, autotune, config};
|
use sunbeam_proxy::{acme, config};
|
||||||
use sunbeam_proxy::proxy::SunbeamProxy;
|
use sunbeam_proxy::proxy::SunbeamProxy;
|
||||||
use sunbeam_proxy::ddos;
|
|
||||||
use sunbeam_proxy::rate_limit;
|
use sunbeam_proxy::rate_limit;
|
||||||
use sunbeam_proxy::scanner;
|
use sunbeam_proxy::scanner;
|
||||||
|
|
||||||
@@ -32,77 +34,18 @@ enum Commands {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
upgrade: bool,
|
upgrade: bool,
|
||||||
},
|
},
|
||||||
/// Replay audit logs through detection models
|
/// Replay audit logs through ensemble models (scanner + DDoS)
|
||||||
Replay {
|
Replay {
|
||||||
#[command(subcommand)]
|
|
||||||
mode: ReplayMode,
|
|
||||||
},
|
|
||||||
/// Train a DDoS detection model from audit logs
|
|
||||||
TrainDdos {
|
|
||||||
/// Path to audit log JSONL file
|
/// Path to audit log JSONL file
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
input: String,
|
input: String,
|
||||||
/// Output model file path
|
|
||||||
#[arg(short, long)]
|
|
||||||
output: String,
|
|
||||||
/// File with known-attack IPs (one per line)
|
|
||||||
#[arg(long)]
|
|
||||||
attack_ips: Option<String>,
|
|
||||||
/// File with known-normal IPs (one per line)
|
|
||||||
#[arg(long)]
|
|
||||||
normal_ips: Option<String>,
|
|
||||||
/// TOML file with heuristic auto-labeling thresholds
|
|
||||||
#[arg(long)]
|
|
||||||
heuristics: Option<String>,
|
|
||||||
/// KNN k parameter
|
|
||||||
#[arg(long, default_value = "5")]
|
|
||||||
k: usize,
|
|
||||||
/// Attack threshold (fraction of k neighbors)
|
|
||||||
#[arg(long, default_value = "0.6")]
|
|
||||||
threshold: f64,
|
|
||||||
/// Sliding window size in seconds
|
/// Sliding window size in seconds
|
||||||
#[arg(long, default_value = "60")]
|
#[arg(long, default_value = "60")]
|
||||||
window_secs: u64,
|
window_secs: u64,
|
||||||
/// Minimum events per IP to include in training
|
/// Minimum events per IP before DDoS classification
|
||||||
#[arg(long, default_value = "10")]
|
#[arg(long, default_value = "5")]
|
||||||
min_events: usize,
|
min_events: usize,
|
||||||
},
|
},
|
||||||
/// Train a per-request scanner detection model from audit logs
|
|
||||||
TrainScanner {
|
|
||||||
/// Path to audit log JSONL file
|
|
||||||
#[arg(short, long)]
|
|
||||||
input: String,
|
|
||||||
/// Output model file path
|
|
||||||
#[arg(short, long, default_value = "scanner_model.bin")]
|
|
||||||
output: String,
|
|
||||||
/// Directory (or file) containing .txt wordlists of scanner paths
|
|
||||||
#[arg(long)]
|
|
||||||
wordlists: Option<String>,
|
|
||||||
/// Classification threshold
|
|
||||||
#[arg(long, default_value = "0.5")]
|
|
||||||
threshold: f64,
|
|
||||||
/// Include CSIC 2010 dataset as base training data (downloaded from GitHub, cached locally)
|
|
||||||
#[arg(long)]
|
|
||||||
csic: bool,
|
|
||||||
},
|
|
||||||
/// Bayesian hyperparameter optimization for DDoS model
|
|
||||||
AutotuneDdos {
|
|
||||||
/// Path to audit log JSONL file
|
|
||||||
#[arg(short, long)]
|
|
||||||
input: String,
|
|
||||||
/// Output best model file path
|
|
||||||
#[arg(short, long, default_value = "ddos_model_best.bin")]
|
|
||||||
output: String,
|
|
||||||
/// Number of optimization trials
|
|
||||||
#[arg(long, default_value = "200")]
|
|
||||||
trials: usize,
|
|
||||||
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
|
|
||||||
#[arg(long, default_value = "1.0")]
|
|
||||||
beta: f64,
|
|
||||||
/// JSONL file to log each trial's parameters and results
|
|
||||||
#[arg(long)]
|
|
||||||
trial_log: Option<String>,
|
|
||||||
},
|
|
||||||
/// Download and cache upstream datasets (CIC-IDS2017)
|
/// Download and cache upstream datasets (CIC-IDS2017)
|
||||||
DownloadDatasets,
|
DownloadDatasets,
|
||||||
/// Prepare a unified training dataset from multiple sources
|
/// Prepare a unified training dataset from multiple sources
|
||||||
@@ -125,6 +68,12 @@ enum Commands {
|
|||||||
/// Path to heuristics.toml for auto-labeling production logs
|
/// Path to heuristics.toml for auto-labeling production logs
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
heuristics: Option<String>,
|
heuristics: Option<String>,
|
||||||
|
/// Inject CSIC 2010 dataset as labeled audit log entries
|
||||||
|
#[arg(long)]
|
||||||
|
inject_csic: bool,
|
||||||
|
/// Inject OWASP ModSec audit log entries (path to .log file)
|
||||||
|
#[arg(long)]
|
||||||
|
inject_modsec: Option<String>,
|
||||||
},
|
},
|
||||||
#[cfg(feature = "training")]
|
#[cfg(feature = "training")]
|
||||||
/// Train scanner ensemble (decision tree + MLP) from prepared dataset
|
/// Train scanner ensemble (decision tree + MLP) from prepared dataset
|
||||||
@@ -142,7 +91,7 @@ enum Commands {
|
|||||||
#[arg(long, default_value = "100")]
|
#[arg(long, default_value = "100")]
|
||||||
epochs: usize,
|
epochs: usize,
|
||||||
/// Learning rate
|
/// Learning rate
|
||||||
#[arg(long, default_value = "0.001")]
|
#[arg(long, default_value = "0.0001")]
|
||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
/// Batch size
|
/// Batch size
|
||||||
#[arg(long, default_value = "64")]
|
#[arg(long, default_value = "64")]
|
||||||
@@ -153,6 +102,12 @@ enum Commands {
|
|||||||
/// Min purity for tree leaves (below -> Defer)
|
/// Min purity for tree leaves (below -> Defer)
|
||||||
#[arg(long, default_value = "0.90")]
|
#[arg(long, default_value = "0.90")]
|
||||||
tree_min_purity: f32,
|
tree_min_purity: f32,
|
||||||
|
/// Min samples required in a leaf node (higher = less overfitting)
|
||||||
|
#[arg(long, default_value = "2")]
|
||||||
|
min_samples_leaf: usize,
|
||||||
|
/// Weight for cookie feature (0.0=ignore, 1.0=full). Controls has_cookies influence.
|
||||||
|
#[arg(long, default_value = "1.0")]
|
||||||
|
cookie_weight: f32,
|
||||||
},
|
},
|
||||||
#[cfg(feature = "training")]
|
#[cfg(feature = "training")]
|
||||||
/// Train DDoS ensemble (decision tree + MLP) from prepared dataset
|
/// Train DDoS ensemble (decision tree + MLP) from prepared dataset
|
||||||
@@ -165,184 +120,82 @@ enum Commands {
|
|||||||
hidden_dim: usize,
|
hidden_dim: usize,
|
||||||
#[arg(long, default_value = "100")]
|
#[arg(long, default_value = "100")]
|
||||||
epochs: usize,
|
epochs: usize,
|
||||||
#[arg(long, default_value = "0.001")]
|
#[arg(long, default_value = "0.0001")]
|
||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
#[arg(long, default_value = "64")]
|
#[arg(long, default_value = "64")]
|
||||||
batch_size: usize,
|
batch_size: usize,
|
||||||
#[arg(long, default_value = "6")]
|
#[arg(long, default_value = "6")]
|
||||||
tree_max_depth: usize,
|
tree_max_depth: usize,
|
||||||
|
/// Min purity for tree leaves (below -> Defer)
|
||||||
#[arg(long, default_value = "0.90")]
|
#[arg(long, default_value = "0.90")]
|
||||||
tree_min_purity: f32,
|
tree_min_purity: f32,
|
||||||
},
|
/// Min samples required in a leaf node (higher = less overfitting)
|
||||||
/// Bayesian hyperparameter optimization for scanner model
|
#[arg(long, default_value = "2")]
|
||||||
AutotuneScanner {
|
min_samples_leaf: usize,
|
||||||
/// Path to audit log JSONL file
|
/// Weight for cookie feature (0.0=ignore, 1.0=full). Controls cookie_ratio influence.
|
||||||
#[arg(short, long)]
|
|
||||||
input: String,
|
|
||||||
/// Output best model file path
|
|
||||||
#[arg(short, long, default_value = "scanner_model_best.bin")]
|
|
||||||
output: String,
|
|
||||||
/// Directory (or file) containing .txt wordlists of scanner paths
|
|
||||||
#[arg(long)]
|
|
||||||
wordlists: Option<String>,
|
|
||||||
/// Include CSIC 2010 dataset as base training data
|
|
||||||
#[arg(long)]
|
|
||||||
csic: bool,
|
|
||||||
/// Number of optimization trials
|
|
||||||
#[arg(long, default_value = "200")]
|
|
||||||
trials: usize,
|
|
||||||
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
|
|
||||||
#[arg(long, default_value = "1.0")]
|
#[arg(long, default_value = "1.0")]
|
||||||
beta: f64,
|
cookie_weight: f32,
|
||||||
/// JSONL file to log each trial's parameters and results
|
},
|
||||||
|
#[cfg(feature = "training")]
|
||||||
|
/// Sweep cookie_weight values and report tree structure + validation accuracy for each
|
||||||
|
SweepCookieWeight {
|
||||||
|
/// Path to prepared dataset (.bin)
|
||||||
|
#[arg(short = 'd', long)]
|
||||||
|
dataset: String,
|
||||||
|
/// Which detector to sweep: "scanner" or "ddos"
|
||||||
|
#[arg(long, default_value = "scanner")]
|
||||||
|
detector: String,
|
||||||
|
/// Comma-separated cookie_weight values to try (default: 0.0,0.1,0.2,...,1.0)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
trial_log: Option<String>,
|
weights: Option<String>,
|
||||||
|
/// Max tree depth
|
||||||
|
#[arg(long, default_value = "6")]
|
||||||
|
tree_max_depth: usize,
|
||||||
|
/// Min purity for tree leaves
|
||||||
|
#[arg(long, default_value = "0.90")]
|
||||||
|
tree_min_purity: f32,
|
||||||
|
/// Min samples required in a leaf node
|
||||||
|
#[arg(long, default_value = "2")]
|
||||||
|
min_samples_leaf: usize,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Subcommand)]
|
|
||||||
enum ReplayMode {
|
|
||||||
/// Replay through ensemble models (scanner + DDoS)
|
|
||||||
Ensemble {
|
|
||||||
/// Path to audit log JSONL file
|
|
||||||
#[arg(short, long)]
|
|
||||||
input: String,
|
|
||||||
/// Sliding window size in seconds
|
|
||||||
#[arg(long, default_value = "60")]
|
|
||||||
window_secs: u64,
|
|
||||||
/// Minimum events per IP before DDoS classification
|
|
||||||
#[arg(long, default_value = "5")]
|
|
||||||
min_events: usize,
|
|
||||||
},
|
|
||||||
/// Replay through legacy KNN DDoS detector
|
|
||||||
Ddos {
|
|
||||||
/// Path to audit log JSONL file
|
|
||||||
#[arg(short, long)]
|
|
||||||
input: String,
|
|
||||||
/// Path to trained model file
|
|
||||||
#[arg(short, long, default_value = "ddos_model.bin")]
|
|
||||||
model: String,
|
|
||||||
/// Optional config file (for rate limit settings)
|
|
||||||
#[arg(short, long)]
|
|
||||||
config: Option<String>,
|
|
||||||
/// KNN k parameter
|
|
||||||
#[arg(long, default_value = "5")]
|
|
||||||
k: usize,
|
|
||||||
/// Attack threshold
|
|
||||||
#[arg(long, default_value = "0.6")]
|
|
||||||
threshold: f64,
|
|
||||||
/// Sliding window size in seconds
|
|
||||||
#[arg(long, default_value = "60")]
|
|
||||||
window_secs: u64,
|
|
||||||
/// Minimum events per IP before classification
|
|
||||||
#[arg(long, default_value = "10")]
|
|
||||||
min_events: usize,
|
|
||||||
/// Also run rate limiter during replay
|
|
||||||
#[arg(long)]
|
|
||||||
rate_limit: bool,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
|
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
|
||||||
Commands::Serve { upgrade } => run_serve(upgrade),
|
Commands::Serve { upgrade } => run_serve(upgrade),
|
||||||
Commands::Replay { mode } => match mode {
|
Commands::Replay { input, window_secs, min_events } => {
|
||||||
ReplayMode::Ensemble { input, window_secs, min_events } => {
|
sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs {
|
||||||
sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs {
|
input, window_secs, min_events,
|
||||||
input, window_secs, min_events,
|
})
|
||||||
})
|
|
||||||
}
|
|
||||||
ReplayMode::Ddos { input, model, config, k, threshold, window_secs, min_events, rate_limit } => {
|
|
||||||
ddos::replay::run(ddos::replay::ReplayArgs {
|
|
||||||
input, model_path: model, config_path: config, k, threshold, window_secs, min_events, rate_limit,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
Commands::TrainDdos {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
attack_ips,
|
|
||||||
normal_ips,
|
|
||||||
heuristics,
|
|
||||||
k,
|
|
||||||
threshold,
|
|
||||||
window_secs,
|
|
||||||
min_events,
|
|
||||||
} => ddos::train::run(ddos::train::TrainArgs {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
attack_ips,
|
|
||||||
normal_ips,
|
|
||||||
heuristics,
|
|
||||||
k,
|
|
||||||
threshold,
|
|
||||||
window_secs,
|
|
||||||
min_events,
|
|
||||||
}),
|
|
||||||
Commands::TrainScanner {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
wordlists,
|
|
||||||
threshold,
|
|
||||||
csic,
|
|
||||||
} => scanner::train::run(scanner::train::TrainScannerArgs {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
wordlists,
|
|
||||||
threshold,
|
|
||||||
csic,
|
|
||||||
}),
|
|
||||||
Commands::DownloadDatasets => {
|
Commands::DownloadDatasets => {
|
||||||
sunbeam_proxy::dataset::download::download_all()
|
sunbeam_proxy::dataset::download::download_all()
|
||||||
},
|
},
|
||||||
Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics } => {
|
Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics, inject_csic, inject_modsec } => {
|
||||||
sunbeam_proxy::dataset::prepare::run(sunbeam_proxy::dataset::prepare::PrepareDatasetArgs {
|
sunbeam_proxy::dataset::prepare::run(sunbeam_proxy::dataset::prepare::PrepareDatasetArgs {
|
||||||
input, owasp, wordlists, output, seed, heuristics,
|
input, owasp, wordlists, output, seed, heuristics, inject_csic, inject_modsec,
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
#[cfg(feature = "training")]
|
#[cfg(feature = "training")]
|
||||||
Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
|
Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight } => {
|
||||||
sunbeam_proxy::training::train_scanner::run(sunbeam_proxy::training::train_scanner::TrainScannerMlpArgs {
|
sunbeam_proxy::training::train_scanner::run(sunbeam_proxy::training::train_scanner::TrainScannerMlpArgs {
|
||||||
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
|
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight,
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
#[cfg(feature = "training")]
|
#[cfg(feature = "training")]
|
||||||
Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
|
Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight } => {
|
||||||
sunbeam_proxy::training::train_ddos::run(sunbeam_proxy::training::train_ddos::TrainDdosMlpArgs {
|
sunbeam_proxy::training::train_ddos::run(sunbeam_proxy::training::train_ddos::TrainDdosMlpArgs {
|
||||||
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
|
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight,
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
Commands::AutotuneDdos {
|
#[cfg(feature = "training")]
|
||||||
input,
|
Commands::SweepCookieWeight { dataset, detector, weights, tree_max_depth, tree_min_purity, min_samples_leaf } => {
|
||||||
output,
|
sunbeam_proxy::training::sweep::run_cookie_sweep(
|
||||||
trials,
|
&dataset, &detector, weights.as_deref(), tree_max_depth, tree_min_purity, min_samples_leaf,
|
||||||
beta,
|
)
|
||||||
trial_log,
|
},
|
||||||
} => autotune::ddos::run_autotune(autotune::ddos::AutotuneDdosArgs {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
trials,
|
|
||||||
beta,
|
|
||||||
trial_log,
|
|
||||||
}),
|
|
||||||
Commands::AutotuneScanner {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
wordlists,
|
|
||||||
csic,
|
|
||||||
trials,
|
|
||||||
beta,
|
|
||||||
trial_log,
|
|
||||||
} => autotune::scanner::run_autotune(autotune::scanner::AutotuneScannerArgs {
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
wordlists,
|
|
||||||
csic,
|
|
||||||
trials,
|
|
||||||
beta,
|
|
||||||
trial_log,
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,46 +216,19 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
|||||||
// 1b. Spawn metrics HTTP server (needs a tokio runtime for the TCP listener).
|
// 1b. Spawn metrics HTTP server (needs a tokio runtime for the TCP listener).
|
||||||
let metrics_port = cfg.telemetry.metrics_port;
|
let metrics_port = cfg.telemetry.metrics_port;
|
||||||
|
|
||||||
// 2. Load DDoS detection model if configured.
|
// 2. Init DDoS detector if configured (ensemble: compiled-in weights).
|
||||||
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
|
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
|
||||||
if ddos_cfg.enabled {
|
if ddos_cfg.enabled {
|
||||||
if ddos_cfg.use_ensemble {
|
let detector = Arc::new(sunbeam_proxy::ddos::detector::DDoSDetector::new(ddos_cfg));
|
||||||
// Ensemble path: compiled-in weights, no model file needed.
|
tracing::info!(
|
||||||
// We still need a TrainedModel for the struct, but it won't be used.
|
threshold = ddos_cfg.threshold,
|
||||||
let dummy_model = ddos::model::TrainedModel::empty(ddos_cfg.k, ddos_cfg.threshold);
|
observe_only = ddos_cfg.observe_only,
|
||||||
let detector = Arc::new(ddos::detector::DDoSDetector::new_ensemble(dummy_model, ddos_cfg));
|
"DDoS ensemble detector enabled"
|
||||||
tracing::info!(
|
);
|
||||||
k = ddos_cfg.k,
|
if ddos_cfg.observe_only {
|
||||||
threshold = ddos_cfg.threshold,
|
tracing::warn!("DDoS detector in OBSERVE-ONLY mode — decisions are logged but traffic is never blocked");
|
||||||
"DDoS ensemble detector enabled"
|
|
||||||
);
|
|
||||||
Some(detector)
|
|
||||||
} else if let Some(ref model_path) = ddos_cfg.model_path {
|
|
||||||
match ddos::model::TrainedModel::load(
|
|
||||||
std::path::Path::new(model_path),
|
|
||||||
Some(ddos_cfg.k),
|
|
||||||
Some(ddos_cfg.threshold),
|
|
||||||
) {
|
|
||||||
Ok(model) => {
|
|
||||||
let point_count = model.point_count();
|
|
||||||
let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg));
|
|
||||||
tracing::info!(
|
|
||||||
points = point_count,
|
|
||||||
k = ddos_cfg.k,
|
|
||||||
threshold = ddos_cfg.threshold,
|
|
||||||
"DDoS detector loaded"
|
|
||||||
);
|
|
||||||
Some(detector)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(error = %e, "failed to load DDoS model; detection disabled");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracing::warn!("DDoS enabled but no model_path and use_ensemble=false; detection disabled");
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
Some(detector)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@@ -435,88 +261,35 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// 2c. Load scanner model if configured.
|
// 2c. Init scanner detector if configured (ensemble: compiled-in weights).
|
||||||
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
|
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
|
||||||
if scanner_cfg.enabled {
|
if scanner_cfg.enabled {
|
||||||
if scanner_cfg.use_ensemble {
|
let detector = scanner::detector::ScannerDetector::new(&cfg.routes);
|
||||||
// Ensemble path: compiled-in weights, no model file needed.
|
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
|
||||||
let detector = scanner::detector::ScannerDetector::new_ensemble(&cfg.routes);
|
|
||||||
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
|
|
||||||
|
|
||||||
// Start bot allowlist if rules are configured.
|
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
|
||||||
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
|
let al = scanner::allowlist::BotAllowlist::spawn(
|
||||||
let al = scanner::allowlist::BotAllowlist::spawn(
|
&scanner_cfg.allowlist,
|
||||||
&scanner_cfg.allowlist,
|
scanner_cfg.bot_cache_ttl_secs,
|
||||||
scanner_cfg.bot_cache_ttl_secs,
|
|
||||||
);
|
|
||||||
tracing::info!(
|
|
||||||
rules = scanner_cfg.allowlist.len(),
|
|
||||||
"bot allowlist enabled"
|
|
||||||
);
|
|
||||||
Some(al)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing::info!(
|
|
||||||
threshold = scanner_cfg.threshold,
|
|
||||||
"scanner ensemble detector enabled"
|
|
||||||
);
|
);
|
||||||
(Some(handle), bot_allowlist)
|
tracing::info!(
|
||||||
} else if let Some(ref model_path) = scanner_cfg.model_path {
|
rules = scanner_cfg.allowlist.len(),
|
||||||
match scanner::model::ScannerModel::load(std::path::Path::new(model_path)) {
|
"bot allowlist enabled"
|
||||||
Ok(mut model) => {
|
);
|
||||||
let fragment_count = model.fragments.len();
|
Some(al)
|
||||||
model.threshold = scanner_cfg.threshold;
|
|
||||||
let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes);
|
|
||||||
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
|
|
||||||
|
|
||||||
// Start bot allowlist if rules are configured.
|
|
||||||
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
|
|
||||||
let al = scanner::allowlist::BotAllowlist::spawn(
|
|
||||||
&scanner_cfg.allowlist,
|
|
||||||
scanner_cfg.bot_cache_ttl_secs,
|
|
||||||
);
|
|
||||||
tracing::info!(
|
|
||||||
rules = scanner_cfg.allowlist.len(),
|
|
||||||
"bot allowlist enabled"
|
|
||||||
);
|
|
||||||
Some(al)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start background file watcher for hot-reload.
|
|
||||||
if scanner_cfg.poll_interval_secs > 0 {
|
|
||||||
let watcher_handle = handle.clone();
|
|
||||||
let watcher_model_path = std::path::PathBuf::from(model_path);
|
|
||||||
let threshold = scanner_cfg.threshold;
|
|
||||||
let routes = cfg.routes.clone();
|
|
||||||
let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs);
|
|
||||||
std::thread::spawn(move || {
|
|
||||||
scanner::watcher::watch_scanner_model(
|
|
||||||
watcher_handle, watcher_model_path, threshold, routes, interval,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::info!(
|
|
||||||
fragments = fragment_count,
|
|
||||||
threshold = scanner_cfg.threshold,
|
|
||||||
poll_interval_secs = scanner_cfg.poll_interval_secs,
|
|
||||||
"scanner detector loaded"
|
|
||||||
);
|
|
||||||
(Some(handle), bot_allowlist)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled");
|
|
||||||
(None, None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
tracing::warn!("scanner enabled but no model_path and use_ensemble=false; scanner detection disabled");
|
None
|
||||||
(None, None)
|
};
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
threshold = scanner_cfg.threshold,
|
||||||
|
observe_only = scanner_cfg.observe_only,
|
||||||
|
"scanner ensemble detector enabled"
|
||||||
|
);
|
||||||
|
if scanner_cfg.observe_only {
|
||||||
|
tracing::warn!("scanner detector in OBSERVE-ONLY mode — decisions are logged but traffic is never blocked");
|
||||||
}
|
}
|
||||||
|
(Some(handle), bot_allowlist)
|
||||||
} else {
|
} else {
|
||||||
(None, None)
|
(None, None)
|
||||||
}
|
}
|
||||||
@@ -617,6 +390,8 @@ fn run_serve(upgrade: bool) -> Result<()> {
|
|||||||
&cfg.rate_limit.as_ref().map(|rl| rl.bypass_cidrs.clone()).unwrap_or_default(),
|
&cfg.rate_limit.as_ref().map(|rl| rl.bypass_cidrs.clone()).unwrap_or_default(),
|
||||||
),
|
),
|
||||||
cluster: cluster_handle,
|
cluster: cluster_handle,
|
||||||
|
ddos_observe_only: cfg.ddos.as_ref().map(|d| d.observe_only).unwrap_or(false),
|
||||||
|
scanner_observe_only: cfg.scanner.as_ref().map(|s| s.observe_only).unwrap_or(false),
|
||||||
};
|
};
|
||||||
let mut svc = http_proxy_service(&server.configuration, proxy);
|
let mut svc = http_proxy_service(&server.configuration, proxy);
|
||||||
|
|
||||||
|
|||||||
35
src/proxy.rs
35
src/proxy.rs
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
use crate::acme::AcmeRoutes;
|
use crate::acme::AcmeRoutes;
|
||||||
use crate::cluster::ClusterHandle;
|
use crate::cluster::ClusterHandle;
|
||||||
use crate::config::RouteConfig;
|
use crate::config::RouteConfig;
|
||||||
@@ -32,9 +35,9 @@ pub struct SunbeamProxy {
|
|||||||
pub routes: Vec<RouteConfig>,
|
pub routes: Vec<RouteConfig>,
|
||||||
/// Per-challenge route table populated by the Ingress watcher.
|
/// Per-challenge route table populated by the Ingress watcher.
|
||||||
pub acme_routes: AcmeRoutes,
|
pub acme_routes: AcmeRoutes,
|
||||||
/// Optional KNN-based DDoS detector.
|
/// Optional DDoS detector (ensemble: decision tree + MLP).
|
||||||
pub ddos_detector: Option<Arc<DDoSDetector>>,
|
pub ddos_detector: Option<Arc<DDoSDetector>>,
|
||||||
/// Optional per-request scanner detector (hot-reloadable via ArcSwap).
|
/// Optional per-request scanner detector (ensemble: decision tree + MLP).
|
||||||
pub scanner_detector: Option<Arc<ArcSwap<ScannerDetector>>>,
|
pub scanner_detector: Option<Arc<ArcSwap<ScannerDetector>>>,
|
||||||
/// Optional verified-bot allowlist (bypasses scanner for known crawlers/agents).
|
/// Optional verified-bot allowlist (bypasses scanner for known crawlers/agents).
|
||||||
pub bot_allowlist: Option<Arc<BotAllowlist>>,
|
pub bot_allowlist: Option<Arc<BotAllowlist>>,
|
||||||
@@ -48,6 +51,10 @@ pub struct SunbeamProxy {
|
|||||||
pub pipeline_bypass_cidrs: Vec<crate::rate_limit::cidr::CidrBlock>,
|
pub pipeline_bypass_cidrs: Vec<crate::rate_limit::cidr::CidrBlock>,
|
||||||
/// Optional cluster handle for multi-node bandwidth tracking.
|
/// Optional cluster handle for multi-node bandwidth tracking.
|
||||||
pub cluster: Option<Arc<ClusterHandle>>,
|
pub cluster: Option<Arc<ClusterHandle>>,
|
||||||
|
/// When true, DDoS detector logs decisions but never blocks traffic.
|
||||||
|
pub ddos_observe_only: bool,
|
||||||
|
/// When true, scanner detector logs decisions but never blocks traffic.
|
||||||
|
pub scanner_observe_only: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RequestCtx {
|
pub struct RequestCtx {
|
||||||
@@ -341,7 +348,7 @@ impl ProxyHttp for SunbeamProxy {
|
|||||||
|
|
||||||
metrics::DDOS_DECISIONS.with_label_values(&[decision]).inc();
|
metrics::DDOS_DECISIONS.with_label_values(&[decision]).inc();
|
||||||
|
|
||||||
if matches!(ddos_action, DDoSAction::Block) {
|
if matches!(ddos_action, DDoSAction::Block) && !self.ddos_observe_only {
|
||||||
let mut resp = ResponseHeader::build(429, None)?;
|
let mut resp = ResponseHeader::build(429, None)?;
|
||||||
resp.insert_header("Retry-After", "60")?;
|
resp.insert_header("Retry-After", "60")?;
|
||||||
resp.insert_header("Content-Length", "0")?;
|
resp.insert_header("Content-Length", "0")?;
|
||||||
@@ -426,7 +433,7 @@ impl ProxyHttp for SunbeamProxy {
|
|||||||
.with_label_values(&[decision, reason])
|
.with_label_values(&[decision, reason])
|
||||||
.inc();
|
.inc();
|
||||||
|
|
||||||
if decision == "block" {
|
if decision == "block" && !self.scanner_observe_only {
|
||||||
let mut resp = ResponseHeader::build(403, None)?;
|
let mut resp = ResponseHeader::build(403, None)?;
|
||||||
resp.insert_header("Content-Length", "0")?;
|
resp.insert_header("Content-Length", "0")?;
|
||||||
session.write_response_header(Box::new(resp), true).await?;
|
session.write_response_header(Box::new(resp), true).await?;
|
||||||
@@ -1150,6 +1157,21 @@ impl ProxyHttp for SunbeamProxy {
|
|||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("-");
|
.unwrap_or("-");
|
||||||
let query = session.req_header().uri.query().unwrap_or("");
|
let query = session.req_header().uri.query().unwrap_or("");
|
||||||
|
let response_bytes = session.body_bytes_sent();
|
||||||
|
let http_version = format!("{:?}", session.req_header().version);
|
||||||
|
let header_count = session.req_header().headers.len() as u16;
|
||||||
|
let accept_encoding = session
|
||||||
|
.req_header()
|
||||||
|
.headers
|
||||||
|
.get("accept-encoding")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("-");
|
||||||
|
let connection = session
|
||||||
|
.req_header()
|
||||||
|
.headers
|
||||||
|
.get("connection")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("-");
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
target = "audit",
|
target = "audit",
|
||||||
@@ -1162,14 +1184,19 @@ impl ProxyHttp for SunbeamProxy {
|
|||||||
status,
|
status,
|
||||||
duration_ms,
|
duration_ms,
|
||||||
content_length,
|
content_length,
|
||||||
|
response_bytes,
|
||||||
user_agent,
|
user_agent,
|
||||||
referer,
|
referer,
|
||||||
accept_language,
|
accept_language,
|
||||||
accept,
|
accept,
|
||||||
|
accept_encoding,
|
||||||
has_cookies,
|
has_cookies,
|
||||||
cf_country,
|
cf_country,
|
||||||
backend,
|
backend,
|
||||||
error = error_str,
|
error = error_str,
|
||||||
|
http_version,
|
||||||
|
header_count,
|
||||||
|
connection,
|
||||||
"request"
|
"request"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! Fetch and convert the CSIC 2010 HTTP dataset into labeled training samples.
|
//! Fetch and convert the CSIC 2010 HTTP dataset into labeled training samples.
|
||||||
//!
|
//!
|
||||||
//! The CSIC 2010 dataset contains raw HTTP/1.1 requests (normal + anomalous)
|
//! The CSIC 2010 dataset contains raw HTTP/1.1 requests (normal + anomalous)
|
||||||
@@ -65,6 +68,7 @@ struct ParsedRequest {
|
|||||||
content_length: u64,
|
content_length: u64,
|
||||||
referer: String,
|
referer: String,
|
||||||
accept_language: String,
|
accept_language: String,
|
||||||
|
accept: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_csic_content(content: &str) -> Vec<ParsedRequest> {
|
fn parse_csic_content(content: &str) -> Vec<ParsedRequest> {
|
||||||
@@ -158,6 +162,7 @@ fn parse_single_request(lines: &[&str]) -> Option<ParsedRequest> {
|
|||||||
content_length,
|
content_length,
|
||||||
referer: get_header("Referer").unwrap_or("-").to_string(),
|
referer: get_header("Referer").unwrap_or("-").to_string(),
|
||||||
accept_language: get_header("Accept-Language").unwrap_or("-").to_string(),
|
accept_language: get_header("Accept-Language").unwrap_or("-").to_string(),
|
||||||
|
accept: get_header("Accept").unwrap_or("-").to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,11 +224,12 @@ fn to_audit_fields(
|
|||||||
// For anomalous samples, simulate real scanner behavior:
|
// For anomalous samples, simulate real scanner behavior:
|
||||||
// strip cookies/referer/accept-language that CSIC attacks have from their session.
|
// strip cookies/referer/accept-language that CSIC attacks have from their session.
|
||||||
let (has_cookies, referer, accept_language, user_agent) = if label != "normal" {
|
let (has_cookies, referer, accept_language, user_agent) = if label != "normal" {
|
||||||
let referer = None;
|
let referer = "-".to_string();
|
||||||
let accept_language = if rng.next_f64() < 0.8 {
|
let accept_language = if rng.next_f64() < 0.8 {
|
||||||
None
|
"-".to_string()
|
||||||
} else {
|
} else {
|
||||||
Some(req.accept_language.clone()).filter(|a| a != "-")
|
let al = req.accept_language.clone();
|
||||||
|
if al == "-" { "-".to_string() } else { al }
|
||||||
};
|
};
|
||||||
let r = rng.next_f64();
|
let r = rng.next_f64();
|
||||||
let user_agent = if r < 0.15 {
|
let user_agent = if r < 0.15 {
|
||||||
@@ -241,12 +247,26 @@ fn to_audit_fields(
|
|||||||
} else {
|
} else {
|
||||||
(
|
(
|
||||||
req.has_cookies,
|
req.has_cookies,
|
||||||
Some(req.referer.clone()).filter(|r| r != "-"),
|
if req.referer == "-" { "-".to_string() } else { req.referer.clone() },
|
||||||
Some(req.accept_language.clone()).filter(|a| a != "-"),
|
if req.accept_language == "-" { "-".to_string() } else { req.accept_language.clone() },
|
||||||
req.user_agent.clone(),
|
req.user_agent.clone(),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// For normal traffic, preserve Accept header from CSIC request.
|
||||||
|
// For attacks, degrade it to simulate scanner behavior.
|
||||||
|
let accept = if label == "normal" {
|
||||||
|
if req.accept == "-" || req.accept.is_empty() {
|
||||||
|
"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8".to_string()
|
||||||
|
} else {
|
||||||
|
req.accept.clone()
|
||||||
|
}
|
||||||
|
} else if rng.next_f64() < 0.6 {
|
||||||
|
"*/*".to_string()
|
||||||
|
} else {
|
||||||
|
req.accept.clone()
|
||||||
|
};
|
||||||
|
|
||||||
AuditFields {
|
AuditFields {
|
||||||
method: req.method.clone(),
|
method: req.method.clone(),
|
||||||
host,
|
host,
|
||||||
@@ -263,9 +283,10 @@ fn to_audit_fields(
|
|||||||
duration_ms: rng.next_usize(50) as u64 + 1,
|
duration_ms: rng.next_usize(50) as u64 + 1,
|
||||||
content_length: req.content_length,
|
content_length: req.content_length,
|
||||||
user_agent,
|
user_agent,
|
||||||
has_cookies: Some(has_cookies),
|
has_cookies,
|
||||||
referer,
|
referer,
|
||||||
accept_language,
|
accept_language,
|
||||||
|
accept,
|
||||||
backend: if label == "normal" {
|
backend: if label == "normal" {
|
||||||
format!("{host_prefix}-svc:8080")
|
format!("{host_prefix}-svc:8080")
|
||||||
} else {
|
} else {
|
||||||
@@ -274,6 +295,7 @@ fn to_audit_fields(
|
|||||||
label: Some(
|
label: Some(
|
||||||
if label == "normal" { "normal" } else { "attack" }.to_string(),
|
if label == "normal" { "normal" } else { "attack" }.to_string(),
|
||||||
),
|
),
|
||||||
|
..AuditFields::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,6 +365,7 @@ mod tests {
|
|||||||
assert_eq!(req.path, "/index.html");
|
assert_eq!(req.path, "/index.html");
|
||||||
assert!(req.has_cookies);
|
assert!(req.has_cookies);
|
||||||
assert_eq!(req.user_agent, "Mozilla/5.0");
|
assert_eq!(req.user_agent, "Mozilla/5.0");
|
||||||
|
assert_eq!(req.accept, "text/html");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -374,11 +397,12 @@ mod tests {
|
|||||||
content_length: 100,
|
content_length: 100,
|
||||||
referer: "https://example.com".to_string(),
|
referer: "https://example.com".to_string(),
|
||||||
accept_language: "en-US".to_string(),
|
accept_language: "en-US".to_string(),
|
||||||
|
accept: "text/html".to_string(),
|
||||||
};
|
};
|
||||||
let mut rng = Rng::new(42);
|
let mut rng = Rng::new(42);
|
||||||
let fields = to_audit_fields(&req, "normal", DEFAULT_HOSTS, &mut rng);
|
let fields = to_audit_fields(&req, "normal", DEFAULT_HOSTS, &mut rng);
|
||||||
assert_eq!(fields.label.as_deref(), Some("normal"));
|
assert_eq!(fields.label.as_deref(), Some("normal"));
|
||||||
assert!(fields.has_cookies.unwrap_or(false));
|
assert!(fields.has_cookies);
|
||||||
assert!(fields.host.ends_with(".sunbeam.pt"));
|
assert!(fields.host.ends_with(".sunbeam.pt"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -393,11 +417,12 @@ mod tests {
|
|||||||
content_length: 0,
|
content_length: 0,
|
||||||
referer: "https://example.com".to_string(),
|
referer: "https://example.com".to_string(),
|
||||||
accept_language: "en-US".to_string(),
|
accept_language: "en-US".to_string(),
|
||||||
|
accept: "text/html".to_string(),
|
||||||
};
|
};
|
||||||
let mut rng = Rng::new(42);
|
let mut rng = Rng::new(42);
|
||||||
let fields = to_audit_fields(&req, "anomalous", DEFAULT_HOSTS, &mut rng);
|
let fields = to_audit_fields(&req, "anomalous", DEFAULT_HOSTS, &mut rng);
|
||||||
assert_eq!(fields.label.as_deref(), Some("attack"));
|
assert_eq!(fields.label.as_deref(), Some("attack"));
|
||||||
assert!(!fields.has_cookies.unwrap_or(true));
|
assert!(!fields.has_cookies);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
use crate::config::RouteConfig;
|
use crate::config::RouteConfig;
|
||||||
use crate::scanner::features::{
|
use crate::scanner::features::{self, fx_hash_bytes, SUSPICIOUS_EXTENSIONS_LIST};
|
||||||
self, fx_hash_bytes, ScannerNormParams, SUSPICIOUS_EXTENSIONS_LIST, NUM_SCANNER_FEATURES,
|
use crate::scanner::model::{ScannerAction, ScannerVerdict};
|
||||||
NUM_SCANNER_WEIGHTS,
|
|
||||||
};
|
|
||||||
use crate::scanner::model::{ScannerAction, ScannerModel, ScannerVerdict};
|
|
||||||
use rustc_hash::FxHashSet;
|
use rustc_hash::FxHashSet;
|
||||||
|
|
||||||
/// Immutable, zero-state per-request scanner detector.
|
/// Immutable, zero-state per-request scanner detector.
|
||||||
@@ -12,44 +12,10 @@ pub struct ScannerDetector {
|
|||||||
fragment_hashes: FxHashSet<u64>,
|
fragment_hashes: FxHashSet<u64>,
|
||||||
extension_hashes: FxHashSet<u64>,
|
extension_hashes: FxHashSet<u64>,
|
||||||
configured_hosts: FxHashSet<u64>,
|
configured_hosts: FxHashSet<u64>,
|
||||||
weights: [f64; NUM_SCANNER_WEIGHTS],
|
|
||||||
threshold: f64,
|
|
||||||
norm_params: ScannerNormParams,
|
|
||||||
use_ensemble: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ScannerDetector {
|
impl ScannerDetector {
|
||||||
pub fn new(model: &ScannerModel, routes: &[RouteConfig]) -> Self {
|
pub fn new(routes: &[RouteConfig]) -> Self {
|
||||||
let fragment_hashes: FxHashSet<u64> = model
|
|
||||||
.fragments
|
|
||||||
.iter()
|
|
||||||
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let extension_hashes: FxHashSet<u64> = SUSPICIOUS_EXTENSIONS_LIST
|
|
||||||
.iter()
|
|
||||||
.map(|e| fx_hash_bytes(e.as_bytes()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let configured_hosts: FxHashSet<u64> = routes
|
|
||||||
.iter()
|
|
||||||
.map(|r| fx_hash_bytes(r.host_prefix.as_bytes()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
fragment_hashes,
|
|
||||||
extension_hashes,
|
|
||||||
configured_hosts,
|
|
||||||
weights: model.weights,
|
|
||||||
threshold: model.threshold,
|
|
||||||
norm_params: model.norm_params.clone(),
|
|
||||||
use_ensemble: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a detector that uses the ensemble (decision tree + MLP) path
|
|
||||||
/// instead of the linear model. No model file needed — weights are compiled in.
|
|
||||||
pub fn new_ensemble(routes: &[RouteConfig]) -> Self {
|
|
||||||
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
||||||
.iter()
|
.iter()
|
||||||
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
||||||
@@ -69,13 +35,6 @@ impl ScannerDetector {
|
|||||||
fragment_hashes,
|
fragment_hashes,
|
||||||
extension_hashes,
|
extension_hashes,
|
||||||
configured_hosts,
|
configured_hosts,
|
||||||
weights: [0.0; NUM_SCANNER_WEIGHTS],
|
|
||||||
threshold: 0.5,
|
|
||||||
norm_params: ScannerNormParams {
|
|
||||||
mins: [0.0; NUM_SCANNER_FEATURES],
|
|
||||||
maxs: [1.0; NUM_SCANNER_FEATURES],
|
|
||||||
},
|
|
||||||
use_ensemble: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,8 +57,6 @@ impl ScannerDetector {
|
|||||||
content_length: u64,
|
content_length: u64,
|
||||||
) -> ScannerVerdict {
|
) -> ScannerVerdict {
|
||||||
// Hard allowlist: obviously legitimate traffic bypasses the model.
|
// Hard allowlist: obviously legitimate traffic bypasses the model.
|
||||||
// This prevents model drift from ever blocking real users and ensures
|
|
||||||
// the training pipeline always has clean positive labels.
|
|
||||||
let host_known = {
|
let host_known = {
|
||||||
let hash = features::fx_hash_bytes(host_prefix.as_bytes());
|
let hash = features::fx_hash_bytes(host_prefix.as_bytes());
|
||||||
self.configured_hosts.contains(&hash)
|
self.configured_hosts.contains(&hash)
|
||||||
@@ -121,95 +78,32 @@ impl ScannerDetector {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.use_ensemble {
|
// Ensemble path: extract f32 features → decision tree + MLP.
|
||||||
// Ensemble path: extract f32 features → decision tree + MLP.
|
let raw_f32 = features::extract_features_f32(
|
||||||
let raw_f32 = features::extract_features_f32(
|
method, path, host_prefix,
|
||||||
method, path, host_prefix,
|
has_cookies, has_referer, has_accept_language,
|
||||||
has_cookies, has_referer, has_accept_language,
|
accept, user_agent, content_length,
|
||||||
accept, user_agent, content_length,
|
&self.fragment_hashes, &self.extension_hashes, &self.configured_hosts,
|
||||||
&self.fragment_hashes, &self.extension_hashes, &self.configured_hosts,
|
|
||||||
);
|
|
||||||
let ev = crate::ensemble::scanner::scanner_ensemble_predict(&raw_f32);
|
|
||||||
crate::metrics::SCANNER_ENSEMBLE_PATH
|
|
||||||
.with_label_values(&[match ev.path {
|
|
||||||
crate::ensemble::scanner::EnsemblePath::TreeBlock => "tree_block",
|
|
||||||
crate::ensemble::scanner::EnsemblePath::TreeAllow => "tree_allow",
|
|
||||||
crate::ensemble::scanner::EnsemblePath::Mlp => "mlp",
|
|
||||||
}])
|
|
||||||
.inc();
|
|
||||||
return ev.into();
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. Extract 12 features
|
|
||||||
let raw = features::extract_features(
|
|
||||||
method,
|
|
||||||
path,
|
|
||||||
host_prefix,
|
|
||||||
has_cookies,
|
|
||||||
has_referer,
|
|
||||||
has_accept_language,
|
|
||||||
accept,
|
|
||||||
user_agent,
|
|
||||||
content_length,
|
|
||||||
&self.fragment_hashes,
|
|
||||||
&self.extension_hashes,
|
|
||||||
&self.configured_hosts,
|
|
||||||
);
|
);
|
||||||
|
let ev = crate::ensemble::scanner::scanner_ensemble_predict(&raw_f32);
|
||||||
// 2. Normalize
|
crate::metrics::SCANNER_ENSEMBLE_PATH
|
||||||
let f = self.norm_params.normalize(&raw);
|
.with_label_values(&[match ev.path {
|
||||||
|
crate::ensemble::scanner::EnsemblePath::TreeBlock => "tree_block",
|
||||||
// 3. Compute score = bias + dot(weights, features) + interaction terms
|
crate::ensemble::scanner::EnsemblePath::TreeAllow => "tree_allow",
|
||||||
let mut score = self.weights[NUM_SCANNER_FEATURES + 2]; // bias (index 14)
|
crate::ensemble::scanner::EnsemblePath::Mlp => "mlp",
|
||||||
for (i, &fi) in f.iter().enumerate().take(NUM_SCANNER_FEATURES) {
|
}])
|
||||||
score += self.weights[i] * fi;
|
.inc();
|
||||||
}
|
ev.into()
|
||||||
// Interaction: suspicious_path AND no_cookies
|
|
||||||
score += self.weights[12] * f[0] * (1.0 - f[3]);
|
|
||||||
// Interaction: unknown_host AND no_accept_language
|
|
||||||
score += self.weights[13] * (1.0 - f[9]) * (1.0 - f[5]);
|
|
||||||
|
|
||||||
// 4. Threshold
|
|
||||||
let action = if score > self.threshold {
|
|
||||||
ScannerAction::Block
|
|
||||||
} else {
|
|
||||||
ScannerAction::Allow
|
|
||||||
};
|
|
||||||
|
|
||||||
ScannerVerdict {
|
|
||||||
action,
|
|
||||||
score,
|
|
||||||
reason: "model",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::scanner::features::NUM_SCANNER_FEATURES;
|
use crate::config::RouteConfig;
|
||||||
|
|
||||||
fn make_detector(weights: [f64; NUM_SCANNER_WEIGHTS], threshold: f64) -> ScannerDetector {
|
fn test_routes() -> Vec<RouteConfig> {
|
||||||
let model = ScannerModel {
|
vec![RouteConfig {
|
||||||
weights,
|
|
||||||
threshold,
|
|
||||||
norm_params: ScannerNormParams {
|
|
||||||
mins: [0.0; NUM_SCANNER_FEATURES],
|
|
||||||
maxs: [1.0; NUM_SCANNER_FEATURES],
|
|
||||||
},
|
|
||||||
fragments: vec![
|
|
||||||
".env".into(),
|
|
||||||
"wp-admin".into(),
|
|
||||||
"wp-login".into(),
|
|
||||||
"phpinfo".into(),
|
|
||||||
"phpmyadmin".into(),
|
|
||||||
".git".into(),
|
|
||||||
"cgi-bin".into(),
|
|
||||||
".htaccess".into(),
|
|
||||||
".htpasswd".into(),
|
|
||||||
],
|
|
||||||
};
|
|
||||||
let routes = vec![RouteConfig {
|
|
||||||
host_prefix: "app".into(),
|
host_prefix: "app".into(),
|
||||||
backend: "http://127.0.0.1:8080".into(),
|
backend: "http://127.0.0.1:8080".into(),
|
||||||
websocket: false,
|
websocket: false,
|
||||||
@@ -221,35 +115,12 @@ mod tests {
|
|||||||
body_rewrites: vec![],
|
body_rewrites: vec![],
|
||||||
response_headers: vec![],
|
response_headers: vec![],
|
||||||
cache: None,
|
cache: None,
|
||||||
}];
|
}]
|
||||||
ScannerDetector::new(&model, &routes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Weights tuned to block scanner-like requests:
|
|
||||||
/// High weight on suspicious_path (w[0]), no_cookies interaction (w[12]),
|
|
||||||
/// has_suspicious_extension (w[2]), traversal (w[11]).
|
|
||||||
/// Negative weight on has_cookies (w[3]), has_referer (w[4]),
|
|
||||||
/// accept_quality (w[6]), ua_category (w[7]), host_is_configured (w[9]).
|
|
||||||
fn attack_tuned_weights() -> [f64; NUM_SCANNER_WEIGHTS] {
|
|
||||||
let mut w = [0.0; NUM_SCANNER_WEIGHTS];
|
|
||||||
w[0] = 2.0; // suspicious_path_score
|
|
||||||
w[2] = 2.0; // has_suspicious_extension
|
|
||||||
w[3] = -2.0; // has_cookies (negative = good)
|
|
||||||
w[4] = -1.0; // has_referer (negative = good)
|
|
||||||
w[5] = -1.0; // has_accept_language (negative = good)
|
|
||||||
w[6] = -0.5; // accept_quality (negative = good)
|
|
||||||
w[7] = -1.0; // ua_category (negative = browser is good)
|
|
||||||
w[9] = -1.5; // host_is_configured (negative = known host is good)
|
|
||||||
w[11] = 2.0; // path_has_traversal
|
|
||||||
w[12] = 1.5; // interaction: suspicious_path AND no_cookies
|
|
||||||
w[13] = 1.0; // interaction: unknown_host AND no_accept_lang
|
|
||||||
w[14] = 0.5; // bias
|
|
||||||
w
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_normal_browser_request_allowed() {
|
fn test_normal_browser_request_allowed() {
|
||||||
let detector = make_detector(attack_tuned_weights(), 0.5);
|
let detector = ScannerDetector::new(&test_routes());
|
||||||
let verdict = detector.check(
|
let verdict = detector.check(
|
||||||
"GET",
|
"GET",
|
||||||
"/blog/hello-world",
|
"/blog/hello-world",
|
||||||
@@ -267,7 +138,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_api_client_with_auth_allowed() {
|
fn test_api_client_with_auth_allowed() {
|
||||||
let detector = make_detector(attack_tuned_weights(), 0.5);
|
let detector = ScannerDetector::new(&test_routes());
|
||||||
let verdict = detector.check(
|
let verdict = detector.check(
|
||||||
"POST",
|
"POST",
|
||||||
"/api/v1/data",
|
"/api/v1/data",
|
||||||
@@ -285,81 +156,24 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_env_probe_blocked() {
|
fn test_env_probe_blocked() {
|
||||||
let detector = make_detector(attack_tuned_weights(), 0.5);
|
let detector = ScannerDetector::new(&test_routes());
|
||||||
let verdict = detector.check(
|
let verdict = detector.check(
|
||||||
"GET",
|
"GET",
|
||||||
"/.env",
|
"/.env",
|
||||||
"unknown",
|
"unknown",
|
||||||
false, // no cookies
|
false,
|
||||||
false, // no referer
|
false,
|
||||||
false, // no accept-language
|
false,
|
||||||
"*/*",
|
"*/*",
|
||||||
"curl/7.0",
|
"curl/7.0",
|
||||||
0,
|
0,
|
||||||
);
|
);
|
||||||
assert_eq!(verdict.action, ScannerAction::Block);
|
assert_eq!(verdict.action, ScannerAction::Block);
|
||||||
assert_eq!(verdict.reason, "model");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_wordpress_scan_blocked() {
|
|
||||||
let detector = make_detector(attack_tuned_weights(), 0.5);
|
|
||||||
let verdict = detector.check(
|
|
||||||
"GET",
|
|
||||||
"/wp-admin/install.php",
|
|
||||||
"unknown",
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"*/*",
|
|
||||||
"",
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
assert_eq!(verdict.action, ScannerAction::Block);
|
|
||||||
assert_eq!(verdict.reason, "model");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_path_traversal_blocked() {
|
|
||||||
let detector = make_detector(attack_tuned_weights(), 0.5);
|
|
||||||
let verdict = detector.check(
|
|
||||||
"GET",
|
|
||||||
"/etc/../../../passwd",
|
|
||||||
"unknown",
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"*/*",
|
|
||||||
"python-requests/2.28",
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
assert_eq!(verdict.action, ScannerAction::Block);
|
|
||||||
assert_eq!(verdict.reason, "model");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_legitimate_php_path_allowed() {
|
|
||||||
let detector = make_detector(attack_tuned_weights(), 0.5);
|
|
||||||
// "/blog/php-is-dead" — "php-is-dead" is not a known fragment
|
|
||||||
// has_cookies=true + known host "app" → hits allowlist
|
|
||||||
let verdict = detector.check(
|
|
||||||
"GET",
|
|
||||||
"/blog/php-is-dead",
|
|
||||||
"app",
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
"text/html",
|
|
||||||
"Mozilla/5.0 Chrome/120",
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
assert_eq!(verdict.action, ScannerAction::Allow);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_allowlist_browser_on_known_host() {
|
fn test_allowlist_browser_on_known_host() {
|
||||||
let detector = make_detector(attack_tuned_weights(), 0.5);
|
let detector = ScannerDetector::new(&test_routes());
|
||||||
// No cookies but browser UA + accept-language + known host → allowlist
|
|
||||||
let verdict = detector.check(
|
let verdict = detector.check(
|
||||||
"GET",
|
"GET",
|
||||||
"/",
|
"/",
|
||||||
@@ -374,22 +188,4 @@ mod tests {
|
|||||||
assert_eq!(verdict.action, ScannerAction::Allow);
|
assert_eq!(verdict.action, ScannerAction::Allow);
|
||||||
assert_eq!(verdict.reason, "allowlist:host+browser");
|
assert_eq!(verdict.reason, "allowlist:host+browser");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_model_path_for_non_allowlisted() {
|
|
||||||
let detector = make_detector(attack_tuned_weights(), 0.5);
|
|
||||||
// Unknown host, no cookies, curl UA → goes through model
|
|
||||||
let verdict = detector.check(
|
|
||||||
"GET",
|
|
||||||
"/robots.txt",
|
|
||||||
"unknown",
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
false,
|
|
||||||
"*/*",
|
|
||||||
"curl/7.0",
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
assert_eq!(verdict.reason, "model");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
pub mod allowlist;
|
pub mod allowlist;
|
||||||
pub mod csic;
|
pub mod csic;
|
||||||
pub mod detector;
|
pub mod detector;
|
||||||
pub mod features;
|
pub mod features;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod train;
|
pub mod train;
|
||||||
pub mod watcher;
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
use crate::scanner::features::{ScannerNormParams, NUM_SCANNER_WEIGHTS};
|
// Copyright Sunbeam Studios 2026
|
||||||
use anyhow::{Context, Result};
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub enum ScannerAction {
|
pub enum ScannerAction {
|
||||||
@@ -16,74 +14,3 @@ pub struct ScannerVerdict {
|
|||||||
/// Why this decision was made: "model", "allowlist", etc.
|
/// Why this decision was made: "model", "allowlist", etc.
|
||||||
pub reason: &'static str,
|
pub reason: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ScannerModel {
|
|
||||||
pub weights: [f64; NUM_SCANNER_WEIGHTS],
|
|
||||||
pub threshold: f64,
|
|
||||||
pub norm_params: ScannerNormParams,
|
|
||||||
/// Suspicious path fragments used during training — kept for reproducibility.
|
|
||||||
pub fragments: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ScannerModel {
|
|
||||||
pub fn save(&self, path: &Path) -> Result<()> {
|
|
||||||
let data = bincode::serialize(self).context("serializing scanner model")?;
|
|
||||||
std::fs::write(path, data)
|
|
||||||
.with_context(|| format!("writing scanner model to {}", path.display()))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load(path: &Path) -> Result<Self> {
|
|
||||||
let data = std::fs::read(path)
|
|
||||||
.with_context(|| format!("reading scanner model from {}", path.display()))?;
|
|
||||||
bincode::deserialize(&data).context("deserializing scanner model")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::scanner::features::NUM_SCANNER_FEATURES;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_serialization_roundtrip() {
|
|
||||||
let model = ScannerModel {
|
|
||||||
weights: [0.1; NUM_SCANNER_WEIGHTS],
|
|
||||||
threshold: 0.5,
|
|
||||||
norm_params: ScannerNormParams {
|
|
||||||
mins: [0.0; NUM_SCANNER_FEATURES],
|
|
||||||
maxs: [1.0; NUM_SCANNER_FEATURES],
|
|
||||||
},
|
|
||||||
fragments: vec![".env".into(), "wp-admin".into()],
|
|
||||||
};
|
|
||||||
let data = bincode::serialize(&model).unwrap();
|
|
||||||
let loaded: ScannerModel = bincode::deserialize(&data).unwrap();
|
|
||||||
assert_eq!(loaded.weights, model.weights);
|
|
||||||
assert_eq!(loaded.threshold, model.threshold);
|
|
||||||
assert_eq!(loaded.fragments, model.fragments);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_save_load_file() {
|
|
||||||
let dir = std::env::temp_dir().join("scanner_model_test");
|
|
||||||
std::fs::create_dir_all(&dir).unwrap();
|
|
||||||
let path = dir.join("test_model.bin");
|
|
||||||
|
|
||||||
let model = ScannerModel {
|
|
||||||
weights: [0.5; NUM_SCANNER_WEIGHTS],
|
|
||||||
threshold: 0.42,
|
|
||||||
norm_params: ScannerNormParams {
|
|
||||||
mins: [0.0; NUM_SCANNER_FEATURES],
|
|
||||||
maxs: [1.0; NUM_SCANNER_FEATURES],
|
|
||||||
},
|
|
||||||
fragments: vec!["phpinfo".into()],
|
|
||||||
};
|
|
||||||
model.save(&path).unwrap();
|
|
||||||
let loaded = ScannerModel::load(&path).unwrap();
|
|
||||||
assert_eq!(loaded.threshold, 0.42);
|
|
||||||
assert_eq!(loaded.fragments, vec!["phpinfo"]);
|
|
||||||
|
|
||||||
let _ = std::fs::remove_dir_all(&dir);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,9 +1,30 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
use crate::ddos::audit_log::{AuditLog, AuditFields};
|
use crate::ddos::audit_log::{AuditLog, AuditFields};
|
||||||
use crate::scanner::features::{
|
use crate::scanner::features::{
|
||||||
self, fx_hash_bytes, ScannerFeatureVector, ScannerNormParams, NUM_SCANNER_FEATURES,
|
self, fx_hash_bytes, ScannerFeatureVector, ScannerNormParams, NUM_SCANNER_FEATURES,
|
||||||
NUM_SCANNER_WEIGHTS,
|
NUM_SCANNER_WEIGHTS,
|
||||||
};
|
};
|
||||||
use crate::scanner::model::ScannerModel;
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Legacy linear scanner model — kept for the `train-scanner` CLI command.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ScannerModel {
|
||||||
|
pub weights: [f64; NUM_SCANNER_WEIGHTS],
|
||||||
|
pub threshold: f64,
|
||||||
|
pub norm_params: ScannerNormParams,
|
||||||
|
pub fragments: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScannerModel {
|
||||||
|
pub fn save(&self, path: &Path) -> Result<()> {
|
||||||
|
let data = bincode::serialize(self).context("serializing scanner model")?;
|
||||||
|
std::fs::write(path, data)
|
||||||
|
.with_context(|| format!("writing scanner model to {}", path.display()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use rustc_hash::FxHashSet;
|
use rustc_hash::FxHashSet;
|
||||||
use std::io::BufRead;
|
use std::io::BufRead;
|
||||||
@@ -88,17 +109,9 @@ pub fn train_and_evaluate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (fields, host_prefix) in &parsed_entries {
|
for (fields, host_prefix) in &parsed_entries {
|
||||||
let has_cookies = fields.has_cookies.unwrap_or(false);
|
let has_cookies = fields.has_cookies;
|
||||||
let has_referer = fields
|
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
|
||||||
.referer
|
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
|
||||||
.as_ref()
|
|
||||||
.map(|r| r != "-" && !r.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let has_accept_language = fields
|
|
||||||
.accept_language
|
|
||||||
.as_ref()
|
|
||||||
.map(|a| a != "-" && !a.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let feats = features::extract_features(
|
let feats = features::extract_features(
|
||||||
&fields.method,
|
&fields.method,
|
||||||
@@ -149,17 +162,9 @@ pub fn train_and_evaluate(
|
|||||||
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
|
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
|
||||||
}
|
}
|
||||||
for (fields, host_prefix) in &csic_entries {
|
for (fields, host_prefix) in &csic_entries {
|
||||||
let has_cookies = fields.has_cookies.unwrap_or(false);
|
let has_cookies = fields.has_cookies;
|
||||||
let has_referer = fields
|
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
|
||||||
.referer
|
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
|
||||||
.as_ref()
|
|
||||||
.map(|r| r != "-" && !r.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let has_accept_language = fields
|
|
||||||
.accept_language
|
|
||||||
.as_ref()
|
|
||||||
.map(|a| a != "-" && !a.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let feats = features::extract_features(
|
let feats = features::extract_features(
|
||||||
&fields.method,
|
&fields.method,
|
||||||
@@ -288,17 +293,9 @@ pub fn run(args: TrainScannerArgs) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (fields, host_prefix) in &parsed_entries {
|
for (fields, host_prefix) in &parsed_entries {
|
||||||
let has_cookies = fields.has_cookies.unwrap_or(false);
|
let has_cookies = fields.has_cookies;
|
||||||
let has_referer = fields
|
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
|
||||||
.referer
|
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
|
||||||
.as_ref()
|
|
||||||
.map(|r| r != "-" && !r.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let has_accept_language = fields
|
|
||||||
.accept_language
|
|
||||||
.as_ref()
|
|
||||||
.map(|a| a != "-" && !a.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let feats = features::extract_features(
|
let feats = features::extract_features(
|
||||||
&fields.method,
|
&fields.method,
|
||||||
@@ -352,17 +349,9 @@ pub fn run(args: TrainScannerArgs) -> Result<()> {
|
|||||||
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
|
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
|
||||||
}
|
}
|
||||||
for (fields, host_prefix) in &csic_entries {
|
for (fields, host_prefix) in &csic_entries {
|
||||||
let has_cookies = fields.has_cookies.unwrap_or(false);
|
let has_cookies = fields.has_cookies;
|
||||||
let has_referer = fields
|
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
|
||||||
.referer
|
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
|
||||||
.as_ref()
|
|
||||||
.map(|r| r != "-" && !r.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
let has_accept_language = fields
|
|
||||||
.accept_language
|
|
||||||
.as_ref()
|
|
||||||
.map(|a| a != "-" && !a.is_empty())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let feats = features::extract_features(
|
let feats = features::extract_features(
|
||||||
&fields.method,
|
&fields.method,
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
use crate::config::RouteConfig;
|
|
||||||
use crate::scanner::detector::ScannerDetector;
|
|
||||||
use crate::scanner::model::ScannerModel;
|
|
||||||
use arc_swap::ArcSwap;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
/// Poll the scanner model file for mtime changes and hot-swap the detector.
|
|
||||||
/// Runs forever on a dedicated OS thread — never returns.
|
|
||||||
pub fn watch_scanner_model(
|
|
||||||
handle: Arc<ArcSwap<ScannerDetector>>,
|
|
||||||
model_path: PathBuf,
|
|
||||||
threshold: f64,
|
|
||||||
routes: Vec<RouteConfig>,
|
|
||||||
poll_interval: Duration,
|
|
||||||
) {
|
|
||||||
let mut last_mtime = std::fs::metadata(&model_path)
|
|
||||||
.and_then(|m| m.modified())
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
loop {
|
|
||||||
std::thread::sleep(poll_interval);
|
|
||||||
|
|
||||||
let current_mtime = match std::fs::metadata(&model_path).and_then(|m| m.modified()) {
|
|
||||||
Ok(t) => t,
|
|
||||||
Err(_) => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
if Some(current_mtime) == last_mtime {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
match ScannerModel::load(&model_path) {
|
|
||||||
Ok(mut model) => {
|
|
||||||
model.threshold = threshold;
|
|
||||||
let fragment_count = model.fragments.len();
|
|
||||||
let detector = ScannerDetector::new(&model, &routes);
|
|
||||||
handle.store(Arc::new(detector));
|
|
||||||
last_mtime = Some(current_mtime);
|
|
||||||
tracing::info!(
|
|
||||||
fragments = fragment_count,
|
|
||||||
"scanner model hot-reloaded"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::warn!(error = %e, "failed to reload scanner model; keeping current");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
112
src/training/batch.rs
Normal file
112
src/training/batch.rs
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
//! Shared training infrastructure: Dataset adapter, Batcher, and batch types
|
||||||
|
//! for use with burn's SupervisedTraining.
|
||||||
|
|
||||||
|
use crate::dataset::sample::TrainingSample;
|
||||||
|
|
||||||
|
use burn::data::dataloader::batcher::Batcher;
|
||||||
|
use burn::data::dataloader::Dataset;
|
||||||
|
use burn::prelude::*;
|
||||||
|
|
||||||
|
/// A single normalized training item ready for batching.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct TrainingItem {
|
||||||
|
pub features: Vec<f32>,
|
||||||
|
pub label: i32,
|
||||||
|
pub weight: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A batch of training items as tensors.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct TrainingBatch<B: Backend> {
|
||||||
|
pub features: Tensor<B, 2>,
|
||||||
|
pub labels: Tensor<B, 1, Int>,
|
||||||
|
pub weights: Tensor<B, 2>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wraps a `Vec<TrainingSample>` as a burn `Dataset`, applying min-max
|
||||||
|
/// normalization to features at construction time.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SampleDataset {
|
||||||
|
items: Vec<TrainingItem>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SampleDataset {
|
||||||
|
pub fn new(samples: &[TrainingSample], mins: &[f32], maxs: &[f32]) -> Self {
|
||||||
|
let items = samples
|
||||||
|
.iter()
|
||||||
|
.map(|s| {
|
||||||
|
let features: Vec<f32> = s
|
||||||
|
.features
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &v)| {
|
||||||
|
let range = maxs[i] - mins[i];
|
||||||
|
if range > f32::EPSILON {
|
||||||
|
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
TrainingItem {
|
||||||
|
features,
|
||||||
|
label: if s.label >= 0.5 { 1 } else { 0 },
|
||||||
|
weight: s.weight,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Self { items }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Dataset<TrainingItem> for SampleDataset {
|
||||||
|
fn get(&self, index: usize) -> Option<TrainingItem> {
|
||||||
|
self.items.get(index).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.items.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a `Vec<TrainingItem>` into a `TrainingBatch` of tensors.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct SampleBatcher;
|
||||||
|
|
||||||
|
impl SampleBatcher {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Batcher<B, TrainingItem, TrainingBatch<B>> for SampleBatcher {
|
||||||
|
fn batch(&self, items: Vec<TrainingItem>, device: &B::Device) -> TrainingBatch<B> {
|
||||||
|
let batch_size = items.len();
|
||||||
|
let num_features = items[0].features.len();
|
||||||
|
|
||||||
|
let flat_features: Vec<f32> = items
|
||||||
|
.iter()
|
||||||
|
.flat_map(|item| item.features.iter().copied())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let labels: Vec<i32> = items.iter().map(|item| item.label).collect();
|
||||||
|
let weights: Vec<f32> = items.iter().map(|item| item.weight).collect();
|
||||||
|
|
||||||
|
let features = Tensor::<B, 1>::from_floats(flat_features.as_slice(), device)
|
||||||
|
.reshape([batch_size, num_features]);
|
||||||
|
|
||||||
|
let labels = Tensor::<B, 1, Int>::from_ints(labels.as_slice(), device);
|
||||||
|
|
||||||
|
let weights = Tensor::<B, 1>::from_floats(weights.as_slice(), device)
|
||||||
|
.reshape([batch_size, 1]);
|
||||||
|
|
||||||
|
TrainingBatch {
|
||||||
|
features,
|
||||||
|
labels,
|
||||||
|
weights,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! Weight export: converts trained models into standalone Rust `const` arrays
|
//! Weight export: converts trained models into standalone Rust `const` arrays
|
||||||
//! and optionally Lean 4 definitions.
|
//! and optionally Lean 4 definitions.
|
||||||
//!
|
//!
|
||||||
@@ -54,7 +57,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
|
|||||||
writeln!(s).unwrap();
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
// Threshold.
|
// Threshold.
|
||||||
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", model.threshold).unwrap();
|
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", sanitize(model.threshold)).unwrap();
|
||||||
writeln!(s).unwrap();
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
// Normalization params.
|
// Normalization params.
|
||||||
@@ -74,7 +77,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
|
|||||||
if i > 0 {
|
if i > 0 {
|
||||||
write!(s, ", ").unwrap();
|
write!(s, ", ").unwrap();
|
||||||
}
|
}
|
||||||
write!(s, "{:.8}", v).unwrap();
|
write!(s, "{:.8}", sanitize(*v)).unwrap();
|
||||||
}
|
}
|
||||||
writeln!(s, "],").unwrap();
|
writeln!(s, "],").unwrap();
|
||||||
}
|
}
|
||||||
@@ -88,7 +91,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
|
|||||||
write_f32_array(&mut s, "W2", &model.w2);
|
write_f32_array(&mut s, "W2", &model.w2);
|
||||||
|
|
||||||
// B2.
|
// B2.
|
||||||
writeln!(s, "pub const B2: f32 = {:.8};", model.b2).unwrap();
|
writeln!(s, "pub const B2: f32 = {:.8};", sanitize(model.b2)).unwrap();
|
||||||
writeln!(s).unwrap();
|
writeln!(s).unwrap();
|
||||||
|
|
||||||
// Tree nodes.
|
// Tree nodes.
|
||||||
@@ -207,6 +210,11 @@ pub fn export_to_file(model: &ExportedModel, path: &Path) -> Result<()> {
|
|||||||
// Helpers
|
// Helpers
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Sanitize a float for Rust source: replace NaN/Inf with 0.0.
|
||||||
|
fn sanitize(v: f32) -> f32 {
|
||||||
|
if v.is_finite() { v } else { 0.0 }
|
||||||
|
}
|
||||||
|
|
||||||
fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
|
fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
|
||||||
writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap();
|
writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap();
|
||||||
write!(s, " ").unwrap();
|
write!(s, " ").unwrap();
|
||||||
@@ -218,7 +226,7 @@ fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
|
|||||||
if i > 0 && i % 8 == 0 {
|
if i > 0 && i % 8 == 0 {
|
||||||
write!(s, "\n ").unwrap();
|
write!(s, "\n ").unwrap();
|
||||||
}
|
}
|
||||||
write!(s, "{:.8}", v).unwrap();
|
write!(s, "{:.8}", sanitize(*v)).unwrap();
|
||||||
}
|
}
|
||||||
writeln!(s, "\n];").unwrap();
|
writeln!(s, "\n];").unwrap();
|
||||||
writeln!(s).unwrap();
|
writeln!(s).unwrap();
|
||||||
|
|||||||
@@ -1,11 +1,18 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
//! burn-rs MLP model definition for ensemble training.
|
//! burn-rs MLP model definition for ensemble training.
|
||||||
//!
|
//!
|
||||||
//! A two-layer network (linear -> ReLU -> linear -> sigmoid) used as the
|
//! A two-layer network (linear -> ReLU -> linear -> sigmoid) used as the
|
||||||
//! "uncertain region" classifier in the tree+MLP ensemble.
|
//! "uncertain region" classifier in the tree+MLP ensemble.
|
||||||
|
|
||||||
|
use crate::training::batch::TrainingBatch;
|
||||||
|
|
||||||
use burn::module::Module;
|
use burn::module::Module;
|
||||||
use burn::nn::{Linear, LinearConfig};
|
use burn::nn::{Linear, LinearConfig};
|
||||||
use burn::prelude::*;
|
use burn::prelude::*;
|
||||||
|
use burn::tensor::backend::AutodiffBackend;
|
||||||
|
use burn::train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep};
|
||||||
|
|
||||||
/// Two-layer MLP: input -> hidden (ReLU) -> output (sigmoid).
|
/// Two-layer MLP: input -> hidden (ReLU) -> output (sigmoid).
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
@@ -34,24 +41,79 @@ impl MlpConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> MlpModel<B> {
|
impl<B: Backend> MlpModel<B> {
|
||||||
/// Forward pass: ReLU hidden activation, sigmoid output.
|
/// Forward pass returning raw logits (pre-sigmoid).
|
||||||
///
|
///
|
||||||
/// Input shape: `[batch, input_dim]`
|
/// Input shape: `[batch, input_dim]`
|
||||||
/// Output shape: `[batch, 1]`
|
/// Output shape: `[batch, 1]`
|
||||||
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
pub fn forward_logits(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
let h = self.linear1.forward(x);
|
let h = self.linear1.forward(x);
|
||||||
let h = burn::tensor::activation::relu(h);
|
let h = burn::tensor::activation::relu(h);
|
||||||
let out = self.linear2.forward(h);
|
self.linear2.forward(h)
|
||||||
burn::tensor::activation::sigmoid(out)
|
}
|
||||||
|
|
||||||
|
/// Forward pass with sigmoid activation for inference/export.
|
||||||
|
///
|
||||||
|
/// Input shape: `[batch, input_dim]`
|
||||||
|
/// Output shape: `[batch, 1]` (values in [0, 1])
|
||||||
|
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||||
|
burn::tensor::activation::sigmoid(self.forward_logits(x))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forward pass returning a `ClassificationOutput` for burn's training loop.
|
||||||
|
///
|
||||||
|
/// Uses raw logits for BCE (which applies sigmoid internally) and converts
|
||||||
|
/// to two-column format `[1-p, p]` for AccuracyMetric (which uses argmax).
|
||||||
|
pub fn forward_classification(
|
||||||
|
&self,
|
||||||
|
batch: TrainingBatch<B>,
|
||||||
|
) -> ClassificationOutput<B> {
|
||||||
|
let logits = self.forward_logits(batch.features); // [batch, 1]
|
||||||
|
let logits_1d = logits.clone().squeeze::<1>(); // [batch]
|
||||||
|
|
||||||
|
// Numerically stable BCE from logits:
|
||||||
|
// loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|))
|
||||||
|
// This avoids log(0) and exp(large) overflow.
|
||||||
|
let targets_float = batch.labels.clone().float(); // [batch]
|
||||||
|
let zeros = Tensor::zeros_like(&logits_1d);
|
||||||
|
let relu_logits = logits_1d.clone().max_pair(zeros); // max(logits, 0)
|
||||||
|
let neg_abs = logits_1d.clone().abs().neg(); // -|logits|
|
||||||
|
let log_term = neg_abs.exp().log1p(); // log(1 + exp(-|logits|))
|
||||||
|
let per_sample = relu_logits - logits_1d.clone() * targets_float + log_term;
|
||||||
|
let loss = per_sample.mean(); // scalar [1]
|
||||||
|
|
||||||
|
// AccuracyMetric expects [batch, num_classes] and uses argmax.
|
||||||
|
let neg_logits = logits.clone().neg();
|
||||||
|
let output_2col = Tensor::cat(vec![neg_logits, logits], 1); // [batch, 2]
|
||||||
|
|
||||||
|
ClassificationOutput::new(loss, output_2col, batch.labels)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: AutodiffBackend> TrainStep for MlpModel<B> {
|
||||||
|
type Input = TrainingBatch<B>;
|
||||||
|
type Output = ClassificationOutput<B>;
|
||||||
|
|
||||||
|
fn step(&self, batch: Self::Input) -> TrainOutput<Self::Output> {
|
||||||
|
let item = self.forward_classification(batch);
|
||||||
|
TrainOutput::new(self, item.loss.backward(), item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> InferenceStep for MlpModel<B> {
|
||||||
|
type Input = TrainingBatch<B>;
|
||||||
|
type Output = ClassificationOutput<B>;
|
||||||
|
|
||||||
|
fn step(&self, batch: Self::Input) -> Self::Output {
|
||||||
|
self.forward_classification(batch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use burn::backend::NdArray;
|
use burn::backend::Wgpu;
|
||||||
|
|
||||||
type TestBackend = NdArray<f32>;
|
type TestBackend = Wgpu<f32, i32>;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_forward_pass_shape() {
|
fn test_forward_pass_shape() {
|
||||||
@@ -80,7 +142,6 @@ mod tests {
|
|||||||
};
|
};
|
||||||
let model = config.init::<TestBackend>(&device);
|
let model = config.init::<TestBackend>(&device);
|
||||||
|
|
||||||
// Random-ish input values.
|
|
||||||
let input = Tensor::<TestBackend, 2>::from_data(
|
let input = Tensor::<TestBackend, 2>::from_data(
|
||||||
[[1.0, -2.0, 0.5, 3.0], [0.0, 0.0, 0.0, 0.0]],
|
[[1.0, -2.0, 0.5, 3.0], [0.0, 0.0, 0.0, 0.0]],
|
||||||
&device,
|
&device,
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
pub mod mlp;
|
pub mod mlp;
|
||||||
|
pub mod batch;
|
||||||
pub mod export;
|
pub mod export;
|
||||||
pub mod train_scanner;
|
pub mod train_scanner;
|
||||||
pub mod train_ddos;
|
pub mod train_ddos;
|
||||||
|
pub mod sweep;
|
||||||
|
|||||||
103
src/training/sweep.rs
Normal file
103
src/training/sweep.rs
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
//! Cookie weight sweep: trains full tree+MLP ensembles (GPU via wgpu) across a
|
||||||
|
//! range of cookie_weight values and reports accuracy metrics for each.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use crate::dataset::sample::load_dataset;
|
||||||
|
use crate::training::train_scanner::TrainScannerMlpArgs;
|
||||||
|
use crate::training::train_ddos::TrainDdosMlpArgs;
|
||||||
|
|
||||||
|
/// Run a sweep across cookie_weight values for either scanner or ddos.
|
||||||
|
///
|
||||||
|
/// Each trial does a full GPU training run (tree + MLP) with the specified
|
||||||
|
/// cookie_weight, writing artifacts to a temp directory.
|
||||||
|
pub fn run_cookie_sweep(
|
||||||
|
dataset_path: &str,
|
||||||
|
detector: &str,
|
||||||
|
weights_csv: Option<&str>,
|
||||||
|
tree_max_depth: usize,
|
||||||
|
tree_min_purity: f32,
|
||||||
|
min_samples_leaf: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
// Validate dataset exists and has samples.
|
||||||
|
let manifest = load_dataset(Path::new(dataset_path))
|
||||||
|
.context("loading dataset manifest")?;
|
||||||
|
|
||||||
|
let (cookie_idx, sample_count) = match detector {
|
||||||
|
"scanner" => (3usize, manifest.scanner_samples.len()),
|
||||||
|
"ddos" => (10usize, manifest.ddos_samples.len()),
|
||||||
|
other => anyhow::bail!("unknown detector '{}', expected 'scanner' or 'ddos'", other),
|
||||||
|
};
|
||||||
|
|
||||||
|
anyhow::ensure!(sample_count > 0, "no {} samples in dataset", detector);
|
||||||
|
drop(manifest); // Free memory before training loop.
|
||||||
|
|
||||||
|
let weights: Vec<f32> = if let Some(csv) = weights_csv {
|
||||||
|
csv.split(',')
|
||||||
|
.map(|s| s.trim().parse::<f32>())
|
||||||
|
.collect::<std::result::Result<Vec<_>, _>>()
|
||||||
|
.context("parsing --weights as comma-separated floats")?
|
||||||
|
} else {
|
||||||
|
(0..=10).map(|i| i as f32 / 10.0).collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"[sweep] {} detector, {} samples, cookie feature index: {}",
|
||||||
|
detector, sample_count, cookie_idx,
|
||||||
|
);
|
||||||
|
println!("[sweep] training {} trials with full tree+MLP (wgpu)\n", weights.len());
|
||||||
|
|
||||||
|
let sweep_dir = tempfile::tempdir().context("creating temp dir for sweep")?;
|
||||||
|
|
||||||
|
for (trial, &cw) in weights.iter().enumerate() {
|
||||||
|
let trial_dir = sweep_dir.path().join(format!("trial_{}", trial));
|
||||||
|
std::fs::create_dir_all(&trial_dir)?;
|
||||||
|
let trial_dir_str = trial_dir.to_string_lossy().to_string();
|
||||||
|
|
||||||
|
println!("━━━ Trial {}/{}: cookie_weight={:.2} ━━━", trial + 1, weights.len(), cw);
|
||||||
|
|
||||||
|
match detector {
|
||||||
|
"scanner" => {
|
||||||
|
crate::training::train_scanner::run(TrainScannerMlpArgs {
|
||||||
|
dataset_path: dataset_path.to_string(),
|
||||||
|
output_dir: trial_dir_str,
|
||||||
|
hidden_dim: 32,
|
||||||
|
epochs: 100,
|
||||||
|
learning_rate: 0.0001,
|
||||||
|
batch_size: 64,
|
||||||
|
tree_max_depth,
|
||||||
|
tree_min_purity,
|
||||||
|
min_samples_leaf,
|
||||||
|
cookie_weight: cw,
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
"ddos" => {
|
||||||
|
crate::training::train_ddos::run(TrainDdosMlpArgs {
|
||||||
|
dataset_path: dataset_path.to_string(),
|
||||||
|
output_dir: trial_dir_str,
|
||||||
|
hidden_dim: 32,
|
||||||
|
epochs: 100,
|
||||||
|
learning_rate: 0.0001,
|
||||||
|
batch_size: 64,
|
||||||
|
tree_max_depth,
|
||||||
|
tree_min_purity,
|
||||||
|
min_samples_leaf,
|
||||||
|
cookie_weight: cw,
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("[sweep] All {} trials complete.", weights.len());
|
||||||
|
println!("[sweep] Tip: compare tree structures and validation accuracy above.");
|
||||||
|
println!("[sweep] Look for a cookie_weight where FP rate drops without FN rate spiking.");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,19 +1,27 @@
|
|||||||
//! DDoS MLP+tree training loop.
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
//! DDoS MLP+tree training loop using burn's SupervisedTraining.
|
||||||
//!
|
//!
|
||||||
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP,
|
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP
|
||||||
//! then exports the combined ensemble weights as a Rust source file that can
|
//! with cosine annealing + early stopping, then exports the combined ensemble
|
||||||
//! be dropped into `src/ensemble/gen/ddos_weights.rs`.
|
//! weights as a Rust source file for `src/ensemble/gen/ddos_weights.rs`.
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use burn::backend::ndarray::NdArray;
|
|
||||||
use burn::backend::Autodiff;
|
use burn::backend::Autodiff;
|
||||||
use burn::module::AutodiffModule;
|
use burn::backend::Wgpu;
|
||||||
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
use burn::data::dataloader::DataLoaderBuilder;
|
||||||
|
use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig;
|
||||||
|
use burn::optim::AdamConfig;
|
||||||
use burn::prelude::*;
|
use burn::prelude::*;
|
||||||
|
use burn::record::CompactRecorder;
|
||||||
|
use burn::train::metric::{AccuracyMetric, LossMetric};
|
||||||
|
use burn::train::{Learner, SupervisedTraining};
|
||||||
|
|
||||||
use crate::dataset::sample::{load_dataset, TrainingSample};
|
use crate::dataset::sample::{load_dataset, TrainingSample};
|
||||||
|
use crate::training::batch::{SampleBatcher, SampleDataset};
|
||||||
use crate::training::export::{export_to_file, ExportedModel};
|
use crate::training::export::{export_to_file, ExportedModel};
|
||||||
use crate::training::mlp::MlpConfig;
|
use crate::training::mlp::MlpConfig;
|
||||||
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
||||||
@@ -21,7 +29,7 @@ use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
|||||||
/// Number of DDoS features (matches `crate::ddos::features::NUM_FEATURES`).
|
/// Number of DDoS features (matches `crate::ddos::features::NUM_FEATURES`).
|
||||||
const NUM_FEATURES: usize = 14;
|
const NUM_FEATURES: usize = 14;
|
||||||
|
|
||||||
type TrainBackend = Autodiff<NdArray<f32>>;
|
type TrainBackend = Autodiff<Wgpu<f32, i32>>;
|
||||||
|
|
||||||
/// Arguments for the DDoS MLP training command.
|
/// Arguments for the DDoS MLP training command.
|
||||||
pub struct TrainDdosMlpArgs {
|
pub struct TrainDdosMlpArgs {
|
||||||
@@ -37,10 +45,14 @@ pub struct TrainDdosMlpArgs {
|
|||||||
pub learning_rate: f64,
|
pub learning_rate: f64,
|
||||||
/// Mini-batch size (default 64).
|
/// Mini-batch size (default 64).
|
||||||
pub batch_size: usize,
|
pub batch_size: usize,
|
||||||
/// CART max depth (default 6).
|
/// CART max depth (default 8).
|
||||||
pub tree_max_depth: usize,
|
pub tree_max_depth: usize,
|
||||||
/// CART leaf purity threshold (default 0.90).
|
/// CART leaf purity threshold (default 0.98).
|
||||||
pub tree_min_purity: f32,
|
pub tree_min_purity: f32,
|
||||||
|
/// Minimum samples in a leaf node (default 2).
|
||||||
|
pub min_samples_leaf: usize,
|
||||||
|
/// Weight for cookie feature (feature 10: cookie_ratio). 0.0 = ignore, 1.0 = full weight.
|
||||||
|
pub cookie_weight: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TrainDdosMlpArgs {
|
impl Default for TrainDdosMlpArgs {
|
||||||
@@ -50,14 +62,19 @@ impl Default for TrainDdosMlpArgs {
|
|||||||
output_dir: ".".into(),
|
output_dir: ".".into(),
|
||||||
hidden_dim: 32,
|
hidden_dim: 32,
|
||||||
epochs: 100,
|
epochs: 100,
|
||||||
learning_rate: 0.001,
|
learning_rate: 0.0001,
|
||||||
batch_size: 64,
|
batch_size: 64,
|
||||||
tree_max_depth: 6,
|
tree_max_depth: 8,
|
||||||
tree_min_purity: 0.90,
|
tree_min_purity: 0.98,
|
||||||
|
min_samples_leaf: 2,
|
||||||
|
cookie_weight: 1.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Index of the cookie_ratio feature in the DDoS feature vector.
|
||||||
|
const COOKIE_FEATURE_IDX: usize = 10;
|
||||||
|
|
||||||
/// Entry point: train DDoS ensemble and export weights.
|
/// Entry point: train DDoS ensemble and export weights.
|
||||||
pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
||||||
// 1. Load dataset.
|
// 1. Load dataset.
|
||||||
@@ -86,6 +103,23 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
|||||||
// 2. Compute normalization params from training data.
|
// 2. Compute normalization params from training data.
|
||||||
let (norm_mins, norm_maxs) = compute_norm_params(samples);
|
let (norm_mins, norm_maxs) = compute_norm_params(samples);
|
||||||
|
|
||||||
|
if args.cookie_weight < 1.0 - f32::EPSILON {
|
||||||
|
println!(
|
||||||
|
"[ddos] cookie_weight={:.2} (feature {} influence reduced)",
|
||||||
|
args.cookie_weight, COOKIE_FEATURE_IDX,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLP norm adjustment: scale cookie feature's normalization range.
|
||||||
|
let mut mlp_norm_maxs = norm_maxs.clone();
|
||||||
|
if args.cookie_weight < 1.0 - f32::EPSILON {
|
||||||
|
let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX];
|
||||||
|
if range > f32::EPSILON && args.cookie_weight > f32::EPSILON {
|
||||||
|
mlp_norm_maxs[COOKIE_FEATURE_IDX] =
|
||||||
|
range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 3. Stratified 80/20 split.
|
// 3. Stratified 80/20 split.
|
||||||
let (train_set, val_set) = stratified_split(samples, 0.8);
|
let (train_set, val_set) = stratified_split(samples, 0.8);
|
||||||
println!(
|
println!(
|
||||||
@@ -94,15 +128,16 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
|||||||
val_set.len()
|
val_set.len()
|
||||||
);
|
);
|
||||||
|
|
||||||
// 4. Train CART tree.
|
// 4. Train CART tree (with cookie feature masking for reduced weight).
|
||||||
|
let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight);
|
||||||
let tree_config = TreeConfig {
|
let tree_config = TreeConfig {
|
||||||
max_depth: args.tree_max_depth,
|
max_depth: args.tree_max_depth,
|
||||||
min_samples_leaf: 5,
|
min_samples_leaf: args.min_samples_leaf,
|
||||||
min_purity: args.tree_min_purity,
|
min_purity: args.tree_min_purity,
|
||||||
num_features: NUM_FEATURES,
|
num_features: NUM_FEATURES,
|
||||||
};
|
};
|
||||||
let tree_nodes = train_tree(&train_set, &tree_config);
|
let tree_nodes = train_tree(&tree_train_set, &tree_config);
|
||||||
println!("[ddos] CART tree: {} nodes", tree_nodes.len());
|
println!("[ddos] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth);
|
||||||
|
|
||||||
// Evaluate tree on validation set.
|
// Evaluate tree on validation set.
|
||||||
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
|
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
|
||||||
@@ -112,23 +147,27 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
|||||||
tree_deferred * 100.0,
|
tree_deferred * 100.0,
|
||||||
);
|
);
|
||||||
|
|
||||||
// 5. Train MLP on the full training set.
|
// 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling).
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
let mlp_config = MlpConfig {
|
let mlp_config = MlpConfig {
|
||||||
input_dim: NUM_FEATURES,
|
input_dim: NUM_FEATURES,
|
||||||
hidden_dim: args.hidden_dim,
|
hidden_dim: args.hidden_dim,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let artifact_dir = Path::new(&args.output_dir).join("ddos_artifacts");
|
||||||
|
std::fs::create_dir_all(&artifact_dir).ok();
|
||||||
|
|
||||||
let model = train_mlp(
|
let model = train_mlp(
|
||||||
&train_set,
|
&train_set,
|
||||||
&val_set,
|
&val_set,
|
||||||
&mlp_config,
|
&mlp_config,
|
||||||
&norm_mins,
|
&norm_mins,
|
||||||
&norm_maxs,
|
&mlp_norm_maxs,
|
||||||
args.epochs,
|
args.epochs,
|
||||||
args.learning_rate,
|
args.learning_rate,
|
||||||
args.batch_size,
|
args.batch_size,
|
||||||
&device,
|
&device,
|
||||||
|
&artifact_dir,
|
||||||
);
|
);
|
||||||
|
|
||||||
// 6. Extract weights from trained model.
|
// 6. Extract weights from trained model.
|
||||||
@@ -136,9 +175,9 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
|||||||
&model,
|
&model,
|
||||||
"ddos",
|
"ddos",
|
||||||
&tree_nodes,
|
&tree_nodes,
|
||||||
0.5, // threshold
|
0.5,
|
||||||
&norm_mins,
|
&norm_mins,
|
||||||
&norm_maxs,
|
&mlp_norm_maxs,
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -153,6 +192,37 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Cookie feature masking for CART trees
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn mask_cookie_feature(
|
||||||
|
samples: &[TrainingSample],
|
||||||
|
cookie_idx: usize,
|
||||||
|
cookie_weight: f32,
|
||||||
|
) -> Vec<TrainingSample> {
|
||||||
|
if cookie_weight >= 1.0 - f32::EPSILON {
|
||||||
|
return samples.to_vec();
|
||||||
|
}
|
||||||
|
samples
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, s)| {
|
||||||
|
let mut s2 = s.clone();
|
||||||
|
if cookie_weight < f32::EPSILON {
|
||||||
|
s2.features[cookie_idx] = 0.5;
|
||||||
|
} else {
|
||||||
|
let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42);
|
||||||
|
let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32;
|
||||||
|
if r > cookie_weight {
|
||||||
|
s2.features[cookie_idx] = 0.5;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s2
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Normalization
|
// Normalization
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -170,21 +240,6 @@ fn compute_norm_params(samples: &[TrainingSample]) -> (Vec<f32>, Vec<f32>) {
|
|||||||
(mins, maxs)
|
(mins, maxs)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
|
||||||
features
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, &v)| {
|
|
||||||
let range = maxs[i] - mins[i];
|
|
||||||
if range > f32::EPSILON {
|
|
||||||
((v - mins[i]) / range).clamp(0.0, 1.0)
|
|
||||||
} else {
|
|
||||||
0.0
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Stratified split
|
// Stratified split
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -272,8 +327,23 @@ fn eval_tree(
|
|||||||
(accuracy, defer_rate)
|
(accuracy, defer_rate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
||||||
|
features
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &v)| {
|
||||||
|
let range = maxs[i] - mins[i];
|
||||||
|
if range > f32::EPSILON {
|
||||||
|
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// MLP training
|
// MLP training via SupervisedTraining
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
fn train_mlp(
|
fn train_mlp(
|
||||||
@@ -286,117 +356,47 @@ fn train_mlp(
|
|||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
batch_size: usize,
|
batch_size: usize,
|
||||||
device: &<TrainBackend as Backend>::Device,
|
device: &<TrainBackend as Backend>::Device,
|
||||||
) -> crate::training::mlp::MlpModel<NdArray<f32>> {
|
artifact_dir: &Path,
|
||||||
let mut model = config.init::<TrainBackend>(device);
|
) -> crate::training::mlp::MlpModel<Wgpu<f32, i32>> {
|
||||||
let mut optim = AdamConfig::new().init();
|
let model = config.init::<TrainBackend>(device);
|
||||||
|
|
||||||
// Pre-normalize all training data.
|
let train_dataset = SampleDataset::new(train_set, mins, maxs);
|
||||||
let train_features: Vec<Vec<f32>> = train_set
|
let val_dataset = SampleDataset::new(val_set, mins, maxs);
|
||||||
.iter()
|
|
||||||
.map(|s| normalize_features(&s.features, mins, maxs))
|
|
||||||
.collect();
|
|
||||||
let train_labels: Vec<f32> = train_set.iter().map(|s| s.label).collect();
|
|
||||||
let train_weights: Vec<f32> = train_set.iter().map(|s| s.weight).collect();
|
|
||||||
|
|
||||||
let n = train_features.len();
|
let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new())
|
||||||
|
.batch_size(batch_size)
|
||||||
|
.shuffle(42)
|
||||||
|
.num_workers(1)
|
||||||
|
.build(train_dataset);
|
||||||
|
|
||||||
for epoch in 0..epochs {
|
let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new())
|
||||||
let mut epoch_loss = 0.0f32;
|
.batch_size(batch_size)
|
||||||
let mut batches = 0usize;
|
.num_workers(1)
|
||||||
|
.build(val_dataset);
|
||||||
|
|
||||||
let mut offset = 0;
|
// Cosine annealing: initial_lr must be in (0.0, 1.0].
|
||||||
while offset < n {
|
let lr = learning_rate.min(1.0);
|
||||||
let end = (offset + batch_size).min(n);
|
let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs)
|
||||||
let batch_n = end - offset;
|
.init()
|
||||||
|
.expect("valid cosine annealing config");
|
||||||
|
|
||||||
// Build input tensor [batch, features].
|
let learner = Learner::new(
|
||||||
let flat: Vec<f32> = train_features[offset..end]
|
model,
|
||||||
.iter()
|
AdamConfig::new().init(),
|
||||||
.flat_map(|f| f.iter().copied())
|
lr_scheduler,
|
||||||
.collect();
|
);
|
||||||
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
|
||||||
.reshape([batch_n, NUM_FEATURES]);
|
|
||||||
|
|
||||||
// Labels [batch, 1].
|
let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid)
|
||||||
let y = Tensor::<TrainBackend, 1>::from_floats(
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
&train_labels[offset..end],
|
.metric_valid_numeric(AccuracyMetric::new())
|
||||||
device,
|
.metric_train_numeric(LossMetric::new())
|
||||||
)
|
.metric_valid_numeric(LossMetric::new())
|
||||||
.reshape([batch_n, 1]);
|
.with_file_checkpointer(CompactRecorder::new())
|
||||||
|
.num_epochs(epochs)
|
||||||
|
.summary()
|
||||||
|
.launch(learner);
|
||||||
|
|
||||||
// Sample weights [batch, 1].
|
result.model
|
||||||
let w = Tensor::<TrainBackend, 1>::from_floats(
|
|
||||||
&train_weights[offset..end],
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
.reshape([batch_n, 1]);
|
|
||||||
|
|
||||||
// Forward pass.
|
|
||||||
let pred = model.forward(x);
|
|
||||||
|
|
||||||
// Binary cross-entropy with sample weights.
|
|
||||||
let eps = 1e-7;
|
|
||||||
let pred_clamped = pred.clone().clamp(eps, 1.0 - eps);
|
|
||||||
let bce = (y.clone() * pred_clamped.clone().log()
|
|
||||||
+ (y.clone().neg().add_scalar(1.0))
|
|
||||||
* pred_clamped.neg().add_scalar(1.0).log())
|
|
||||||
.neg();
|
|
||||||
let weighted_bce = bce * w;
|
|
||||||
let loss = weighted_bce.mean();
|
|
||||||
|
|
||||||
epoch_loss += loss.clone().into_scalar().elem::<f32>();
|
|
||||||
batches += 1;
|
|
||||||
|
|
||||||
// Backward + optimizer step.
|
|
||||||
let grads = loss.backward();
|
|
||||||
let grads = GradientsParams::from_grads(grads, &model);
|
|
||||||
model = optim.step(learning_rate, model, grads);
|
|
||||||
|
|
||||||
offset = end;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (epoch + 1) % 10 == 0 || epoch == 0 {
|
|
||||||
let avg_loss = epoch_loss / batches as f32;
|
|
||||||
let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device);
|
|
||||||
println!(
|
|
||||||
"[ddos] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}",
|
|
||||||
epoch + 1,
|
|
||||||
epochs,
|
|
||||||
avg_loss,
|
|
||||||
val_acc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
model.valid()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn eval_mlp_accuracy(
|
|
||||||
model: &crate::training::mlp::MlpModel<TrainBackend>,
|
|
||||||
val_set: &[TrainingSample],
|
|
||||||
mins: &[f32],
|
|
||||||
maxs: &[f32],
|
|
||||||
device: &<TrainBackend as Backend>::Device,
|
|
||||||
) -> f64 {
|
|
||||||
let flat: Vec<f32> = val_set
|
|
||||||
.iter()
|
|
||||||
.flat_map(|s| normalize_features(&s.features, mins, maxs))
|
|
||||||
.collect();
|
|
||||||
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
|
||||||
.reshape([val_set.len(), NUM_FEATURES]);
|
|
||||||
|
|
||||||
let pred = model.forward(x);
|
|
||||||
let pred_data: Vec<f32> = pred.to_data().to_vec().expect("flat vec");
|
|
||||||
|
|
||||||
let mut correct = 0usize;
|
|
||||||
for (i, s) in val_set.iter().enumerate() {
|
|
||||||
let p = pred_data[i];
|
|
||||||
let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 };
|
|
||||||
if (predicted_label - s.label).abs() < 0.1 {
|
|
||||||
correct += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
correct as f64 / val_set.len() as f64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -404,13 +404,13 @@ fn eval_mlp_accuracy(
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
fn extract_weights(
|
fn extract_weights(
|
||||||
model: &crate::training::mlp::MlpModel<NdArray<f32>>,
|
model: &crate::training::mlp::MlpModel<Wgpu<f32, i32>>,
|
||||||
name: &str,
|
name: &str,
|
||||||
tree_nodes: &[(u8, f32, u16, u16)],
|
tree_nodes: &[(u8, f32, u16, u16)],
|
||||||
threshold: f32,
|
threshold: f32,
|
||||||
norm_mins: &[f32],
|
norm_mins: &[f32],
|
||||||
norm_maxs: &[f32],
|
norm_maxs: &[f32],
|
||||||
_device: &<NdArray<f32> as Backend>::Device,
|
_device: &<Wgpu<f32, i32> as Backend>::Device,
|
||||||
) -> ExportedModel {
|
) -> ExportedModel {
|
||||||
let w1_tensor = model.linear1.weight.val();
|
let w1_tensor = model.linear1.weight.val();
|
||||||
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
|
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
|
||||||
|
|||||||
@@ -1,19 +1,27 @@
|
|||||||
//! Scanner MLP+tree training loop.
|
// Copyright Sunbeam Studios 2026
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
//! Scanner MLP+tree training loop using burn's SupervisedTraining.
|
||||||
//!
|
//!
|
||||||
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP,
|
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP
|
||||||
//! then exports the combined ensemble weights as a Rust source file that can
|
//! with cosine annealing + early stopping, then exports the combined ensemble
|
||||||
//! be dropped into `src/ensemble/gen/scanner_weights.rs`.
|
//! weights as a Rust source file for `src/ensemble/gen/scanner_weights.rs`.
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use burn::backend::ndarray::NdArray;
|
|
||||||
use burn::backend::Autodiff;
|
use burn::backend::Autodiff;
|
||||||
use burn::module::AutodiffModule;
|
use burn::backend::Wgpu;
|
||||||
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
use burn::data::dataloader::DataLoaderBuilder;
|
||||||
|
use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig;
|
||||||
|
use burn::optim::AdamConfig;
|
||||||
use burn::prelude::*;
|
use burn::prelude::*;
|
||||||
|
use burn::record::CompactRecorder;
|
||||||
|
use burn::train::metric::{AccuracyMetric, LossMetric};
|
||||||
|
use burn::train::{Learner, SupervisedTraining};
|
||||||
|
|
||||||
use crate::dataset::sample::{load_dataset, TrainingSample};
|
use crate::dataset::sample::{load_dataset, TrainingSample};
|
||||||
|
use crate::training::batch::{SampleBatcher, SampleDataset};
|
||||||
use crate::training::export::{export_to_file, ExportedModel};
|
use crate::training::export::{export_to_file, ExportedModel};
|
||||||
use crate::training::mlp::MlpConfig;
|
use crate::training::mlp::MlpConfig;
|
||||||
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
||||||
@@ -21,7 +29,7 @@ use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
|
|||||||
/// Number of scanner features (matches `crate::scanner::features::NUM_SCANNER_FEATURES`).
|
/// Number of scanner features (matches `crate::scanner::features::NUM_SCANNER_FEATURES`).
|
||||||
const NUM_FEATURES: usize = 12;
|
const NUM_FEATURES: usize = 12;
|
||||||
|
|
||||||
type TrainBackend = Autodiff<NdArray<f32>>;
|
type TrainBackend = Autodiff<Wgpu<f32, i32>>;
|
||||||
|
|
||||||
/// Arguments for the scanner MLP training command.
|
/// Arguments for the scanner MLP training command.
|
||||||
pub struct TrainScannerMlpArgs {
|
pub struct TrainScannerMlpArgs {
|
||||||
@@ -37,10 +45,14 @@ pub struct TrainScannerMlpArgs {
|
|||||||
pub learning_rate: f64,
|
pub learning_rate: f64,
|
||||||
/// Mini-batch size (default 64).
|
/// Mini-batch size (default 64).
|
||||||
pub batch_size: usize,
|
pub batch_size: usize,
|
||||||
/// CART max depth (default 6).
|
/// CART max depth (default 8).
|
||||||
pub tree_max_depth: usize,
|
pub tree_max_depth: usize,
|
||||||
/// CART leaf purity threshold (default 0.90).
|
/// CART leaf purity threshold (default 0.98).
|
||||||
pub tree_min_purity: f32,
|
pub tree_min_purity: f32,
|
||||||
|
/// Minimum samples in a leaf node (default 2).
|
||||||
|
pub min_samples_leaf: usize,
|
||||||
|
/// Weight for cookie feature (feature 3: has_cookies). 0.0 = ignore, 1.0 = full weight.
|
||||||
|
pub cookie_weight: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TrainScannerMlpArgs {
|
impl Default for TrainScannerMlpArgs {
|
||||||
@@ -50,14 +62,19 @@ impl Default for TrainScannerMlpArgs {
|
|||||||
output_dir: ".".into(),
|
output_dir: ".".into(),
|
||||||
hidden_dim: 32,
|
hidden_dim: 32,
|
||||||
epochs: 100,
|
epochs: 100,
|
||||||
learning_rate: 0.001,
|
learning_rate: 0.0001,
|
||||||
batch_size: 64,
|
batch_size: 64,
|
||||||
tree_max_depth: 6,
|
tree_max_depth: 8,
|
||||||
tree_min_purity: 0.90,
|
tree_min_purity: 0.98,
|
||||||
|
min_samples_leaf: 2,
|
||||||
|
cookie_weight: 1.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Index of the has_cookies feature in the scanner feature vector.
|
||||||
|
const COOKIE_FEATURE_IDX: usize = 3;
|
||||||
|
|
||||||
/// Entry point: train scanner ensemble and export weights.
|
/// Entry point: train scanner ensemble and export weights.
|
||||||
pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
||||||
// 1. Load dataset.
|
// 1. Load dataset.
|
||||||
@@ -86,6 +103,27 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
|||||||
// 2. Compute normalization params from training data.
|
// 2. Compute normalization params from training data.
|
||||||
let (norm_mins, norm_maxs) = compute_norm_params(samples);
|
let (norm_mins, norm_maxs) = compute_norm_params(samples);
|
||||||
|
|
||||||
|
// Apply cookie_weight: for the MLP, we scale the normalization range so
|
||||||
|
// the feature contributes less gradient signal. For the CART tree, scaling
|
||||||
|
// doesn't help (the tree just adjusts its threshold), so we mask the feature
|
||||||
|
// to a constant on a fraction of training samples to degrade its Gini gain.
|
||||||
|
if args.cookie_weight < 1.0 - f32::EPSILON {
|
||||||
|
println!(
|
||||||
|
"[scanner] cookie_weight={:.2} (feature {} influence reduced)",
|
||||||
|
args.cookie_weight, COOKIE_FEATURE_IDX,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLP norm adjustment: scale the cookie feature's normalization range.
|
||||||
|
let mut mlp_norm_maxs = norm_maxs.clone();
|
||||||
|
if args.cookie_weight < 1.0 - f32::EPSILON {
|
||||||
|
let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX];
|
||||||
|
if range > f32::EPSILON && args.cookie_weight > f32::EPSILON {
|
||||||
|
mlp_norm_maxs[COOKIE_FEATURE_IDX] =
|
||||||
|
range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 3. Stratified 80/20 split.
|
// 3. Stratified 80/20 split.
|
||||||
let (train_set, val_set) = stratified_split(samples, 0.8);
|
let (train_set, val_set) = stratified_split(samples, 0.8);
|
||||||
println!(
|
println!(
|
||||||
@@ -94,17 +132,18 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
|||||||
val_set.len()
|
val_set.len()
|
||||||
);
|
);
|
||||||
|
|
||||||
// 4. Train CART tree.
|
// 4. Train CART tree (with cookie feature masking for reduced weight).
|
||||||
|
let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight);
|
||||||
let tree_config = TreeConfig {
|
let tree_config = TreeConfig {
|
||||||
max_depth: args.tree_max_depth,
|
max_depth: args.tree_max_depth,
|
||||||
min_samples_leaf: 5,
|
min_samples_leaf: args.min_samples_leaf,
|
||||||
min_purity: args.tree_min_purity,
|
min_purity: args.tree_min_purity,
|
||||||
num_features: NUM_FEATURES,
|
num_features: NUM_FEATURES,
|
||||||
};
|
};
|
||||||
let tree_nodes = train_tree(&train_set, &tree_config);
|
let tree_nodes = train_tree(&tree_train_set, &tree_config);
|
||||||
println!("[scanner] CART tree: {} nodes", tree_nodes.len());
|
println!("[scanner] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth);
|
||||||
|
|
||||||
// Evaluate tree on validation set.
|
// Evaluate tree on validation set (use original norms — tree learned on masked features).
|
||||||
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
|
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
|
||||||
println!(
|
println!(
|
||||||
"[scanner] tree validation: {:.2}% correct (of decided), {:.1}% deferred",
|
"[scanner] tree validation: {:.2}% correct (of decided), {:.1}% deferred",
|
||||||
@@ -112,35 +151,38 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
|||||||
tree_deferred * 100.0,
|
tree_deferred * 100.0,
|
||||||
);
|
);
|
||||||
|
|
||||||
// 5. Train MLP on the full training set (the MLP only fires on Defer
|
// 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling).
|
||||||
// at inference time, but we train it on all data so it learns the
|
|
||||||
// full decision boundary).
|
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
let mlp_config = MlpConfig {
|
let mlp_config = MlpConfig {
|
||||||
input_dim: NUM_FEATURES,
|
input_dim: NUM_FEATURES,
|
||||||
hidden_dim: args.hidden_dim,
|
hidden_dim: args.hidden_dim,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let artifact_dir = Path::new(&args.output_dir).join("scanner_artifacts");
|
||||||
|
std::fs::create_dir_all(&artifact_dir).ok();
|
||||||
|
|
||||||
let model = train_mlp(
|
let model = train_mlp(
|
||||||
&train_set,
|
&train_set,
|
||||||
&val_set,
|
&val_set,
|
||||||
&mlp_config,
|
&mlp_config,
|
||||||
&norm_mins,
|
&norm_mins,
|
||||||
&norm_maxs,
|
&mlp_norm_maxs,
|
||||||
args.epochs,
|
args.epochs,
|
||||||
args.learning_rate,
|
args.learning_rate,
|
||||||
args.batch_size,
|
args.batch_size,
|
||||||
&device,
|
&device,
|
||||||
|
&artifact_dir,
|
||||||
);
|
);
|
||||||
|
|
||||||
// 6. Extract weights from trained model.
|
// 6. Extract weights from trained model (export mlp_norm_maxs so inference
|
||||||
|
// automatically applies the same cookie scaling).
|
||||||
let exported = extract_weights(
|
let exported = extract_weights(
|
||||||
&model,
|
&model,
|
||||||
"scanner",
|
"scanner",
|
||||||
&tree_nodes,
|
&tree_nodes,
|
||||||
0.5, // threshold
|
0.5,
|
||||||
&norm_mins,
|
&norm_mins,
|
||||||
&norm_maxs,
|
&mlp_norm_maxs,
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -155,6 +197,46 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Cookie feature masking for CART trees
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Mask the cookie feature to reduce its influence on CART tree training.
|
||||||
|
///
|
||||||
|
/// Scaling a binary feature doesn't reduce its Gini gain — the tree just adjusts
|
||||||
|
/// the split threshold. Instead, we mask (set to 0.5) a fraction of samples so
|
||||||
|
/// the feature's apparent class-separation degrades.
|
||||||
|
///
|
||||||
|
/// - `cookie_weight = 0.0` → fully masked (feature is constant 0.5, zero info gain)
|
||||||
|
/// - `cookie_weight = 0.5` → 50% of samples masked (noisy, reduced gain)
|
||||||
|
/// - `cookie_weight = 1.0` → no masking (full feature)
|
||||||
|
fn mask_cookie_feature(
|
||||||
|
samples: &[TrainingSample],
|
||||||
|
cookie_idx: usize,
|
||||||
|
cookie_weight: f32,
|
||||||
|
) -> Vec<TrainingSample> {
|
||||||
|
if cookie_weight >= 1.0 - f32::EPSILON {
|
||||||
|
return samples.to_vec();
|
||||||
|
}
|
||||||
|
samples
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, s)| {
|
||||||
|
let mut s2 = s.clone();
|
||||||
|
if cookie_weight < f32::EPSILON {
|
||||||
|
s2.features[cookie_idx] = 0.5;
|
||||||
|
} else {
|
||||||
|
let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42);
|
||||||
|
let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32;
|
||||||
|
if r > cookie_weight {
|
||||||
|
s2.features[cookie_idx] = 0.5;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s2
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Normalization
|
// Normalization
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -172,21 +254,6 @@ fn compute_norm_params(samples: &[TrainingSample]) -> (Vec<f32>, Vec<f32>) {
|
|||||||
(mins, maxs)
|
(mins, maxs)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
|
||||||
features
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, &v)| {
|
|
||||||
let range = maxs[i] - mins[i];
|
|
||||||
if range > f32::EPSILON {
|
|
||||||
((v - mins[i]) / range).clamp(0.0, 1.0)
|
|
||||||
} else {
|
|
||||||
0.0
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Stratified split
|
// Stratified split
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -195,7 +262,6 @@ fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec<Traini
|
|||||||
let mut attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label >= 0.5).collect();
|
let mut attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label >= 0.5).collect();
|
||||||
let mut normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label < 0.5).collect();
|
let mut normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label < 0.5).collect();
|
||||||
|
|
||||||
// Deterministic shuffle using a simple index permutation seeded by length.
|
|
||||||
deterministic_shuffle(&mut attacks);
|
deterministic_shuffle(&mut attacks);
|
||||||
deterministic_shuffle(&mut normals);
|
deterministic_shuffle(&mut normals);
|
||||||
|
|
||||||
@@ -224,7 +290,6 @@ fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec<Traini
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn deterministic_shuffle<T>(items: &mut [T]) {
|
fn deterministic_shuffle<T>(items: &mut [T]) {
|
||||||
// Simple Fisher-Yates with a fixed LCG seed for reproducibility.
|
|
||||||
let mut rng = 42u64;
|
let mut rng = 42u64;
|
||||||
for i in (1..items.len()).rev() {
|
for i in (1..items.len()).rev() {
|
||||||
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
|
||||||
@@ -276,8 +341,23 @@ fn eval_tree(
|
|||||||
(accuracy, defer_rate)
|
(accuracy, defer_rate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
|
||||||
|
features
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &v)| {
|
||||||
|
let range = maxs[i] - mins[i];
|
||||||
|
if range > f32::EPSILON {
|
||||||
|
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// MLP training
|
// MLP training via SupervisedTraining
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
fn train_mlp(
|
fn train_mlp(
|
||||||
@@ -290,119 +370,47 @@ fn train_mlp(
|
|||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
batch_size: usize,
|
batch_size: usize,
|
||||||
device: &<TrainBackend as Backend>::Device,
|
device: &<TrainBackend as Backend>::Device,
|
||||||
) -> crate::training::mlp::MlpModel<NdArray<f32>> {
|
artifact_dir: &Path,
|
||||||
let mut model = config.init::<TrainBackend>(device);
|
) -> crate::training::mlp::MlpModel<Wgpu<f32, i32>> {
|
||||||
let mut optim = AdamConfig::new().init();
|
let model = config.init::<TrainBackend>(device);
|
||||||
|
|
||||||
// Pre-normalize all training data.
|
let train_dataset = SampleDataset::new(train_set, mins, maxs);
|
||||||
let train_features: Vec<Vec<f32>> = train_set
|
let val_dataset = SampleDataset::new(val_set, mins, maxs);
|
||||||
.iter()
|
|
||||||
.map(|s| normalize_features(&s.features, mins, maxs))
|
|
||||||
.collect();
|
|
||||||
let train_labels: Vec<f32> = train_set.iter().map(|s| s.label).collect();
|
|
||||||
let train_weights: Vec<f32> = train_set.iter().map(|s| s.weight).collect();
|
|
||||||
|
|
||||||
let n = train_features.len();
|
let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new())
|
||||||
|
.batch_size(batch_size)
|
||||||
|
.shuffle(42)
|
||||||
|
.num_workers(1)
|
||||||
|
.build(train_dataset);
|
||||||
|
|
||||||
for epoch in 0..epochs {
|
let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new())
|
||||||
let mut epoch_loss = 0.0f32;
|
.batch_size(batch_size)
|
||||||
let mut batches = 0usize;
|
.num_workers(1)
|
||||||
|
.build(val_dataset);
|
||||||
|
|
||||||
let mut offset = 0;
|
// Cosine annealing: initial_lr must be in (0.0, 1.0].
|
||||||
while offset < n {
|
let lr = learning_rate.min(1.0);
|
||||||
let end = (offset + batch_size).min(n);
|
let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs)
|
||||||
let batch_n = end - offset;
|
.init()
|
||||||
|
.expect("valid cosine annealing config");
|
||||||
|
|
||||||
// Build input tensor [batch, features].
|
let learner = Learner::new(
|
||||||
let flat: Vec<f32> = train_features[offset..end]
|
model,
|
||||||
.iter()
|
AdamConfig::new().init(),
|
||||||
.flat_map(|f| f.iter().copied())
|
lr_scheduler,
|
||||||
.collect();
|
);
|
||||||
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
|
||||||
.reshape([batch_n, NUM_FEATURES]);
|
|
||||||
|
|
||||||
// Labels [batch, 1].
|
let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid)
|
||||||
let y = Tensor::<TrainBackend, 1>::from_floats(
|
.metric_train_numeric(AccuracyMetric::new())
|
||||||
&train_labels[offset..end],
|
.metric_valid_numeric(AccuracyMetric::new())
|
||||||
device,
|
.metric_train_numeric(LossMetric::new())
|
||||||
)
|
.metric_valid_numeric(LossMetric::new())
|
||||||
.reshape([batch_n, 1]);
|
.with_file_checkpointer(CompactRecorder::new())
|
||||||
|
.num_epochs(epochs)
|
||||||
|
.summary()
|
||||||
|
.launch(learner);
|
||||||
|
|
||||||
// Sample weights [batch, 1].
|
result.model
|
||||||
let w = Tensor::<TrainBackend, 1>::from_floats(
|
|
||||||
&train_weights[offset..end],
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
.reshape([batch_n, 1]);
|
|
||||||
|
|
||||||
// Forward pass.
|
|
||||||
let pred = model.forward(x);
|
|
||||||
|
|
||||||
// Binary cross-entropy with sample weights:
|
|
||||||
// loss = -w * [y * log(p) + (1-y) * log(1-p)]
|
|
||||||
let eps = 1e-7;
|
|
||||||
let pred_clamped = pred.clone().clamp(eps, 1.0 - eps);
|
|
||||||
let bce = (y.clone() * pred_clamped.clone().log()
|
|
||||||
+ (y.clone().neg().add_scalar(1.0))
|
|
||||||
* pred_clamped.neg().add_scalar(1.0).log())
|
|
||||||
.neg();
|
|
||||||
let weighted_bce = bce * w;
|
|
||||||
let loss = weighted_bce.mean();
|
|
||||||
|
|
||||||
epoch_loss += loss.clone().into_scalar().elem::<f32>();
|
|
||||||
batches += 1;
|
|
||||||
|
|
||||||
// Backward + optimizer step.
|
|
||||||
let grads = loss.backward();
|
|
||||||
let grads = GradientsParams::from_grads(grads, &model);
|
|
||||||
model = optim.step(learning_rate, model, grads);
|
|
||||||
|
|
||||||
offset = end;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (epoch + 1) % 10 == 0 || epoch == 0 {
|
|
||||||
let avg_loss = epoch_loss / batches as f32;
|
|
||||||
let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device);
|
|
||||||
println!(
|
|
||||||
"[scanner] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}",
|
|
||||||
epoch + 1,
|
|
||||||
epochs,
|
|
||||||
avg_loss,
|
|
||||||
val_acc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the inner (non-autodiff) model for weight extraction.
|
|
||||||
model.valid()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn eval_mlp_accuracy(
|
|
||||||
model: &crate::training::mlp::MlpModel<TrainBackend>,
|
|
||||||
val_set: &[TrainingSample],
|
|
||||||
mins: &[f32],
|
|
||||||
maxs: &[f32],
|
|
||||||
device: &<TrainBackend as Backend>::Device,
|
|
||||||
) -> f64 {
|
|
||||||
let flat: Vec<f32> = val_set
|
|
||||||
.iter()
|
|
||||||
.flat_map(|s| normalize_features(&s.features, mins, maxs))
|
|
||||||
.collect();
|
|
||||||
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
|
|
||||||
.reshape([val_set.len(), NUM_FEATURES]);
|
|
||||||
|
|
||||||
let pred = model.forward(x);
|
|
||||||
let pred_data: Vec<f32> = pred.to_data().to_vec().expect("flat vec");
|
|
||||||
|
|
||||||
let mut correct = 0usize;
|
|
||||||
for (i, s) in val_set.iter().enumerate() {
|
|
||||||
let p = pred_data[i];
|
|
||||||
let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 };
|
|
||||||
if (predicted_label - s.label).abs() < 0.1 {
|
|
||||||
correct += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
correct as f64 / val_set.len() as f64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -410,19 +418,14 @@ fn eval_mlp_accuracy(
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
fn extract_weights(
|
fn extract_weights(
|
||||||
model: &crate::training::mlp::MlpModel<NdArray<f32>>,
|
model: &crate::training::mlp::MlpModel<Wgpu<f32, i32>>,
|
||||||
name: &str,
|
name: &str,
|
||||||
tree_nodes: &[(u8, f32, u16, u16)],
|
tree_nodes: &[(u8, f32, u16, u16)],
|
||||||
threshold: f32,
|
threshold: f32,
|
||||||
norm_mins: &[f32],
|
norm_mins: &[f32],
|
||||||
norm_maxs: &[f32],
|
norm_maxs: &[f32],
|
||||||
_device: &<NdArray<f32> as Backend>::Device,
|
_device: &<Wgpu<f32, i32> as Backend>::Device,
|
||||||
) -> ExportedModel {
|
) -> ExportedModel {
|
||||||
// Extract weight tensors from the model.
|
|
||||||
// linear1.weight: [hidden_dim, input_dim]
|
|
||||||
// linear1.bias: [hidden_dim]
|
|
||||||
// linear2.weight: [1, hidden_dim]
|
|
||||||
// linear2.bias: [1]
|
|
||||||
let w1_tensor = model.linear1.weight.val();
|
let w1_tensor = model.linear1.weight.val();
|
||||||
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
|
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
|
||||||
let w2_tensor = model.linear2.weight.val();
|
let w2_tensor = model.linear2.weight.val();
|
||||||
@@ -436,7 +439,6 @@ fn extract_weights(
|
|||||||
let hidden_dim = b1_data.len();
|
let hidden_dim = b1_data.len();
|
||||||
let input_dim = w1_data.len() / hidden_dim;
|
let input_dim = w1_data.len() / hidden_dim;
|
||||||
|
|
||||||
// Reshape W1 into [hidden_dim][input_dim].
|
|
||||||
let w1: Vec<Vec<f32>> = (0..hidden_dim)
|
let w1: Vec<Vec<f32>> = (0..hidden_dim)
|
||||||
.map(|h| w1_data[h * input_dim..(h + 1) * input_dim].to_vec())
|
.map(|h| w1_data[h * input_dim..(h + 1) * input_dim].to_vec())
|
||||||
.collect();
|
.collect();
|
||||||
@@ -485,9 +487,8 @@ mod tests {
|
|||||||
let train_attacks = train.iter().filter(|s| s.label >= 0.5).count();
|
let train_attacks = train.iter().filter(|s| s.label >= 0.5).count();
|
||||||
let val_attacks = val.iter().filter(|s| s.label >= 0.5).count();
|
let val_attacks = val.iter().filter(|s| s.label >= 0.5).count();
|
||||||
|
|
||||||
// Should preserve the 80/20 attack ratio approximately.
|
assert_eq!(train_attacks, 16);
|
||||||
assert_eq!(train_attacks, 16); // 80% of 20
|
assert_eq!(val_attacks, 4);
|
||||||
assert_eq!(val_attacks, 4); // 20% of 20
|
|
||||||
assert_eq!(train.len() + val.len(), 100);
|
assert_eq!(train.len() + val.len(), 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,13 +504,4 @@ mod tests {
|
|||||||
assert_eq!(mins[1], 10.0);
|
assert_eq!(mins[1], 10.0);
|
||||||
assert_eq!(maxs[1], 20.0);
|
assert_eq!(maxs[1], 20.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_normalize_features() {
|
|
||||||
let mins = vec![0.0, 10.0];
|
|
||||||
let maxs = vec![1.0, 20.0];
|
|
||||||
let normed = normalize_features(&[0.5, 15.0], &mins, &maxs);
|
|
||||||
assert!((normed[0] - 0.5).abs() < 1e-6);
|
|
||||||
assert!((normed[1] - 0.5).abs() < 1e-6);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user