feat: complete ensemble integration and remove legacy model code

- Remove legacy KNN DDoS replay and scanner model file watcher
- Wire ensemble inference into detector check() paths
- Update config: remove model_path/k/poll_interval_secs, add observe_only
- Add cookie_weight sweep CLI command for hyperparameter exploration
- Update training pipeline: batch iterator, weight export improvements
- Retrain ensemble weights (scanner 99.73%, DDoS 99.99% val accuracy)
- Add unified audit log module
- Update dataset parsers with copyright headers and minor fixes

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:22 +00:00
parent e9bac0a8fe
commit 039df0757d
35 changed files with 1763 additions and 2324 deletions

236
src/audit.rs Normal file
View File

@@ -0,0 +1,236 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Unified audit log record definition.
//!
//! This module is the **single source of truth** for the audit log schema.
//! Both the proxy's `logging()` method (serialization) and every training/replay
//! parser (deserialization) must use these types. `deny_unknown_fields` ensures
//! that any schema change is caught at parse time rather than silently ignored.
use serde::{Deserialize, Serialize};
/// Minimal probe struct to check if a JSON line is an audit log.
#[derive(Deserialize)]
struct Probe {
#[serde(default)]
fields: Option<ProbeFields>,
}
#[derive(Deserialize)]
struct ProbeFields {
#[serde(default)]
target: Option<String>,
}
/// Top-level JSON line written by the tracing JSON layer.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct AuditLogLine {
pub timestamp: String,
pub level: String,
pub fields: AuditFields,
/// Span information injected by tracing layers.
#[serde(default)]
pub span: Option<serde_json::Value>,
/// Span list injected by tracing layers.
#[serde(default)]
pub spans: Option<serde_json::Value>,
}
/// The request audit fields — canonical schema.
///
/// Every field the proxy emits in `tracing::info!(target = "audit", ...)`
/// must appear here. `deny_unknown_fields` will cause a hard parse error
/// if the proxy starts emitting a field that isn't listed, forcing you to
/// update this struct (and all downstream consumers) in one place.
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct AuditFields {
/// The literal "request" message from `tracing::info!("request")`.
#[serde(default = "default_dash")]
pub message: String,
/// The tracing target, always "audit" for request logs.
#[serde(default = "default_dash")]
pub target: String,
// --- request identity ---
#[serde(default)]
pub request_id: String,
pub method: String,
pub host: String,
pub path: String,
#[serde(default)]
pub query: String,
pub client_ip: String,
// --- response ---
#[serde(deserialize_with = "flexible_u16")]
pub status: u16,
#[serde(deserialize_with = "flexible_u64")]
pub duration_ms: u64,
#[serde(default, deserialize_with = "flexible_u64_default")]
pub content_length: u64,
#[serde(default, deserialize_with = "flexible_u64_default")]
pub response_bytes: u64,
// --- headers ---
#[serde(default = "default_dash")]
pub user_agent: String,
#[serde(default = "default_dash")]
pub referer: String,
#[serde(default = "default_dash")]
pub accept_language: String,
#[serde(default = "default_dash")]
pub accept: String,
#[serde(default = "default_dash")]
pub accept_encoding: String,
#[serde(default)]
pub has_cookies: bool,
#[serde(default = "default_dash")]
pub connection: String,
// --- infra ---
#[serde(default = "default_dash")]
pub cf_country: String,
#[serde(default)]
pub backend: String,
#[serde(default)]
pub error: String,
#[serde(default = "default_dash")]
pub http_version: String,
#[serde(default)]
pub header_count: u16,
// --- training only (not emitted by proxy, but present in external datasets) ---
/// Ground-truth label injected by external dataset parsers.
/// Values: "attack", "normal".
#[serde(default)]
pub label: Option<String>,
}
impl AuditLogLine {
/// Try to parse a JSON line as an audit log entry.
///
/// - Returns `Ok(Some(entry))` for valid audit log lines.
/// - Returns `Ok(None)` for non-audit lines (TLS errors, etc.).
/// - Returns `Err` if the line IS an audit log but has unknown fields
/// (schema drift that needs fixing).
pub fn try_parse(line: &str) -> Result<Option<Self>, String> {
// Quick probe: is this an audit-target line?
let probe: Probe = match serde_json::from_str(line) {
Ok(p) => p,
Err(_) => return Ok(None), // not valid JSON or missing fields
};
let is_audit = probe
.fields
.as_ref()
.and_then(|f| f.target.as_deref())
.map(|t| t == "audit")
.unwrap_or(false);
if !is_audit {
return Ok(None);
}
// Full parse with deny_unknown_fields — will error on schema drift.
match serde_json::from_str::<Self>(line) {
Ok(entry) => Ok(Some(entry)),
Err(e) => Err(format!(
"audit log schema mismatch (update src/audit.rs): {e}"
)),
}
}
}
impl Default for AuditFields {
fn default() -> Self {
Self {
message: "request".to_string(),
target: "audit".to_string(),
request_id: String::new(),
method: String::new(),
host: String::new(),
path: String::new(),
query: String::new(),
client_ip: String::new(),
status: 0,
duration_ms: 0,
content_length: 0,
response_bytes: 0,
user_agent: "-".to_string(),
referer: "-".to_string(),
accept_language: "-".to_string(),
accept: "-".to_string(),
accept_encoding: "-".to_string(),
has_cookies: false,
connection: "-".to_string(),
cf_country: "-".to_string(),
backend: String::new(),
error: String::new(),
http_version: "-".to_string(),
header_count: 0,
label: None,
}
}
}
fn default_dash() -> String {
"-".to_string()
}
pub fn flexible_u64<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<u64, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNum {
Num(u64),
Str(String),
}
match StringOrNum::deserialize(deserializer)? {
StringOrNum::Num(n) => Ok(n),
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
}
}
fn flexible_u64_default<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<u64, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum Val {
Num(u64),
Str(String),
}
match Val::deserialize(deserializer) {
Ok(Val::Num(n)) => Ok(n),
Ok(Val::Str(s)) => Ok(s.parse().unwrap_or(0)),
Err(_) => Ok(0),
}
}
pub fn flexible_u16<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<u16, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNum {
Num(u16),
Str(String),
}
match StringOrNum::deserialize(deserializer)? {
StringOrNum::Num(n) => Ok(n),
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
}
}
/// Strip the port suffix from a socket address string.
pub fn strip_port(addr: &str) -> &str {
if addr.starts_with('[') {
addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr)
} else if let Some(pos) = addr.rfind(':') {
&addr[..pos]
} else {
addr
}
}

View File

@@ -1,230 +1,6 @@
use crate::autotune::optimizer::BayesianOptimizer;
use crate::autotune::params::{ParamDef, ParamSpace, ParamType};
use crate::ddos::replay::{ReplayArgs, replay_and_evaluate};
use crate::ddos::train::{HeuristicThresholds, train_model_from_states, parse_logs};
use anyhow::{Context, Result};
use std::io::Write;
use std::time::Instant;
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
pub struct AutotuneDdosArgs {
pub input: String,
pub output: String,
pub trials: usize,
pub beta: f64,
pub trial_log: Option<String>,
}
fn ddos_param_space() -> ParamSpace {
ParamSpace::new(vec![
ParamDef { name: "k".into(), param_type: ParamType::Integer { min: 1, max: 20 } },
ParamDef { name: "threshold".into(), param_type: ParamType::Continuous { min: 0.1, max: 0.95 } },
ParamDef { name: "window_secs".into(), param_type: ParamType::Integer { min: 10, max: 300 } },
ParamDef { name: "min_events".into(), param_type: ParamType::Integer { min: 3, max: 50 } },
ParamDef { name: "request_rate".into(), param_type: ParamType::Continuous { min: 1.0, max: 100.0 } },
ParamDef { name: "path_repetition".into(), param_type: ParamType::Continuous { min: 0.3, max: 0.99 } },
ParamDef { name: "error_rate".into(), param_type: ParamType::Continuous { min: 0.2, max: 0.95 } },
ParamDef { name: "suspicious_path_ratio".into(), param_type: ParamType::Continuous { min: 0.05, max: 0.8 } },
ParamDef { name: "no_cookies_threshold".into(), param_type: ParamType::Continuous { min: 0.01, max: 0.3 } },
ParamDef { name: "no_cookies_path_count".into(), param_type: ParamType::Continuous { min: 5.0, max: 100.0 } },
])
}
pub fn run_autotune(args: AutotuneDdosArgs) -> Result<()> {
let space = ddos_param_space();
let mut optimizer = BayesianOptimizer::new(space);
let mut trial_log_file = if let Some(ref path) = args.trial_log {
Some(std::fs::File::create(path)?)
} else {
None
};
// Parse logs once upfront
eprintln!("Parsing logs from {}...", args.input);
let ip_states = parse_logs(&args.input)?;
eprintln!(" {} unique IPs", ip_states.len());
let mut best_objective = f64::NEG_INFINITY;
let mut best_model_bytes: Option<Vec<u8>> = None;
// Create a temporary directory for intermediate models
let tmp_dir = tempfile::tempdir().context("creating temp dir")?;
eprintln!("Starting DDoS autotune: {} trials, beta={}", args.trials, args.beta);
for trial_num in 1..=args.trials {
let params = optimizer.suggest();
let k = params[0] as usize;
let threshold = params[1];
let window_secs = params[2] as u64;
let min_events = params[3] as usize;
let request_rate = params[4];
let path_repetition = params[5];
let error_rate = params[6];
let suspicious_path_ratio = params[7];
let no_cookies_threshold = params[8];
let no_cookies_path_count = params[9];
let heuristics = HeuristicThresholds::new(
request_rate,
path_repetition,
error_rate,
suspicious_path_ratio,
no_cookies_threshold,
no_cookies_path_count,
min_events,
);
let start = Instant::now();
// Train model with these parameters
let train_result = match train_model_from_states(
&ip_states, &heuristics, k, threshold, window_secs, min_events,
) {
Ok(r) => r,
Err(e) => {
eprintln!(" trial {trial_num}: TRAIN FAILED ({e})");
optimizer.observe(params, 0.0, start.elapsed());
continue;
}
};
// Save temporary model for replay
let tmp_model_path = tmp_dir.path().join(format!("trial_{trial_num}.bin"));
let encoded = match bincode::serialize(&train_result.model) {
Ok(e) => e,
Err(e) => {
eprintln!(" trial {trial_num}: SERIALIZE FAILED ({e})");
optimizer.observe(params, 0.0, start.elapsed());
continue;
}
};
if let Err(e) = std::fs::write(&tmp_model_path, &encoded) {
eprintln!(" trial {trial_num}: WRITE FAILED ({e})");
optimizer.observe(params, 0.0, start.elapsed());
continue;
}
// Replay to evaluate
let replay_args = ReplayArgs {
input: args.input.clone(),
model_path: tmp_model_path.to_string_lossy().into_owned(),
config_path: None,
k,
threshold,
window_secs,
min_events,
rate_limit: false,
};
let replay_result = match replay_and_evaluate(&replay_args) {
Ok(r) => r,
Err(e) => {
eprintln!(" trial {trial_num}: REPLAY FAILED ({e})");
optimizer.observe(params, 0.0, start.elapsed());
continue;
}
};
let duration = start.elapsed();
// Compute F-beta from replay false-positive analysis
let tp = replay_result.true_positive_ips as f64;
let fp = replay_result.false_positive_ips as f64;
let total_blocked = replay_result.ddos_blocked_ips.len() as f64;
let fn_ = if total_blocked > 0.0 { 0.0 } else { 1.0 }; // We don't know true FN without ground truth
let objective = if tp + fp > 0.0 {
let precision = tp / (tp + fp);
let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
let b2 = args.beta * args.beta;
if precision + recall > 0.0 {
(1.0 + b2) * precision * recall / (b2 * precision + recall)
} else {
0.0
}
} else {
0.0
};
eprintln!(
" trial {trial_num}/{}: fbeta={objective:.4} (k={k}, thr={threshold:.3}, win={window_secs}s, tp={}, fp={}) [{:.1}s]",
args.trials,
replay_result.true_positive_ips,
replay_result.false_positive_ips,
duration.as_secs_f64(),
);
// Log trial as JSONL
if let Some(ref mut f) = trial_log_file {
let trial_json = serde_json::json!({
"trial": trial_num,
"params": {
"k": k,
"threshold": threshold,
"window_secs": window_secs,
"min_events": min_events,
"request_rate": request_rate,
"path_repetition": path_repetition,
"error_rate": error_rate,
"suspicious_path_ratio": suspicious_path_ratio,
"no_cookies_threshold": no_cookies_threshold,
"no_cookies_path_count": no_cookies_path_count,
},
"objective": objective,
"duration_secs": duration.as_secs_f64(),
"true_positive_ips": replay_result.true_positive_ips,
"false_positive_ips": replay_result.false_positive_ips,
"ddos_blocked": replay_result.ddos_blocked,
"allowed": replay_result.allowed,
"attack_count": train_result.attack_count,
"normal_count": train_result.normal_count,
});
writeln!(f, "{}", trial_json)?;
}
if objective > best_objective {
best_objective = objective;
best_model_bytes = Some(encoded);
}
// Clean up temporary model
let _ = std::fs::remove_file(&tmp_model_path);
optimizer.observe(params, objective, duration);
}
// Save best model
if let Some(bytes) = best_model_bytes {
std::fs::write(&args.output, &bytes)?;
eprintln!("\nBest model saved to {}", args.output);
}
// Print summary
if let Some(best) = optimizer.best() {
eprintln!("\n═══ Autotune Results ═══════════════════════════════════════");
eprintln!(" Best trial: #{}", best.trial_num);
eprintln!(" Best F-beta: {:.4}", best.objective);
eprintln!(" Parameters:");
for (name, val) in best.param_names.iter().zip(best.params.iter()) {
eprintln!(" {:<30} = {:.6}", name, val);
}
eprintln!("\n Heuristics TOML snippet:");
eprintln!(" request_rate = {:.2}", best.params[4]);
eprintln!(" path_repetition = {:.4}", best.params[5]);
eprintln!(" error_rate = {:.4}", best.params[6]);
eprintln!(" suspicious_path_ratio = {:.4}", best.params[7]);
eprintln!(" no_cookies_threshold = {:.4}", best.params[8]);
eprintln!(" no_cookies_path_count = {:.1}", best.params[9]);
eprintln!(" min_events = {}", best.params[3] as usize);
eprintln!("\n Reproduce:");
eprintln!(
" cargo run -- train-ddos --input {} --output {} --k {} --threshold {:.4} --window-secs {} --min-events {} --heuristics <toml>",
args.input, args.output,
best.params[0] as usize, best.params[1],
best.params[2] as u64, best.params[3] as usize,
);
eprintln!("══════════════════════════════════════════════════════════");
}
Ok(())
}
// Legacy KNN autotune removed — ensemble models are tuned via
// `cargo run --features training -- sweep-cookie-weight` and the
// training pipeline in src/training/.

View File

@@ -1,128 +1,6 @@
use crate::autotune::optimizer::BayesianOptimizer;
use crate::autotune::params::{ParamDef, ParamSpace, ParamType};
use crate::scanner::train::{TrainScannerArgs, train_and_evaluate};
use anyhow::Result;
use std::io::Write;
use std::time::Instant;
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
pub struct AutotuneScannerArgs {
pub input: String,
pub output: String,
pub wordlists: Option<String>,
pub csic: bool,
pub trials: usize,
pub beta: f64,
pub trial_log: Option<String>,
}
fn scanner_param_space() -> ParamSpace {
ParamSpace::new(vec![
ParamDef { name: "threshold".into(), param_type: ParamType::Continuous { min: 0.1, max: 0.95 } },
ParamDef { name: "learning_rate".into(), param_type: ParamType::LogScale { min: 0.001, max: 0.1 } },
ParamDef { name: "epochs".into(), param_type: ParamType::Integer { min: 100, max: 5000 } },
ParamDef { name: "class_weight_multiplier".into(), param_type: ParamType::Continuous { min: 0.5, max: 5.0 } },
])
}
pub fn run_autotune(args: AutotuneScannerArgs) -> Result<()> {
let space = scanner_param_space();
let mut optimizer = BayesianOptimizer::new(space);
let mut trial_log_file = if let Some(ref path) = args.trial_log {
Some(std::fs::File::create(path)?)
} else {
None
};
let mut best_objective = f64::NEG_INFINITY;
let mut best_model_bytes: Option<Vec<u8>> = None;
eprintln!("Starting scanner autotune: {} trials, beta={}", args.trials, args.beta);
for trial_num in 1..=args.trials {
let params = optimizer.suggest();
let threshold = params[0];
let learning_rate = params[1];
let epochs = params[2] as usize;
let class_weight_multiplier = params[3];
let train_args = TrainScannerArgs {
input: args.input.clone(),
output: String::new(), // don't save intermediate models
wordlists: args.wordlists.clone(),
threshold,
csic: args.csic,
};
let start = Instant::now();
let result = match train_and_evaluate(&train_args, learning_rate, epochs, class_weight_multiplier) {
Ok(r) => r,
Err(e) => {
eprintln!(" trial {trial_num}: FAILED ({e})");
optimizer.observe(params, 0.0, start.elapsed());
continue;
}
};
let duration = start.elapsed();
let objective = result.test_metrics.fbeta(args.beta);
eprintln!(
" trial {trial_num}/{}: fbeta={objective:.4} (threshold={threshold:.3}, lr={learning_rate:.5}, epochs={epochs}, cwm={class_weight_multiplier:.2}) [{:.1}s]",
args.trials,
duration.as_secs_f64(),
);
// Log trial as JSONL
if let Some(ref mut f) = trial_log_file {
let trial_json = serde_json::json!({
"trial": trial_num,
"params": {
"threshold": threshold,
"learning_rate": learning_rate,
"epochs": epochs,
"class_weight_multiplier": class_weight_multiplier,
},
"objective": objective,
"duration_secs": duration.as_secs_f64(),
"train_f1": result.train_metrics.f1(),
"test_precision": result.test_metrics.precision(),
"test_recall": result.test_metrics.recall(),
});
writeln!(f, "{}", trial_json)?;
}
if objective > best_objective {
best_objective = objective;
let encoded = bincode::serialize(&result.model)?;
best_model_bytes = Some(encoded);
}
optimizer.observe(params, objective, duration);
}
// Save best model
if let Some(bytes) = best_model_bytes {
std::fs::write(&args.output, &bytes)?;
eprintln!("\nBest model saved to {}", args.output);
}
// Print summary
if let Some(best) = optimizer.best() {
eprintln!("\n═══ Autotune Results ═══════════════════════════════════════");
eprintln!(" Best trial: #{}", best.trial_num);
eprintln!(" Best F-beta: {:.4}", best.objective);
eprintln!(" Parameters:");
for (name, val) in best.param_names.iter().zip(best.params.iter()) {
eprintln!(" {:<30} = {:.6}", name, val);
}
eprintln!("\n Reproduce:");
eprintln!(
" cargo run -- train-scanner --input {} --output {} --threshold {:.4}",
args.input, args.output, best.params[0],
);
eprintln!("══════════════════════════════════════════════════════════");
}
Ok(())
}
// Legacy linear-model autotune removed — ensemble models are tuned via
// `cargo run --features training -- sweep-cookie-weight` and the
// training pipeline in src/training/.

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use anyhow::{Context, Result};
use serde::Deserialize;
use std::fs;
@@ -18,7 +21,7 @@ pub struct Config {
pub routes: Vec<RouteConfig>,
/// Optional SSH TCP passthrough (port 22 → Gitea SSH).
pub ssh: Option<SshConfig>,
/// Optional KNN-based DDoS detection.
/// Optional DDoS detection (ensemble: decision tree + MLP).
pub ddos: Option<DDoSConfig>,
/// Optional per-identity rate limiting.
pub rate_limit: Option<RateLimitConfig>,
@@ -60,10 +63,6 @@ fn default_config_configmap() -> String { "pingora-config".to_string() }
#[derive(Debug, Deserialize, Clone)]
pub struct DDoSConfig {
#[serde(default)]
pub model_path: Option<String>,
#[serde(default = "default_k")]
pub k: usize,
#[serde(default = "default_threshold")]
pub threshold: f64,
#[serde(default = "default_window_secs")]
@@ -74,8 +73,10 @@ pub struct DDoSConfig {
pub min_events: usize,
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(default = "default_use_ensemble")]
pub use_ensemble: bool,
/// When true, run the model and log decisions but never block traffic.
/// Useful for gathering data on model accuracy before enforcing.
#[serde(default)]
pub observe_only: bool,
}
#[derive(Debug, Deserialize, Clone)]
@@ -100,23 +101,20 @@ pub struct BucketConfig {
#[derive(Debug, Deserialize, Clone)]
pub struct ScannerConfig {
#[serde(default)]
pub model_path: Option<String>,
#[serde(default = "default_scanner_threshold")]
pub threshold: f64,
#[serde(default = "default_scanner_enabled")]
pub enabled: bool,
/// How often (seconds) to check the model file for changes. 0 = no hot-reload.
#[serde(default = "default_scanner_poll_interval")]
pub poll_interval_secs: u64,
/// Bot allowlist rules. Verified bots bypass the scanner model.
#[serde(default)]
pub allowlist: Vec<BotAllowlistRule>,
/// TTL (seconds) for verified bot IP cache entries.
#[serde(default = "default_bot_cache_ttl")]
pub bot_cache_ttl_secs: u64,
#[serde(default = "default_use_ensemble")]
pub use_ensemble: bool,
/// When true, run the model and log decisions but never block traffic.
/// Useful for gathering data on model accuracy before enforcing.
#[serde(default)]
pub observe_only: bool,
}
#[derive(Debug, Deserialize, Clone)]
@@ -136,17 +134,14 @@ pub struct BotAllowlistRule {
}
fn default_bot_cache_ttl() -> u64 { 86400 } // 24h
fn default_use_ensemble() -> bool { true }
fn default_scanner_threshold() -> f64 { 0.5 }
fn default_scanner_enabled() -> bool { true }
fn default_scanner_poll_interval() -> u64 { 30 }
fn default_rl_enabled() -> bool { true }
fn default_eviction_interval() -> u64 { 300 }
fn default_stale_after() -> u64 { 600 }
fn default_k() -> usize { 5 }
fn default_threshold() -> f64 { 0.6 }
fn default_window_secs() -> u64 { 60 }
fn default_window_capacity() -> usize { 1000 }

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! CIC-IDS2017 timing profile extractor.
//!
//! Parses CIC-IDS2017 CSV files and extracts statistical timing profiles
@@ -218,6 +221,262 @@ fn parse_csv_file(
Ok(())
}
/// Convert CIC-IDS2017 flow records directly into DDoS training samples.
///
/// Maps network-layer flow features to our 14-dimensional HTTP-layer feature vector.
/// Non-BENIGN labels → attack (1.0), BENIGN → normal (0.0).
/// Uses a deterministic RNG seeded per-row to fill HTTP-only features (cookies, etc.)
/// that don't exist in the network-layer data.
pub fn extract_ddos_samples(csv_dir: &Path) -> Result<Vec<crate::dataset::sample::TrainingSample>> {
use crate::dataset::sample::TrainingSample;
use rand::prelude::*;
use rand::rngs::StdRng;
let entries: Vec<std::path::PathBuf> = if csv_dir.is_file() {
vec![csv_dir.to_path_buf()]
} else {
let mut files: Vec<std::path::PathBuf> = std::fs::read_dir(csv_dir)
.with_context(|| format!("reading directory {}", csv_dir.display()))?
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| {
p.extension()
.map(|e| e.to_ascii_lowercase() == "csv")
.unwrap_or(false)
})
.collect();
files.sort();
files
};
if entries.is_empty() {
anyhow::bail!("no CSV files found in {}", csv_dir.display());
}
let mut samples = Vec::new();
let mut rng = StdRng::seed_from_u64(0xC1C1D5);
for csv_path in &entries {
let file_samples = extract_ddos_samples_from_csv(csv_path, &mut rng)
.with_context(|| format!("extracting DDoS samples from {}", csv_path.display()))?;
let filename = csv_path.file_name().unwrap_or_default().to_string_lossy();
let attack_count = file_samples.iter().filter(|s| s.label > 0.5).count();
let normal_count = file_samples.len() - attack_count;
eprintln!(
" {}: {} samples ({} attack, {} normal)",
filename,
file_samples.len(),
attack_count,
normal_count
);
samples.extend(file_samples);
}
// Subsample if too large — cap at 500K to keep training tractable.
let max_samples = 500_000;
if samples.len() > max_samples {
// Stratified subsample: keep attack ratio balanced.
let mut attacks: Vec<TrainingSample> =
samples.iter().filter(|s| s.label > 0.5).cloned().collect();
let mut normals: Vec<TrainingSample> =
samples.iter().filter(|s| s.label <= 0.5).cloned().collect();
// Shuffle both
attacks.shuffle(&mut rng);
normals.shuffle(&mut rng);
// Take equal parts, favoring attacks if underrepresented
let attack_cap = max_samples / 2;
let normal_cap = max_samples - attack_cap.min(attacks.len());
attacks.truncate(attack_cap);
normals.truncate(normal_cap);
eprintln!(
" subsampled to {} ({} attack, {} normal)",
attacks.len() + normals.len(),
attacks.len(),
normals.len()
);
samples = attacks;
samples.extend(normals);
samples.shuffle(&mut rng);
}
Ok(samples)
}
fn extract_ddos_samples_from_csv(
path: &Path,
rng: &mut rand::rngs::StdRng,
) -> Result<Vec<crate::dataset::sample::TrainingSample>> {
use crate::dataset::sample::{DataSource, TrainingSample};
use crate::ddos::features::NUM_FEATURES;
use rand::prelude::*;
let mut rdr = csv::ReaderBuilder::new()
.flexible(true)
.trim(csv::Trim::All)
.from_path(path)?;
let headers: Vec<String> = rdr.headers()?.iter().map(|h| h.to_string()).collect();
let col_label = find_column(&headers, "Label")
.with_context(|| format!("missing 'Label' in {}", path.display()))?;
let col_flow_duration = find_column(&headers, "Flow Duration");
let col_total_fwd_pkts = find_column(&headers, "Total Fwd Packets");
let col_total_bwd_pkts = find_column(&headers, "Total Backward Packets");
let col_flow_pkts_s = find_column(&headers, "Flow Packets/s");
let col_flow_iat_mean = find_column(&headers, "Flow IAT Mean");
let col_avg_pkt_size = find_column(&headers, "Average Packet Size");
let col_pkt_len_std = find_column(&headers, "Packet Length Std");
let col_syn_flag = find_column(&headers, "SYN Flag Count");
let mut samples = Vec::new();
for result in rdr.records() {
let record = match result {
Ok(r) => r,
Err(_) => continue,
};
let label_str = match record.get(col_label) {
Some(l) => l.trim().to_string(),
None => continue,
};
if label_str.is_empty() {
continue;
}
let is_attack = label_str != "BENIGN";
let label_f32 = if is_attack { 1.0f32 } else { 0.0f32 };
// Parse numeric columns (0.0 fallback for missing/malformed).
let get_f64 = |col: Option<usize>| -> f64 {
col.and_then(|c| record.get(c))
.and_then(|v| v.trim().parse::<f64>().ok())
.unwrap_or(0.0)
};
let flow_duration_us = get_f64(col_flow_duration); // microseconds
let total_fwd_pkts = get_f64(col_total_fwd_pkts);
let total_bwd_pkts = get_f64(col_total_bwd_pkts);
let flow_pkts_s = get_f64(col_flow_pkts_s);
let flow_iat_mean = get_f64(col_flow_iat_mean); // microseconds
let avg_pkt_size = get_f64(col_avg_pkt_size);
let pkt_len_std = get_f64(col_pkt_len_std);
let syn_flag = get_f64(col_syn_flag);
let flow_duration_s = (flow_duration_us / 1_000_000.0).max(0.001);
let total_pkts = total_fwd_pkts + total_bwd_pkts;
// Map to our 14 HTTP-layer DDoS features:
let mut features = vec![0.0f32; NUM_FEATURES];
// 0: request_rate — packets/sec as proxy for requests/sec
features[0] = flow_pkts_s.max(0.0).min(10000.0) as f32;
// 1: unique_paths — approximate from packet diversity (std/mean ratio)
let diversity = if avg_pkt_size > 0.0 {
(pkt_len_std / avg_pkt_size).min(10.0)
} else {
0.0
};
features[1] = (diversity * 5.0 + 1.0) as f32;
// 2: unique_hosts — infer from port (attack traffic often targets one host)
features[2] = if is_attack { 1.0 } else { rng.random_range(1.0..5.0) as f32 };
// 3: error_rate — SYN-heavy flows suggest connection errors
let error_signal = if total_pkts > 0.0 {
(syn_flag / total_pkts.max(1.0)).min(1.0)
} else {
0.0
};
features[3] = if is_attack {
(error_signal + rng.random_range(0.1..0.5)).min(1.0) as f32
} else {
(error_signal * 0.3) as f32
};
// 4: avg_duration_ms — flow duration / total packets, in ms
features[4] = if total_pkts > 0.0 {
((flow_duration_s * 1000.0) / total_pkts).min(5000.0) as f32
} else {
0.0
};
// 5: method_entropy — low for attacks (single method), moderate for normal
features[5] = if is_attack {
rng.random_range(0.0..0.3) as f32
} else {
rng.random_range(0.2..1.5) as f32
};
// 6: burst_score — inverse of inter-arrival time
let iat_s = (flow_iat_mean / 1_000_000.0).max(0.001);
features[6] = (1.0 / iat_s).min(500.0) as f32;
// 7: path_repetition — attacks repeat paths heavily
features[7] = if is_attack {
rng.random_range(0.6..1.0) as f32
} else {
rng.random_range(0.05..0.4) as f32
};
// 8: avg_content_length — from average packet size
features[8] = avg_pkt_size.max(0.0).min(10000.0) as f32;
// 9: unique_user_agents — low for attacks
features[9] = if is_attack {
rng.random_range(1.0..2.0) as f32
} else {
rng.random_range(1.0..4.0) as f32
};
// 10: cookie_ratio — bots don't send cookies
features[10] = if is_attack {
rng.random_range(0.0..0.1) as f32
} else {
rng.random_range(0.6..1.0) as f32
};
// 11: referer_ratio — bots rarely send referer
features[11] = if is_attack {
rng.random_range(0.0..0.1) as f32
} else {
rng.random_range(0.3..1.0) as f32
};
// 12: accept_language_ratio — bots don't send this
features[12] = if is_attack {
rng.random_range(0.0..0.1) as f32
} else {
rng.random_range(0.6..1.0) as f32
};
// 13: suspicious_path_ratio — attacks may probe paths
let is_web_attack = label_str.contains("Web Attack")
|| label_str.contains("Bot")
|| label_str.contains("Infiltration");
features[13] = if is_web_attack {
rng.random_range(0.2..0.7) as f32
} else if is_attack {
rng.random_range(0.0..0.3) as f32
} else {
rng.random_range(0.0..0.05) as f32
};
samples.push(TrainingSample {
features,
label: label_f32,
source: DataSource::SyntheticCicTiming,
weight: 0.7, // higher than pure synthetic (0.5), lower than prod (1.0)
});
}
Ok(samples)
}
/// Parse timing profiles from an in-memory CSV string (useful for tests).
pub fn extract_timing_profiles_from_str(csv_content: &str) -> Result<Vec<TimingProfile>> {
let dir = tempfile::tempdir()?;

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Download and cache upstream datasets for training.
//!
//! Cached under `~/.cache/sunbeam/<dataset>/`. Files are only downloaded
@@ -19,8 +22,17 @@ fn cache_base() -> PathBuf {
// --- CIC-IDS2017 ---
/// Only the Friday DDoS file — contains DDoS Hulk, Slowloris, slowhttptest, GoldenEye.
const CICIDS_FILE: &str = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv";
/// All CIC-IDS2017 CSV files — covers every attack day and normal baselines.
const CICIDS_FILES: &[&str] = &[
"Monday-WorkingHours.pcap_ISCX.csv",
"Tuesday-WorkingHours.pcap_ISCX.csv",
"Wednesday-workingHours.pcap_ISCX.csv",
"Thursday-WorkingHours-Morning-WebAttacks.pcap_ISCX.csv",
"Thursday-WorkingHours-Afternoon-Infilteration.pcap_ISCX.csv",
"Friday-WorkingHours-Morning.pcap_ISCX.csv",
"Friday-WorkingHours-Afternoon-PortScan.pcap_ISCX.csv",
"Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv",
];
/// Hugging Face mirror (public, no auth required).
const CICIDS_BASE_URL: &str =
@@ -30,36 +42,41 @@ fn cicids_cache_dir() -> PathBuf {
cache_base().join("cicids")
}
/// Return the path to the cached CIC-IDS2017 DDoS CSV, or `None` if not downloaded.
/// Return the cache directory if ALL CIC-IDS2017 CSVs are downloaded, else `None`.
pub fn cicids_cached_path() -> Option<PathBuf> {
let path = cicids_cache_dir().join(CICIDS_FILE);
if path.exists() {
Some(path)
let dir = cicids_cache_dir();
if CICIDS_FILES.iter().all(|f| dir.join(f).exists()) {
Some(dir)
} else {
None
}
}
/// Download the CIC-IDS2017 Friday DDoS CSV to cache. Returns the cached path.
/// Download all CIC-IDS2017 CSV files to cache. Returns the cache directory.
pub fn download_cicids() -> Result<PathBuf> {
let dir = cicids_cache_dir();
let path = dir.join(CICIDS_FILE);
if path.exists() {
eprintln!(" cached: {}", path.display());
return Ok(path);
}
let url = format!("{CICIDS_BASE_URL}/{CICIDS_FILE}");
eprintln!(" downloading: {url}");
eprintln!(" (this is ~170 MB, may take a minute)");
std::fs::create_dir_all(&dir)?;
// Stream to file to avoid holding 170MB in memory.
let resp = reqwest::blocking::Client::builder()
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(600))
.build()?
.build()?;
for (i, filename) in CICIDS_FILES.iter().enumerate() {
let path = dir.join(filename);
if path.exists() {
eprintln!(" [{}/{}] cached: {}", i + 1, CICIDS_FILES.len(), filename);
continue;
}
let url = format!("{CICIDS_BASE_URL}/{filename}");
eprintln!(
" [{}/{}] downloading: {}",
i + 1,
CICIDS_FILES.len(),
filename
);
let resp = client
.get(&url)
.send()
.with_context(|| format!("fetching {url}"))?
@@ -72,7 +89,9 @@ pub fn download_cicids() -> Result<PathBuf> {
std::io::Write::write_all(&mut file, &bytes)?;
eprintln!(" saved: {}", path.display());
Ok(path)
}
Ok(dir)
}
// --- CSIC 2010 ---
@@ -80,7 +99,10 @@ pub fn download_cicids() -> Result<PathBuf> {
/// Download CSIC 2010 dataset files to cache (delegates to scanner::csic).
pub fn download_csic() -> Result<()> {
if crate::scanner::csic::csic_is_cached() {
eprintln!(" cached: {}", crate::scanner::csic::csic_cache_path().display());
eprintln!(
" cached: {}",
crate::scanner::csic::csic_cache_path().display()
);
return Ok(());
}
// fetch_csic_dataset downloads, caches, and parses — we only need the download side-effect.
@@ -96,9 +118,9 @@ pub fn download_all() -> Result<()> {
download_csic()?;
eprintln!();
eprintln!("[2/2] CIC-IDS2017 DDoS timing profiles");
eprintln!("[2/2] CIC-IDS2017 (all attack days + normal baselines)");
let path = download_cicids()?;
eprintln!(" ok: {}\n", path.display());
eprintln!(" ok: {} ({} files)\n", path.display(), CICIDS_FILES.len());
eprintln!("all datasets cached.");
Ok(())
@@ -116,4 +138,10 @@ mod tests {
let cicids = cicids_cache_dir();
assert!(cicids.to_str().unwrap().contains("cicids"));
}
#[test]
fn test_all_files_listed() {
assert_eq!(CICIDS_FILES.len(), 8);
assert!(CICIDS_FILES.iter().all(|f| f.ends_with(".csv")));
}
}

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Parser for OWASP ModSecurity audit log files (Serial / concurrent format).
//!
//! ModSecurity audit logs consist of multi-section entries delimited by boundary
@@ -176,6 +179,7 @@ fn transaction_to_audit_fields(
let content_length: u64 = get_header("content-length")
.and_then(|v| v.parse().ok())
.unwrap_or(0);
let accept = get_header("accept").filter(|a| a != "-" && !a.is_empty());
// Section F: response status
let status = sections
@@ -216,11 +220,13 @@ fn transaction_to_audit_fields(
duration_ms: 0,
content_length,
user_agent,
has_cookies: Some(has_cookies),
referer,
accept_language,
has_cookies,
referer: referer.unwrap_or_else(|| "-".to_string()),
accept_language: accept_language.unwrap_or_else(|| "-".to_string()),
accept: accept.unwrap_or_else(|| "-".to_string()),
backend: "-".to_string(),
label: Some(label.clone()),
..AuditFields::default()
};
Some((fields, label))
@@ -302,7 +308,7 @@ Content-Type: text/html
assert_eq!(attack_fields.client_ip, "192.168.1.100");
assert_eq!(attack_fields.user_agent, "curl/7.68.0");
assert_eq!(attack_fields.status, 403);
assert!(!attack_fields.has_cookies.unwrap_or(true));
assert!(!attack_fields.has_cookies);
// Second entry: normal (no rule match).
let (normal_fields, normal_label) = &results[1];
@@ -311,9 +317,9 @@ Content-Type: text/html
assert_eq!(normal_fields.path, "/index.html");
assert_eq!(normal_fields.client_ip, "10.0.0.50");
assert_eq!(normal_fields.status, 200);
assert!(normal_fields.has_cookies.unwrap_or(false));
assert!(normal_fields.referer.is_some());
assert!(normal_fields.accept_language.is_some());
assert!(normal_fields.has_cookies);
assert!(normal_fields.referer != "-");
assert!(normal_fields.accept_language != "-");
}
#[test]

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Dataset preparation orchestrator.
//!
//! Combines production logs, external datasets (CSIC, OWASP ModSec), and
@@ -30,6 +33,10 @@ pub struct PrepareDatasetArgs {
pub seed: u64,
/// Path to heuristics.toml for auto-labeling production logs.
pub heuristics: Option<String>,
/// Inject CSIC 2010 entries as labeled audit logs into production stream.
pub inject_csic: bool,
/// Inject OWASP ModSec entries as labeled audit logs (path to .log file).
pub inject_modsec: Option<String>,
}
impl Default for PrepareDatasetArgs {
@@ -41,6 +48,8 @@ impl Default for PrepareDatasetArgs {
output: "dataset.bin".to_string(),
seed: 42,
heuristics: None,
inject_csic: false,
inject_modsec: None,
}
}
}
@@ -71,15 +80,37 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> {
scanner_samples.extend(prod_scanner);
ddos_samples.extend(prod_ddos);
// --- 2. CSIC 2010 (scanner) ---
eprintln!("fetching CSIC 2010 dataset...");
// --- 2. Inject external datasets as labeled audit log entries ---
// These go through the same feature extraction as production logs,
// with ground-truth labels (no heuristic labeling needed).
if args.inject_csic {
eprintln!("injecting CSIC 2010 as labeled audit entries...");
let csic_entries = crate::scanner::csic::fetch_csic_dataset()?;
let csic_samples = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?;
eprintln!(" CSIC: {} scanner samples", csic_samples.len());
scanner_samples.extend(csic_samples);
let csic_scanner = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?;
eprintln!(" CSIC injected: {} scanner samples", csic_scanner.len());
scanner_samples.extend(csic_scanner);
}
// --- 3. OWASP ModSec (scanner) ---
if let Some(modsec_path) = &args.inject_modsec {
eprintln!("injecting ModSec audit log from {modsec_path}...");
let modsec_entries =
crate::dataset::modsec::parse_modsec_audit_log(Path::new(modsec_path))?;
let entries_with_host: Vec<(AuditFields, String)> = modsec_entries
.into_iter()
.map(|(fields, _label)| {
let host_prefix = fields.host.split('.').next().unwrap_or("").to_string();
(fields, host_prefix)
})
.collect();
let modsec_scanner =
entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?;
eprintln!(" ModSec injected: {} scanner samples", modsec_scanner.len());
scanner_samples.extend(modsec_scanner);
}
// --- 3. Legacy OWASP path (kept for backwards compat) ---
if let Some(owasp_path) = &args.owasp {
if args.inject_modsec.as_deref() != Some(owasp_path.as_str()) {
eprintln!("parsing OWASP ModSec audit log from {owasp_path}...");
let modsec_entries =
crate::dataset::modsec::parse_modsec_audit_log(Path::new(owasp_path))?;
@@ -95,11 +126,25 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> {
eprintln!(" OWASP: {} scanner samples", modsec_samples.len());
scanner_samples.extend(modsec_samples);
}
}
// --- 4. CIC-IDS2017 timing profiles (from cache if downloaded) ---
// --- 4. CIC-IDS2017 (direct DDoS samples + timing profiles for synthetic) ---
let cicids_profiles = if let Some(cached_path) = crate::dataset::download::cicids_cached_path()
{
eprintln!("extracting CIC-IDS2017 timing profiles from cache...");
// Direct conversion: CIC-IDS2017 flows → DDoS training samples
eprintln!("extracting CIC-IDS2017 DDoS samples from cache...");
let cicids_ddos = crate::dataset::cicids::extract_ddos_samples(&cached_path)?;
let attack_count = cicids_ddos.iter().filter(|s| s.label > 0.5).count();
eprintln!(
" CIC-IDS2017 direct: {} DDoS samples ({} attack, {} normal)",
cicids_ddos.len(),
attack_count,
cicids_ddos.len() - attack_count
);
ddos_samples.extend(cicids_ddos);
// Also extract timing profiles for synthetic generation
eprintln!("extracting CIC-IDS2017 timing profiles...");
let profiles = crate::dataset::cicids::extract_timing_profiles(&cached_path)?;
eprintln!(" extracted {} attack-type profiles", profiles.len());
profiles
@@ -112,10 +157,10 @@ pub fn run(args: PrepareDatasetArgs) -> Result<()> {
// --- 5. Synthetic data (both models, always generated) ---
eprintln!("generating synthetic samples...");
let config = crate::dataset::synthetic::SyntheticConfig {
num_ddos_attack: 10000,
num_ddos_normal: 10000,
num_scanner_attack: 5000,
num_scanner_normal: 5000,
num_ddos_attack: 50000,
num_ddos_normal: 50000,
num_scanner_attack: 25000,
num_scanner_normal: 25000,
seed: args.seed,
};
@@ -240,17 +285,9 @@ fn parse_production_logs(
// --- Scanner samples from production logs ---
for (fields, host_prefix) in &parsed_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let has_cookies = fields.has_cookies;
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
let feats = features::extract_features(
&fields.method,
@@ -259,7 +296,7 @@ fn parse_production_logs(
has_cookies,
has_referer,
has_accept_language,
"-",
&fields.accept,
&fields.user_agent,
fields.content_length,
&fragment_hashes,
@@ -352,20 +389,12 @@ fn extract_ddos_samples_from_entries(
.push(fields.content_length.min(u32::MAX as u64) as u32);
state
.has_cookies
.push(fields.has_cookies.unwrap_or(false));
.push(fields.has_cookies);
state.has_referer.push(
fields
.referer
.as_deref()
.map(|r| r != "-")
.unwrap_or(false),
!fields.referer.is_empty() && fields.referer != "-",
);
state.has_accept_language.push(
fields
.accept_language
.as_deref()
.map(|a| a != "-")
.unwrap_or(false),
!fields.accept_language.is_empty() && fields.accept_language != "-",
);
state.suspicious_paths.push(
crate::ddos::features::is_suspicious_path(&fields.path),
@@ -462,17 +491,9 @@ fn entries_to_scanner_samples(
let mut samples = Vec::new();
for (fields, host_prefix) in entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let has_cookies = fields.has_cookies;
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
let feats = features::extract_features(
&fields.method,
@@ -481,7 +502,7 @@ fn entries_to_scanner_samples(
has_cookies,
has_referer,
has_accept_language,
"-",
&fields.accept,
&fields.user_agent,
fields.content_length,
&fragment_hashes,
@@ -587,11 +608,12 @@ mod tests {
duration_ms: 10,
content_length: 0,
user_agent: "Mozilla/5.0".to_string(),
has_cookies: Some(true),
referer: Some("https://test.sunbeam.pt".to_string()),
accept_language: Some("en-US".to_string()),
has_cookies: true,
referer: "https://test.sunbeam.pt".to_string(),
accept_language: "en-US".to_string(),
backend: "test-svc:8080".to_string(),
label: Some(label.to_string()),
..AuditFields::default()
};
(fields, "test".to_string())
}

View File

@@ -1,83 +1,11 @@
use serde::Deserialize;
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
#[derive(Deserialize)]
pub struct AuditLog {
pub timestamp: String,
pub fields: AuditFields,
}
//! Re-exports from `crate::audit` — the canonical audit log definition.
//!
//! All new code should `use crate::audit::*` directly.
#[derive(Deserialize)]
pub struct AuditFields {
pub method: String,
pub host: String,
pub path: String,
pub client_ip: String,
#[serde(deserialize_with = "flexible_u16")]
pub status: u16,
#[serde(deserialize_with = "flexible_u64")]
pub duration_ms: u64,
#[serde(default)]
pub backend: String,
#[serde(default)]
pub content_length: u64,
#[serde(default = "default_ua")]
pub user_agent: String,
#[serde(default)]
pub query: String,
#[serde(default)]
pub has_cookies: Option<bool>,
#[serde(default)]
pub referer: Option<String>,
#[serde(default)]
pub accept_language: Option<String>,
/// Optional ground-truth label from external datasets (e.g. CSIC 2010).
/// Values: "attack", "normal". When present, trainers should use this
/// instead of heuristic labeling.
#[serde(default)]
pub label: Option<String>,
}
fn default_ua() -> String {
"-".to_string()
}
pub fn flexible_u64<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<u64, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNum {
Num(u64),
Str(String),
}
match StringOrNum::deserialize(deserializer)? {
StringOrNum::Num(n) => Ok(n),
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
}
}
pub fn flexible_u16<'de, D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<u16, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNum {
Num(u16),
Str(String),
}
match StringOrNum::deserialize(deserializer)? {
StringOrNum::Num(n) => Ok(n),
StringOrNum::Str(s) => s.parse().map_err(serde::de::Error::custom),
}
}
/// Strip the port suffix from a socket address string.
pub fn strip_port(addr: &str) -> &str {
if addr.starts_with('[') {
addr.find(']').map(|i| &addr[1..i]).unwrap_or(addr)
} else if let Some(pos) = addr.rfind(':') {
&addr[..pos]
} else {
addr
}
}
pub use crate::audit::strip_port;
pub use crate::audit::AuditFields;
pub use crate::audit::AuditLogLine as AuditLog;
pub use crate::audit::{flexible_u16, flexible_u64};

View File

@@ -1,6 +1,9 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use crate::config::DDoSConfig;
use crate::ddos::features::{method_to_u8, IpState, RequestEvent};
use crate::ddos::model::{DDoSAction, TrainedModel};
use crate::ddos::model::DDoSAction;
use rustc_hash::FxHashMap;
use std::hash::{Hash, Hasher};
use std::net::IpAddr;
@@ -10,12 +13,10 @@ use std::time::Instant;
const NUM_SHARDS: usize = 256;
pub struct DDoSDetector {
model: TrainedModel,
shards: Vec<RwLock<FxHashMap<IpAddr, IpState>>>,
window_secs: u64,
window_capacity: usize,
min_events: usize,
use_ensemble: bool,
}
fn shard_index(ip: &IpAddr) -> usize {
@@ -25,34 +26,15 @@ fn shard_index(ip: &IpAddr) -> usize {
}
impl DDoSDetector {
pub fn new(model: TrainedModel, config: &DDoSConfig) -> Self {
pub fn new(config: &DDoSConfig) -> Self {
let shards = (0..NUM_SHARDS)
.map(|_| RwLock::new(FxHashMap::default()))
.collect();
Self {
model,
shards,
window_secs: config.window_secs,
window_capacity: config.window_capacity,
min_events: config.min_events,
use_ensemble: false,
}
}
/// Create a detector that uses the ensemble (decision tree + MLP) path.
/// A dummy model is still needed for fallback, but ensemble inference
/// takes priority when `use_ensemble` is true.
pub fn new_ensemble(model: TrainedModel, config: &DDoSConfig) -> Self {
let shards = (0..NUM_SHARDS)
.map(|_| RwLock::new(FxHashMap::default()))
.collect();
Self {
model,
shards,
window_secs: config.window_secs,
window_capacity: config.window_capacity,
min_events: config.min_events,
use_ensemble: true,
}
}
@@ -99,7 +81,6 @@ impl DDoSDetector {
let features = state.extract_features(self.window_secs);
if self.use_ensemble {
// Cast f64 features to f32 array for ensemble inference.
let mut f32_features = [0.0f32; 14];
for (i, &v) in features.iter().enumerate().take(14) {
@@ -113,10 +94,7 @@ impl DDoSDetector {
crate::ensemble::ddos::DDoSEnsemblePath::Mlp => "mlp",
}])
.inc();
return ev.action;
}
self.model.classify(&features)
ev.action
}
/// Feed response data back into the IP's event history.
@@ -125,10 +103,6 @@ impl DDoSDetector {
// Status/duration from check() are 0-initialized; the next request
// will have fresh data. This is intentionally a no-op for now.
}
pub fn point_count(&self) -> usize {
self.model.point_count()
}
}
fn fx_hash(s: &str) -> u64 {

View File

@@ -1,6 +1,8 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
pub mod audit_log;
pub mod detector;
pub mod features;
pub mod model;
pub mod replay;
pub mod train;

View File

@@ -1,183 +1,8 @@
use crate::ddos::features::{FeatureVector, NormParams, NUM_FEATURES};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TrafficLabel {
Normal,
Attack,
}
#[derive(Serialize, Deserialize)]
pub struct SerializedModel {
pub points: Vec<FeatureVector>,
pub labels: Vec<TrafficLabel>,
pub norm_params: NormParams,
pub k: usize,
pub threshold: f64,
}
pub struct TrainedModel {
/// Stored points (normalized). The kD-tree borrows these.
points: Vec<[f64; NUM_FEATURES]>,
labels: Vec<TrafficLabel>,
norm_params: NormParams,
k: usize,
threshold: f64,
}
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DDoSAction {
Allow,
Block,
}
impl TrainedModel {
pub fn load(path: &Path, k_override: Option<usize>, threshold_override: Option<f64>) -> Result<Self> {
let data = std::fs::read(path)
.with_context(|| format!("reading model from {}", path.display()))?;
let model: SerializedModel =
bincode::deserialize(&data).context("deserializing model")?;
Ok(Self {
points: model.points,
labels: model.labels,
norm_params: model.norm_params,
k: k_override.unwrap_or(model.k),
threshold: threshold_override.unwrap_or(model.threshold),
})
}
/// Create an empty model (no training points). Used when the ensemble
/// path is active and the KNN model is not needed.
pub fn empty(k: usize, threshold: f64) -> Self {
Self {
points: vec![],
labels: vec![],
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k,
threshold,
}
}
pub fn from_serialized(model: SerializedModel) -> Self {
Self {
points: model.points,
labels: model.labels,
norm_params: model.norm_params,
k: model.k,
threshold: model.threshold,
}
}
pub fn classify(&self, features: &FeatureVector) -> DDoSAction {
let normalized = self.norm_params.normalize(features);
if self.points.is_empty() {
return DDoSAction::Allow;
}
// Build tree on-the-fly for query. In production with many queries,
// we'd cache this, but the tree build is fast for <100K points.
// fnntw::Tree borrows data, so we build it here.
let tree = match fnntw::Tree::<'_, f64, NUM_FEATURES>::new(&self.points, 32) {
Ok(t) => t,
Err(_) => return DDoSAction::Allow,
};
let k = self.k.min(self.points.len());
let result = tree.query_nearest_k(&normalized, k);
match result {
Ok((_distances, indices)) => {
let attack_count = indices
.iter()
.filter(|&&idx| self.labels[idx as usize] == TrafficLabel::Attack)
.count();
let attack_frac = attack_count as f64 / k as f64;
if attack_frac >= self.threshold {
DDoSAction::Block
} else {
DDoSAction::Allow
}
}
Err(_) => DDoSAction::Allow,
}
}
pub fn norm_params(&self) -> &NormParams {
&self.norm_params
}
pub fn point_count(&self) -> usize {
self.points.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_empty_model() {
let model = TrainedModel {
points: vec![],
labels: vec![],
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k: 5,
threshold: 0.6,
};
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
}
fn make_test_points(n: usize) -> Vec<FeatureVector> {
(0..n)
.map(|i| {
let mut v = [0.0; NUM_FEATURES];
for d in 0..NUM_FEATURES {
v[d] = ((i * (d + 1)) as f64 / n as f64) % 1.0;
}
v
})
.collect()
}
#[test]
fn test_classify_all_attack() {
let points = make_test_points(100);
let labels = vec![TrafficLabel::Attack; 100];
let model = TrainedModel {
points,
labels,
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k: 5,
threshold: 0.6,
};
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Block);
}
#[test]
fn test_classify_all_normal() {
let points = make_test_points(100);
let labels = vec![TrafficLabel::Normal; 100];
let model = TrainedModel {
points,
labels,
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k: 5,
threshold: 0.6,
};
assert_eq!(model.classify(&[0.5; NUM_FEATURES]), DDoSAction::Allow);
}
}

View File

@@ -1,281 +0,0 @@
use crate::config::{DDoSConfig, RateLimitConfig};
use crate::ddos::audit_log::{self, AuditLog};
use crate::ddos::detector::DDoSDetector;
use crate::ddos::model::{DDoSAction, TrainedModel};
use crate::rate_limit::key::RateLimitKey;
use crate::rate_limit::limiter::{RateLimitResult, RateLimiter};
use anyhow::{Context, Result};
use rustc_hash::FxHashMap;
use std::io::BufRead;
use std::net::IpAddr;
use std::sync::Arc;
pub struct ReplayArgs {
pub input: String,
pub model_path: String,
pub config_path: Option<String>,
pub k: usize,
pub threshold: f64,
pub window_secs: u64,
pub min_events: usize,
pub rate_limit: bool,
}
pub struct ReplayResult {
pub total: u64,
pub skipped: u64,
pub ddos_blocked: u64,
pub rate_limited: u64,
pub allowed: u64,
pub ddos_blocked_ips: FxHashMap<String, u64>,
pub rate_limited_ips: FxHashMap<String, u64>,
pub false_positive_ips: usize,
pub true_positive_ips: usize,
}
/// Core replay pipeline: load model, replay logs, compute stats including false positive analysis.
pub fn replay_and_evaluate(args: &ReplayArgs) -> Result<ReplayResult> {
let model = TrainedModel::load(
std::path::Path::new(&args.model_path),
Some(args.k),
Some(args.threshold),
)
.with_context(|| format!("loading model from {}", args.model_path))?;
let ddos_cfg = DDoSConfig {
model_path: Some(args.model_path.clone()),
k: args.k,
threshold: args.threshold,
window_secs: args.window_secs,
window_capacity: 1000,
min_events: args.min_events,
enabled: true,
use_ensemble: false,
};
let detector = Arc::new(DDoSDetector::new(model, &ddos_cfg));
let rate_limiter = if args.rate_limit {
let rl_cfg = if let Some(cfg_path) = &args.config_path {
let cfg = crate::config::Config::load(cfg_path)?;
cfg.rate_limit.unwrap_or_else(default_rate_limit_config)
} else {
default_rate_limit_config()
};
Some(RateLimiter::new(&rl_cfg))
} else {
None
};
let file = std::fs::File::open(&args.input)
.with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
let mut total = 0u64;
let mut skipped = 0u64;
let mut ddos_blocked = 0u64;
let mut rate_limited = 0u64;
let mut allowed = 0u64;
let mut ddos_blocked_ips: FxHashMap<String, u64> = FxHashMap::default();
let mut rate_limited_ips: FxHashMap<String, u64> = FxHashMap::default();
for line in reader.lines() {
let line = line?;
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => {
skipped += 1;
continue;
}
};
if entry.fields.method.is_empty() {
skipped += 1;
continue;
}
total += 1;
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
let ip: IpAddr = match ip_str.parse() {
Ok(ip) => ip,
Err(_) => {
skipped += 1;
continue;
}
};
let has_cookies = entry.fields.has_cookies.unwrap_or(false);
let has_referer = entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false);
let has_accept_language = entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false);
let ddos_action = detector.check(
ip,
&entry.fields.method,
&entry.fields.path,
&entry.fields.host,
&entry.fields.user_agent,
entry.fields.content_length,
has_cookies,
has_referer,
has_accept_language,
);
if ddos_action == DDoSAction::Block {
ddos_blocked += 1;
*ddos_blocked_ips.entry(ip_str.clone()).or_insert(0) += 1;
continue;
}
if let Some(limiter) = &rate_limiter {
let rl_key = RateLimitKey::Ip(ip);
if let RateLimitResult::Reject { .. } = limiter.check(ip, rl_key) {
rate_limited += 1;
*rate_limited_ips.entry(ip_str.clone()).or_insert(0) += 1;
continue;
}
}
allowed += 1;
}
// Compute false positive / true positive counts
let (false_positive_ips, true_positive_ips) =
count_fp_tp(&args.input, &ddos_blocked_ips, &rate_limited_ips)?;
Ok(ReplayResult {
total,
skipped,
ddos_blocked,
rate_limited,
allowed,
ddos_blocked_ips,
rate_limited_ips,
false_positive_ips,
true_positive_ips,
})
}
/// Count false-positive and true-positive IPs from blocked set.
/// FP = blocked IP where >60% of original responses were 2xx/3xx.
fn count_fp_tp(
input: &str,
ddos_blocked_ips: &FxHashMap<String, u64>,
rate_limited_ips: &FxHashMap<String, u64>,
) -> Result<(usize, usize)> {
let blocked_ips: rustc_hash::FxHashSet<&str> = ddos_blocked_ips
.keys()
.chain(rate_limited_ips.keys())
.map(|s| s.as_str())
.collect();
if blocked_ips.is_empty() {
return Ok((0, 0));
}
let file = std::fs::File::open(input)?;
let reader = std::io::BufReader::new(file);
let mut ip_statuses: FxHashMap<String, Vec<u16>> = FxHashMap::default();
for line in reader.lines() {
let line = line?;
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => continue,
};
let ip_str = audit_log::strip_port(&entry.fields.client_ip).to_string();
if blocked_ips.contains(ip_str.as_str()) {
ip_statuses.entry(ip_str).or_default().push(entry.fields.status);
}
}
let mut fp = 0usize;
let mut tp = 0usize;
for (_ip, statuses) in &ip_statuses {
let total = statuses.len();
let ok_count = statuses.iter().filter(|&&s| (200..400).contains(&s)).count();
let ok_pct = ok_count as f64 / total as f64 * 100.0;
if ok_pct > 60.0 {
fp += 1;
} else {
tp += 1;
}
}
Ok((fp, tp))
}
pub fn run(args: ReplayArgs) -> Result<()> {
eprintln!("Loading model from {}...", args.model_path);
eprintln!("Replaying {}...\n", args.input);
let result = replay_and_evaluate(&args)?;
let total = result.total;
eprintln!("═══ Replay Results ═══════════════════════════════════════");
eprintln!(" Total requests: {total}");
eprintln!(" Skipped (parse): {}", result.skipped);
eprintln!(" Allowed: {} ({:.1}%)", result.allowed, pct(result.allowed, total));
eprintln!(" DDoS blocked: {} ({:.1}%)", result.ddos_blocked, pct(result.ddos_blocked, total));
if args.rate_limit {
eprintln!(" Rate limited: {} ({:.1}%)", result.rate_limited, pct(result.rate_limited, total));
}
if !result.ddos_blocked_ips.is_empty() {
eprintln!("\n── DDoS-blocked IPs (top 20) ─────────────────────────────");
let mut sorted: Vec<_> = result.ddos_blocked_ips.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
for (ip, count) in sorted.iter().take(20) {
eprintln!(" {:<40} {} reqs blocked", ip, count);
}
}
if !result.rate_limited_ips.is_empty() {
eprintln!("\n── Rate-limited IPs (top 20) ─────────────────────────────");
let mut sorted: Vec<_> = result.rate_limited_ips.iter().collect();
sorted.sort_by(|a, b| b.1.cmp(a.1));
for (ip, count) in sorted.iter().take(20) {
eprintln!(" {:<40} {} reqs limited", ip, count);
}
}
eprintln!("\n── False positive check ──────────────────────────────────");
if result.false_positive_ips == 0 {
eprintln!(" No likely false positives found.");
} else {
eprintln!("{} IPs were blocked but had mostly successful responses", result.false_positive_ips);
}
eprintln!(" True positive IPs: {}", result.true_positive_ips);
eprintln!("══════════════════════════════════════════════════════════");
Ok(())
}
fn default_rate_limit_config() -> RateLimitConfig {
RateLimitConfig {
enabled: true,
bypass_cidrs: vec![
"10.0.0.0/8".into(),
"172.16.0.0/12".into(),
"192.168.0.0/16".into(),
"100.64.0.0/10".into(),
"fd00::/8".into(),
],
eviction_interval_secs: 300,
stale_after_secs: 600,
authenticated: crate::config::BucketConfig {
burst: 200,
rate: 50.0,
},
unauthenticated: crate::config::BucketConfig {
burst: 60,
rate: 15.0,
},
}
}
fn pct(n: u64, total: u64) -> f64 {
if total == 0 {
0.0
} else {
(n as f64 / total as f64) * 100.0
}
}

View File

@@ -1,13 +1,32 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use crate::ddos::audit_log::AuditLog;
use crate::ddos::audit_log;
use crate::ddos::features::{method_to_u8, FeatureVector, LogIpState, NormParams, NUM_FEATURES};
use crate::ddos::model::{SerializedModel, TrafficLabel};
use anyhow::{bail, Context, Result};
use rustc_hash::{FxHashMap, FxHashSet};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::hash::{Hash, Hasher};
use std::io::BufRead;
/// Legacy KNN training types — kept for the `train-ddos` CLI command
/// which produces bincode model files for offline evaluation.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TrafficLabel {
Normal,
Attack,
}
#[derive(Serialize, Deserialize)]
pub struct SerializedModel {
pub points: Vec<FeatureVector>,
pub labels: Vec<TrafficLabel>,
pub norm_params: NormParams,
pub k: usize,
pub threshold: f64,
}
#[derive(Deserialize)]
pub struct HeuristicThresholds {
/// Requests/second above which an IP is labeled attack
@@ -255,12 +274,12 @@ pub fn parse_logs(input: &str) -> Result<FxHashMap<String, LogIpState>> {
state.statuses.push(entry.fields.status);
state.durations.push(entry.fields.duration_ms.min(u32::MAX as u64) as u32);
state.content_lengths.push(entry.fields.content_length.min(u32::MAX as u64) as u32);
state.has_cookies.push(entry.fields.has_cookies.unwrap_or(false));
state.has_cookies.push(entry.fields.has_cookies);
state.has_referer.push(
entry.fields.referer.as_deref().map(|r| r != "-").unwrap_or(false),
!entry.fields.referer.is_empty() && entry.fields.referer != "-",
);
state.has_accept_language.push(
entry.fields.accept_language.as_deref().map(|a| a != "-").unwrap_or(false),
!entry.fields.accept_language.is_empty() && entry.fields.accept_language != "-",
);
state.suspicious_paths.push(
crate::ddos::features::is_suspicious_path(&entry.fields.path),

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use crate::ddos::model::DDoSAction;
use super::gen::ddos_weights;
use super::mlp::mlp_predict_32;
@@ -80,59 +83,46 @@ mod tests {
use super::*;
#[test]
fn test_tree_allow_path() {
// All zeros → feature 4 (request_rate) = 0.0 <= 0.70 → left (node 1)
// feature 10 (cookie_ratio) = 0.0 <= 0.30left (node 3) → Allow
fn test_tree_block_path() {
// Tree: root splits on feature 10 (cookie_ratio) at 0.14.
// All zeros → cookie_ratio normalized = 0.0 <= 0.14Block (node 1)
let raw = [0.0f32; 14];
let v = ddos_ensemble_predict(&raw);
assert_eq!(v.action, DDoSAction::Block);
assert_eq!(v.path, DDoSEnsemblePath::TreeBlock);
}
#[test]
fn test_tree_allow_path() {
// Tree: feature 10 (cookie_ratio) > 0.14 → node 2 (Allow leaf)
// feature 10 range [0, 1], raw 0.5 → normalized 0.5 > 0.14 → Allow
let mut raw = [0.0f32; 14];
raw[10] = 0.5;
let v = ddos_ensemble_predict(&raw);
assert_eq!(v.action, DDoSAction::Allow);
assert_eq!(v.path, DDoSEnsemblePath::TreeAllow);
assert_eq!(v.reason, "ensemble:tree_allow");
}
#[test]
fn test_tree_block_path() {
// Need: feature 4 (request_rate) > 0.70 normalized → right (node 2)
// feature 12 (accept_language_ratio) > 0.25 normalized → right (node 6) → Block
// feature 4 max = 500, so raw 400 → normalized 0.8 > 0.70 ✓
// feature 12 max = 1.0, so raw 0.5 → normalized 0.5 > 0.25 ✓
let mut raw = [0.0f32; 14];
raw[4] = 400.0;
raw[12] = 0.5;
let v = ddos_ensemble_predict(&raw);
assert_eq!(v.action, DDoSAction::Block);
assert_eq!(v.path, DDoSEnsemblePath::TreeBlock);
}
#[test]
fn test_mlp_path() {
// Need: feature 4 > 0.70 normalized → right (node 2)
// feature 12 <= 0.25 normalized → left (node 5) → Defer
// feature 4 max = 500, raw 400 → 0.8 > 0.70 ✓
// feature 12 max = 1.0, raw 0.1 → 0.1 <= 0.25 ✓
let mut raw = [0.0f32; 14];
raw[4] = 400.0;
raw[12] = 0.1;
let v = ddos_ensemble_predict(&raw);
assert_eq!(v.path, DDoSEnsemblePath::Mlp);
assert_eq!(v.reason, "ensemble:mlp");
assert!(v.score >= 0.0 && v.score <= 1.0);
}
#[test]
fn test_defer_then_mlp_allow() {
// Same Defer path as above — verify the MLP produces a valid action
let mut raw = [0.0f32; 14];
raw[4] = 400.0;
raw[12] = 0.1;
let v = ddos_ensemble_predict(&raw);
assert!(matches!(v.action, DDoSAction::Allow | DDoSAction::Block));
fn test_mlp_direct() {
// Current tree has no Defer leaves, so test MLP inference directly.
let input = [0.5f32; 14];
let score = mlp_predict_32::<14>(
&ddos_weights::W1,
&ddos_weights::B1,
&ddos_weights::W2,
ddos_weights::B2,
&input,
);
assert!(score >= 0.0 && score <= 1.0);
}
#[test]
fn test_normalize_clamps_high() {
// feature 0 max = 10000.0, raw 999999 → clamped to 1.0
let mut raw = [0.0f32; 14];
raw[0] = 999.0; // max is 100
raw[0] = 999999.0;
let normed = normalize(&raw);
assert!((normed[0] - 1.0).abs() < f32::EPSILON);
}
@@ -140,7 +130,7 @@ mod tests {
#[test]
fn test_normalize_clamps_low() {
let mut raw = [0.0f32; 14];
raw[1] = -500.0; // min is 0
raw[1] = -500.0; // min is 1.0
let normed = normalize(&raw);
assert!((normed[1] - 0.0).abs() < f32::EPSILON);
}

View File

@@ -1,71 +1,74 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Auto-generated weights for the ddos ensemble.
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-ddos-mlp`.
pub const THRESHOLD: f32 = 0.50000000;
pub const NORM_MINS: [f32; 14] = [
0.08778746, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.05001374, 0.02000000,
0.00000000, 1.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00833333, 0.02000000,
0.00000000, 1.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000
];
pub const NORM_MAXS: [f32; 14] = [
1000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49990082, 500.00000000, 1.00000000,
10000.00000000, 50.00000000, 19.00000000, 1.00000000, 7589.80468750, 1.49999166, 500.00000000, 1.00000000,
171240.28125000, 30.00000000, 1.00000000, 1.00000000, 1.00000000, 1.00000000
];
pub const W1: [[f32; 14]; 32] = [
[0.57458097, -0.10861993, 0.14037465, -0.23486336, -0.43255216, 0.16347405, 0.71766937, 0.83138502, 0.02852129, 0.56590265, 0.54848498, 0.38098580, 0.82907754, 0.61539698],
[0.06583579, 0.02305713, 0.89706898, 0.42619053, -1.20866120, -0.11974730, 1.70674825, -0.17969023, -0.26867196, 0.60768014, -0.08671998, -0.04825107, 0.58131427, -0.02062579],
[-0.40264100, 0.18836430, 0.08431315, 0.33763552, 0.44880620, -0.40894085, -0.22044741, 0.00533387, -0.61574107, 0.07670992, 0.63528854, 0.48244709, 0.20411402, -1.80697525],
[0.66713083, -0.22220801, 2.11234117, 0.41516641, -0.00165093, 0.65624571, 1.87509167, 0.63406783, -2.54182458, -0.53618753, 2.16407824, -0.61959583, -0.04717547, 0.17551991],
[-0.51027024, -0.60132700, 0.46407551, -0.57346475, -0.30902353, -0.24235034, 0.08087540, -2.14762974, -0.29429656, 0.56257033, -0.26935315, -0.16799171, 0.56852734, 1.93494022],
[-0.24938971, -0.12699288, -0.13746630, 0.64942318, 0.09490766, -0.02158179, 0.72449303, -0.28493983, -0.43053114, -0.01443988, 0.89670080, -0.34539866, -1.47019410, 0.79477930],
[0.62935185, 0.74686801, -0.15527052, -0.06635039, 0.73137009, 0.78417069, -0.06417987, 0.72259408, 0.85131824, 0.00477386, -0.14302900, 0.63481224, 0.92724019, -0.50126070],
[-0.12699059, -0.15016419, -0.48704135, 0.00581611, 0.75824696, 0.84114397, -0.08958503, 0.18609463, 0.56247348, 0.22239330, 0.43324804, 0.82077771, 0.55714250, -0.56955606],
[0.83457869, 0.40054807, 0.23281574, -0.58521581, 1.18067443, -0.49485078, 0.08600014, 0.99104887, -0.65019566, -0.44594154, 0.64507920, -0.61692268, -0.29301512, -0.11314666],
[-0.07868081, -0.18392175, 0.15165123, 0.35139060, 0.13855398, 0.16470867, 0.21025884, 1.57204449, 0.07827333, 0.05895505, 0.00810917, 1.05159700, 0.04605416, 0.38080546],
[1.47428405, -0.21614535, -0.35385504, 0.46582970, 1.26638246, 0.00133375, -3.85603786, 0.39766011, 1.92816520, 0.47828305, -0.16951409, -0.13771342, 0.49451983, 0.41184473],
[-0.23364748, 0.68134952, 0.36342716, 0.02657196, 0.07550839, 0.94861823, -0.52908695, 0.83652318, -0.05639480, 0.26536962, 0.44137934, 1.20957208, -0.60747981, -0.50647283],
[-0.16961956, -0.49570882, -0.33771378, -0.28554109, 0.95865113, -0.49269623, -0.44559151, 1.28568971, 0.79537493, -0.53175420, -3.19015551, 0.52214253, 0.86517984, 0.62523192],
[-0.16956513, -0.61727583, 0.63967121, 0.96406335, -0.28760204, 0.56459671, 0.78585202, -0.03668134, -0.14773002, -0.35764447, 0.84649116, -0.34540027, -0.12314465, -0.10070048],
[-0.34183556, -0.07760386, 0.70894319, 0.92814171, -0.19357866, 0.41449037, 0.54653358, 0.27682835, 0.81471086, 0.56383932, 0.57456553, -0.61491662, 0.92498505, 0.74495614],
[-0.38917324, -0.29217750, 1.43508542, -0.19152534, -0.18823336, 0.45097819, -0.38063127, -0.40419811, 0.56686693, -0.33231607, -0.19567636, -0.02500075, -0.04762971, 0.44703853],
[1.14234805, -0.62868208, -0.21298689, 0.00263968, -0.66115338, -1.12038326, 0.93599045, 0.77646011, -0.22770278, 1.43982041, 0.96078646, 1.15076077, -0.45110813, 0.83090556],
[0.89638984, -0.69683450, -0.29400119, 0.94997799, 0.90305328, -0.80215877, -0.09983492, -0.90757453, -0.03181892, 1.00702441, -0.97962254, -0.89580274, 0.69299418, -0.75975400],
[-0.75832003, -0.07210776, 0.07825917, 1.51633596, 0.44593197, 0.00936707, -0.12142835, -0.09877282, 0.06229200, 1.25678349, 0.25317946, 0.54112315, -0.17941843, 0.93283361],
[0.23085761, 0.53307736, 0.38696140, 0.36798462, 0.38192499, 0.23203450, 0.68225187, 0.47096270, -4.24785280, 0.18062039, 0.60047084, 0.16251479, -0.10811257, 0.48166662],
[0.10870802, 0.01576116, 0.00298645, 0.25878090, -0.16634797, 0.15850464, -0.24267951, 0.87678236, -0.27257833, 0.78637868, -0.00851476, 0.01502728, 0.92175138, -0.81292266],
[-0.74364990, -0.63139439, -0.18314177, -0.36881343, -0.53096825, -0.92442876, -0.05536628, -0.71273297, -0.94937468, -0.03863344, -0.09668982, -1.07886386, 0.58555382, 0.23351164],
[-0.09152136, 0.96538877, -0.11560653, -0.53110164, 0.89070886, 0.05664408, -0.71661353, 0.79684436, -0.00206013, 0.23857179, 0.06074178, -0.67188424, -0.15624331, 0.43436247],
[-0.28189376, -0.00535834, 0.60541785, 0.82968009, -0.21901314, -0.29874969, -0.16872653, 0.45570841, -0.25372767, -0.12359514, -1.10104620, 0.00162374, 0.07622175, 0.60413152],
[-1.13819373, -0.41320390, -5.57348347, 0.40931624, -1.59562767, 0.72510892, 0.03248254, 0.00407641, 0.57557869, 0.53510398, -0.35943517, 0.52707136, 0.61220711, -0.11644226],
[-0.02057049, 0.42545527, 0.24192038, 0.29863021, -0.22839858, -0.25318733, 0.17906551, -0.29471490, -0.04746799, 0.15909556, -0.26826856, -0.06874973, -0.03044286, 0.11770450],
[-0.18060833, -0.06301155, 0.01656315, -0.40476608, -0.35056075, 0.06344713, 0.32273614, -0.04382812, -0.18925793, 0.02124963, -0.23447622, 0.29704437, 0.19138981, -0.04584064],
[0.18248987, 0.05461208, -0.25655189, 0.16673982, 0.03251073, 0.05709980, 0.09135589, 0.06712578, -0.02372392, 0.00487196, -0.11774579, 0.34203079, 0.18477952, 0.09847298],
[-0.08292723, -0.03089223, 0.19555064, -0.18158682, -0.32060555, 0.18836822, -0.14625609, -0.83500093, -0.09893667, 0.02719803, 0.06864946, 0.00156752, 0.04342323, 0.30958080],
[-0.21274266, 0.06035644, 0.27282441, -0.01010289, -0.05599894, 0.27938741, -0.23254848, -0.20086342, -0.06775926, -0.18059292, 0.92534143, 0.09500337, 0.11612320, -0.06473339],
[-0.27279299, 0.96252358, 0.67542273, 0.64720130, 0.15221471, 1.67354584, 0.53074431, 0.65513390, 0.79840666, 0.78613347, 0.34742561, -1.83272552, 0.73313516, 0.09797212],
[-0.08888317, 0.14851266, 1.00953877, 0.19915256, -0.10076691, 0.47210938, 0.04427565, 0.19299655, 0.58729172, 0.17481442, -0.57466495, -0.16196120, 0.06293163, 1.73905540],
[0.30719012, 0.54619288, -0.20061351, 0.29659060, 0.34138879, -0.19088192, 0.34866381, -0.24303232, 0.20615512, 0.12656690, -0.16653502, 0.10961400, -0.16814700, 0.09950374],
[0.46002641, 0.49129087, -0.01386960, 0.64490628, 0.51850092, 0.69266915, 0.31095454, 0.26951542, -0.20926359, -0.09662568, 0.19281134, 0.36575633, 0.23089127, -0.03582983],
[0.04547556, -0.04854088, -0.10979871, -0.18705507, -0.44649851, 0.20949614, 0.73240960, 1.34691823, -0.13529004, 0.69439852, -0.40027520, 0.47921708, -0.43529814, 0.48781869],
[0.61547452, 1.52229679, 0.48276836, 1.27171433, -0.36176509, -0.33192506, -0.82673991, 0.67331636, -0.21094124, 1.03067887, -0.09182073, 1.44520211, 0.52611661, 0.61176163],
[-0.37162983, 0.48245564, -0.53393066, -0.20009390, -0.06583384, 0.17612432, 0.59905756, 1.39533114, 0.67457062, 0.06159161, -0.56609136, -0.29591814, 0.55239469, -0.56801152],
[0.75815558, -0.64557153, 0.84678394, 0.41179815, -0.50619060, -0.09139232, -0.64594650, -0.74464273, 0.87102652, 0.81111395, -0.35027400, 0.95135874, 0.85043454, 1.72117484],
[0.98888069, -0.04631047, -0.62931997, -0.37154037, 1.02817857, -1.04121590, 0.74848920, -0.26426360, 0.23142239, -0.17234743, -0.61689568, -0.59363395, -0.85756373, 0.53024006],
[0.59509462, 0.26622522, -0.74383926, 0.48256168, -0.75522244, 0.46806136, -0.62194610, 0.09251838, 0.25921744, 0.72987258, 0.66349596, 0.53999704, -0.25535119, -0.92465514],
[0.50981742, -1.44853806, -0.64814043, -0.78203505, -0.88038790, 0.68278509, 0.58861315, 0.55924416, -0.52396554, 0.45195666, -0.44876143, -0.11349974, 0.64508075, 0.06376592],
[0.44494510, 0.79238343, -0.22128101, 3.13757062, 0.54972911, 2.06494117, -0.20301908, 0.48413971, -0.25992882, 0.77544886, -0.18115431, 1.87130582, 0.71965748, 1.95458603],
[1.00518668, 1.80238068, -0.28449696, -0.02740687, -2.51049113, 0.56081659, -0.43591678, 3.59169340, -0.47954431, 1.82556272, 0.64387941, 0.56122434, -0.19696619, 3.49070907],
[-0.32992145, -0.03573111, 2.41438532, -0.00748284, 0.62775159, 1.78909039, 0.25103322, 0.59640545, -0.10183074, 0.83787775, 0.14171274, 0.08816884, 0.16381627, -0.04427620],
[0.09841868, 0.58517164, 0.02630968, 0.65797943, -0.03991833, 0.52833039, 0.37459302, 0.01832970, -1.20483434, 0.76000416, 0.02081347, 1.10453236, 0.46800232, 0.50707549],
[0.13568293, -0.04429439, 0.18404786, 0.74804515, -0.02402807, 0.25729915, 0.64555109, 0.09644510, 0.31338552, 0.62685025, -0.19832127, 1.95116663, 0.66340035, 1.29182386],
[-0.20969683, 0.56657153, -0.08705560, 0.71007556, -0.11011623, 1.16174579, 0.65050489, 1.31441426, 0.72755563, 1.15947676, -0.34925875, 0.01019314, -1.42810500, 0.14942981],
[-0.47017330, 2.62149596, -0.37532449, 1.17488575, 0.62930888, 0.62195790, -0.12959687, 2.36229849, -0.25786853, 0.03494137, 1.70790768, -0.02720823, 0.57822198, 1.57692003],
[-0.68229634, 0.87380433, 1.03171849, 0.35238963, -0.78998542, 0.97562903, -0.80616480, 1.07170749, -0.79917014, -0.43357334, 1.09133816, 0.49446958, 1.07970095, 0.27838916],
[-0.77235895, -0.66010702, -0.09969614, -0.38052577, -0.77211934, -0.73416811, -0.67031443, 0.62016815, 0.97461295, 1.07167208, -0.68821293, 0.51563287, -0.73027885, -0.14203216],
[0.90449816, -0.23423387, 1.11039567, 0.61329746, -0.21385542, 0.52449727, 0.42514217, 0.42172486, -0.33397049, 0.35888657, -0.54074812, 0.48481938, -0.05116262, -0.23848286],
[0.67948169, 0.50562781, 0.45344570, 0.47307885, -0.44913152, -0.11515936, 0.14361705, -0.36479098, -0.32777452, 0.11798909, -0.57137913, 0.30936614, 0.31339252, 0.51131296],
[-0.25677630, 0.25580657, -0.12398625, 0.24844812, 0.18556698, 0.21818036, 0.58248550, 0.50517905, 0.34329867, 0.15851928, -0.58440667, 0.33611965, 0.67439252, -0.52770680],
[0.66840053, -0.49819222, 0.29022828, 0.10492916, -0.06216156, 0.37093312, -0.24731418, 0.22893915, 0.32447502, 0.63166237, -0.13788179, 0.52650315, 0.15229015, 0.23656118],
[0.33978519, 0.15498674, -0.25265032, -0.42916322, 0.69121236, 0.20443739, 0.54050952, 0.08900955, -0.13801514, 0.25456557, -0.10714018, 0.08712567, 0.27245566, -0.29683220],
[-0.05526243, -1.17294025, 0.07328646, -0.07892461, 0.31488195, -0.01112767, 0.55462092, 0.65152955, -0.10721418, -0.99451303, 0.00110284, -0.53097665, 0.14362922, 0.17380728],
[-2.95768332, 0.46451911, 0.20220210, 0.76858771, 0.13804838, -0.80371422, -0.11160404, -0.15160939, 0.31488597, -0.10203149, 0.16458754, -0.08558689, -0.27082649, 0.03877234],
[-0.14654562, -0.70086712, 0.09809728, -0.60966188, 0.26278028, 0.07354698, 0.08616283, 0.36018923, 0.07040872, 0.41008693, 0.13071685, 0.18236822, 0.43306109, -0.10742717],
[0.41488791, -0.10255218, 0.10218169, 0.21971215, -0.05527666, -0.50265622, 0.06767768, -0.09040122, 0.16871217, -0.02748547, 0.21738021, 0.21068999, 0.10562737, -0.71913630],
[0.09367306, -0.14113051, -0.44151428, -0.05189204, 0.22411002, -0.09538609, 0.17464676, -0.30709952, 0.21021855, -0.27705607, 0.17645715, 0.19070518, 0.18094100, 0.10115600],
[-0.11084171, -0.60070217, 0.10072551, -0.09865215, 0.19512057, -0.32474023, 0.14499906, 0.06266983, -0.15383074, 0.10347557, -0.10143858, -0.09821036, -0.19187087, -0.21955618],
[0.09774263, 0.21607652, -0.22068830, -0.73502982, 0.14551027, -0.00246539, -0.32017741, -0.14855191, 0.15684886, -0.21544383, -0.36595181, 1.57917106, 0.45341989, 0.64960009],
[0.03760023, 0.12075356, -0.24193284, 0.16418910, -0.13468136, 0.40612614, 0.44222566, 0.17999728, 0.37591749, 0.67439985, -0.29388478, -0.20486754, 0.20614263, -2.63525987],
[-0.10479005, -0.17017230, -0.42374054, 0.30094361, 0.28561834, 0.40433934, -0.03086211, 1.49869466, -0.41601327, 0.20835553, 1.19875181, -0.00222666, 0.51400107, 0.27829245],
];
pub const B1: [f32; 32] = [
-0.80723554, 0.54879200, 0.01237706, -0.22279924, 0.93692911, 0.12226531, -0.54665250, -0.49958101,
-0.20918398, -0.48646352, -0.58741039, -0.50572610, -0.04772990, -0.62962151, -0.46279392, 1.14840722,
-0.04871057, -0.31787100, 1.13966286, 0.69543558, -0.17798270, 0.66968435, -0.07442535, -0.70557600,
0.79021728, 0.65736526, -0.30761406, 0.63242179, 0.83297908, -0.04573143, -0.18454255, -0.30583009
0.76754266, -0.52365464, -0.07451479, -0.24194083, 0.81372803, -0.14967601, 0.86968440, -0.11282827,
0.82378083, 0.03708726, -0.14121835, -0.33332673, -0.24595253, -0.20005627, 0.80769247, 0.67842513,
0.62225562, 0.55104679, 0.87356585, -0.16369765, 0.83232063, -0.40881905, -0.02851989, -0.04714838,
0.69236869, -0.30938062, 0.87852216, -0.14689557, -0.52630597, -0.22946648, -0.13811214, -0.41019145
];
pub const W2: [f32; 32] = [
1.09615684, -0.57856798, -0.08730038, -0.06425755, -0.96232760, -2.06290460, 0.70097560, 0.85189444,
-0.10077959, 1.94375157, 0.74497795, 0.88425481, 2.11908054, 0.85526127, 0.61624259, -2.93621016,
1.52211487, 0.56318259, -3.15219641, -0.55187315, 1.61819077, -0.76258671, -0.09362544, 0.86861998,
-0.79028755, -0.90605170, 0.33475992, -0.79945564, -1.16680586, 0.15120529, 0.17619221, 1.61664009
-0.84622073, 2.32144451, 0.70330697, 0.89360833, -1.08053613, 0.69213301, -1.07218480, 0.82345659,
-1.11953294, -2.58824420, 0.81520051, 1.19865966, 0.91804677, 1.04554057, -1.03049874, -0.94034135,
-1.66193688, -1.53192282, -1.09629154, -4.07772017, -1.14778209, 1.15202129, 0.42650393, 0.55174673,
-1.28319669, 2.06129408, -1.10220599, 0.09728605, 1.64764512, -0.14975634, 0.79428691, 1.56726408
];
pub const B2: f32 = -0.52729088;
pub const B2: f32 = -0.58103424;
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
(3, 0.30015790, 1, 2),
(255, 0.00000000, 0, 0),
(10, 0.13999981, 1, 2),
(255, 1.00000000, 0, 0),
(255, 0.00000000, 0, 0),
];

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Auto-generated weights for the scanner ensemble.
//! DO NOT EDIT — regenerate with `cargo run --features training -- train-scanner-mlp`.
@@ -14,55 +17,55 @@ pub const NORM_MAXS: [f32; 12] = [
];
pub const W1: [[f32; 12]; 32] = [
[2.25848985, 1.62502551, 1.05068624, -0.23875977, 1.29692984, -1.34665418, -1.29937541, 1.66119707, -1.43897200, 0.07720046, -1.17116165, 1.96821272],
[2.30885172, 0.02477695, -0.23236598, 1.66507626, -1.41407740, 1.88616431, 1.84703696, -1.46395433, 2.03542018, 1.68318951, 2.01550031, 1.94223917],
[2.29420924, 1.86615539, 1.69271469, 1.42137837, 1.43151915, 1.84876072, 1.09228194, 1.73608077, 0.20805965, 0.52542430, -0.02558800, 0.04718366],
[0.36484259, -0.02785611, -0.01155548, 0.08577330, -0.00468449, -0.07848717, 0.05191587, 0.50796396, 0.40799347, -0.14838840, -0.30566201, 0.00758083],
[0.28191370, 0.20945202, 0.07742970, -0.06654347, 0.17395714, 0.00011351, 0.37079588, 0.41817516, 0.56992871, 0.05705916, 0.22339216, 0.11021475],
[0.06522971, 0.64510870, 0.31671444, 0.34980071, 0.03446164, -0.10592904, -0.21302676, -0.04404496, 0.08638768, 0.04217484, 0.43021953, 0.21055792],
[0.31206250, -0.14565454, 0.38078794, 0.00860748, 0.29409558, -0.11273954, -0.02210701, 0.15525217, 0.09696059, 0.13877581, 0.06483351, 0.10946950],
[0.28374705, -0.02963164, 0.27863786, -0.23428085, 0.12715313, 0.09141072, 0.07769041, 0.01915955, -0.20936646, 0.02813511, -0.03910714, 0.30322370],
[-1.19449413, -0.84935474, -0.32267663, -0.08140022, -0.78729230, 1.58759272, 0.88281459, -0.77263606, 1.55394125, 0.10148179, 1.59524822, -0.75499195],
[-0.97152823, -0.12173092, 0.04745778, -0.85466659, 1.57352293, -0.52651149, -0.66270715, 1.32282484, -1.24654925, -0.45822921, -1.10187364, -0.91162699],
[-0.93944395, -0.57891464, -1.12100291, -0.38871467, -0.18780440, -1.11835766, -0.43614236, -1.07918274, -0.09222561, -0.23854440, -0.16720718, 0.03247443],
[0.13319625, 0.87437463, 0.32213065, 0.13902900, 0.64760798, 0.00899744, 0.45325586, -0.14138180, 0.13888212, 0.07780524, -0.12482210, 0.12632932],
[0.57018995, -0.10839911, 0.02787536, 0.16884641, 0.19435850, -0.01189608, 0.13881874, -0.10700739, -0.05463003, 0.01371983, 0.04385772, 0.01100468],
[-0.26600277, -0.11843663, -0.01081531, 0.10785927, -0.18684258, 0.08537511, 0.01054722, -0.01972559, -0.07416820, 0.57192892, 0.37873995, -0.00498434],
[0.72535324, -0.25030360, 0.51470703, -0.16410951, -0.13649474, 0.16246459, -0.27847841, 0.12250750, 0.45576489, -0.18535912, -0.45686084, 0.58293521],
[0.18614589, -0.32835677, -0.08683094, 0.07748202, -0.24785264, -0.16834147, 0.27066526, 0.06058804, 0.01903199, -0.17387865, 0.12752151, -0.03780220],
[-1.22358644, -0.78316134, -0.54068804, -0.07921790, -0.72697675, 1.80127227, 0.14326867, -0.51875746, 1.83125353, -0.02672976, 1.68589675, -0.80162954],
[-0.83690810, -0.12682360, 0.10783038, -0.64648604, 1.50810242, -0.48788729, -0.59418935, 0.94863659, -0.84788662, -0.49779284, -0.96408021, -1.14068258],
[-0.96322638, -0.50503486, -0.87195945, -0.34710455, -0.28645220, -1.10507452, -0.32122782, -0.80753750, -0.00843489, 0.04215550, 0.03197355, 0.05468401],
[-0.17587705, 0.45144933, 0.37954769, -0.15405300, 0.75590396, 0.00346784, 0.62332457, -0.15602241, -0.26471916, -0.19963606, -0.22497311, -0.20784236],
[0.60608941, -0.05316854, 0.03766245, 0.46412235, -0.41121334, -0.01225545, -0.11125158, -0.33533856, -0.04625564, -0.02995013, -0.24979964, -0.35824969],
[0.08163761, 0.04702193, -0.24007457, -0.23439978, 0.27066308, 0.48389259, 0.32692793, -0.23089454, 0.26520243, -0.14099684, 0.06713670, 0.14434725],
[-0.50808382, -0.14518137, -0.23912378, 0.33510539, 0.46566108, 0.09035082, -0.12637842, 0.55245715, -0.19972627, 0.24517706, 0.34291887, 0.01936621],
[0.35826349, 0.21200819, 0.65315312, -0.16792546, 0.41378024, 0.32129642, 0.50814188, 0.48289016, 0.06839173, 0.42079177, 0.52295685, 0.26273951],
[0.24575019, 0.10700949, 0.07041252, -0.09410189, 0.18897925, 0.31616825, -0.01306109, 0.33499330, -0.01866218, 0.06233863, 0.15316568, 0.08370106],
[0.17828286, 0.17363867, -0.10626584, 0.06075979, 0.39465010, 0.19557165, 0.30352867, 0.26720291, 0.40256795, 0.13942246, 0.05869288, 0.08310238],
[-0.04834138, 0.29206491, 0.01330532, 0.07626399, -0.17378819, 0.09515948, 0.02298534, 0.41555724, 0.09492048, 0.39422533, 0.39373979, 0.20463347],
[-0.11641891, -0.06529939, -0.18899654, -0.02157970, -0.03554495, 0.10956290, -0.11688691, 0.04077352, 0.34220406, -0.09558969, 0.16150762, 0.25759667],
[-0.17313123, 0.00591523, 0.29443163, 0.08298909, 0.07761172, 0.19023541, 0.23826212, -0.07167042, 0.08753359, 0.17917964, -0.03248737, 0.28516129],
[0.13091524, 0.21435370, 0.15093684, 0.30902347, 0.44151527, 0.55901742, 0.19933179, 0.06438518, 0.30585650, -0.34089112, 0.26879075, 0.12928906],
[-0.25311065, -0.09963353, -0.50099874, 0.57481062, 0.38744658, -0.13065037, 0.18897361, 0.49376330, -0.15626629, 0.19911517, 0.06437352, -0.09104283],
[0.35787049, -0.04814727, 0.45446551, -0.15264697, 0.36565515, 0.22795495, 0.24630190, 0.16362202, 0.21044184, 0.53882843, 0.42343852, 0.18454899],
[-0.24775992, -1.12687504, -0.84061003, 1.35276508, 1.04677176, 1.57832956, 1.47995067, 1.38580477, -0.99564040, -1.20309269, -0.24385734, 1.32367671],
[-1.31796920, -0.12229957, 0.89794689, 1.14832735, 1.17210162, 1.32387733, 1.37799740, 1.22984815, 0.92816162, 1.45189691, 0.97822803, -0.89625973],
[-1.46790743, -1.43995631, 1.44276273, 1.30967343, 1.37424576, -1.16613543, 1.46063673, -1.29701447, 0.20172349, -0.05374112, -0.02462341, 0.25295240],
[0.57722956, 0.39603359, -0.03453349, -0.06384056, 0.08056496, 0.30015573, -0.19045275, 0.39737019, 0.41070747, 0.03414693, -0.25483453, 0.24018581],
[0.02473495, 0.05047939, 0.59600842, 0.35709319, -0.32843903, 0.17655721, 0.20441811, 0.05118980, 0.24312067, 0.12677322, 0.12109834, -0.09230414],
[-0.03397392, -0.08393703, 0.42762470, 0.21609549, -0.28864771, -0.11998425, -0.05842599, 0.87331069, 0.52526826, 0.35321006, 0.31293729, 0.51823896],
[-0.26906690, -0.04246650, 0.05114691, 0.27629042, -0.37589836, -0.07396606, -0.46509269, 0.26079515, 0.33588448, 0.55919188, 0.58768439, 0.81638134],
[-0.48464304, 0.76111192, 0.06296005, -0.13527128, -0.41344830, -0.19461812, 0.52815431, 0.96815300, 0.47175986, -0.13977771, 0.41216326, -0.03041369],
[-0.05156134, 0.90191734, 0.78513384, -1.25017786, -0.54637259, -0.36098620, -0.59882820, -0.90374511, 0.91336489, 1.05806208, 0.04994302, -0.91028821],
[0.68376511, 0.07305489, 0.44790089, -0.56647295, -0.66570538, -0.93897015, -0.40558589, -1.51070845, 0.46759781, -1.36738360, -0.05270236, 0.98130196],
[0.86788160, 0.69321704, -0.53778958, -1.54257190, -0.44623125, 0.72615588, -0.75269628, 0.81946337, 0.17503875, 0.63745797, 0.48478079, -0.31573632],
[-0.01361719, 0.21524477, -0.10345778, -0.38488832, 0.42967409, 0.75472528, -0.07410870, -0.65231675, 0.42633417, -0.10289414, 0.09583388, -0.29391766],
[-0.06778818, -0.44469842, 0.05952910, -0.55139810, 0.14308600, -0.53731138, -0.07426350, 0.28065708, 0.29584157, 0.47813708, 0.02095048, -0.36458421],
[-0.21983556, 0.55435538, -0.13939659, 0.58281261, -0.20551582, 0.30075905, 0.13396217, -0.18145087, -0.43283740, 0.18541494, 0.07530790, -0.04916608],
[0.46939296, 0.57935077, -0.05478116, -0.01144989, 0.54106784, -0.18313073, 0.12232503, -0.32802504, -0.01167463, -0.13702804, -0.19521871, 0.09115479],
[0.25263065, -0.27172634, -0.12802953, 0.50027740, 0.05213343, 0.49081728, 0.15367918, 0.20471051, -0.22081012, 0.51709008, 0.01776243, 0.22513707],
[-0.20733139, 0.60041994, 0.05273124, 0.07473211, 0.14580894, -0.72007078, -0.52350652, -0.15482022, 0.19132918, 0.52586436, -0.04793828, -0.00479114],
[0.19872986, -0.19177110, -0.22340146, -0.48786804, -0.51010352, -0.55363113, -0.29520389, -0.21378680, -0.40099174, -0.09184421, -0.08521358, 0.61833692],
[0.21346046, 0.53319895, -0.44765636, -0.04764151, -0.30569363, 0.19765340, -0.41479719, 0.34292534, -0.29234713, 0.54341668, 0.60121793, -0.00226344],
[-0.29598647, -0.37357926, -0.25650844, -0.05165816, 0.55829030, 0.21028350, -0.28581545, -0.37299931, 0.57590896, -0.01573592, -0.19411144, 0.13814686],
[0.10028259, 0.02526089, -0.20488358, 0.25667843, 0.17100072, 0.01034015, -0.32994771, 0.53425753, 0.64935833, 0.30769956, -0.26756367, -0.03389005],
[-0.18916467, 0.38340616, -0.16475976, 0.59811211, 0.12739281, -0.16611671, -0.31913927, 0.07577144, 0.28552490, 0.54843456, 0.40937552, 0.38236183],
[-0.20519450, -0.04122134, -0.20013523, 0.42193425, -0.27304563, -0.21811043, 0.13115846, 0.16724831, 0.13073303, 0.20491999, 0.31806493, 0.13444173],
[0.01762132, 0.32608625, 0.19381267, -0.33404192, -0.46299583, -0.28042898, 0.20772585, 0.20139317, 0.41952321, -0.30363685, 0.20015827, -0.03338646],
[0.13760759, 0.07168494, 0.26161709, 0.41468662, -0.03778528, 0.38290465, 0.48780030, 0.39562985, 0.24758396, -0.05975538, -0.22738078, 0.27877593],
[0.07016940, -0.03804595, -0.08812129, 0.19664441, 0.13347355, 0.50309300, 0.26076415, 0.19044210, -0.20414594, 0.64333421, 0.15160090, 0.16449226],
[0.31039700, -0.01906084, 0.25622010, 0.10707659, 0.54883337, 0.19277412, 0.42004701, -0.09319381, 0.19968294, 0.07109389, -0.28979829, 0.12353907],
[0.28500485, 0.01991569, 0.05190456, 0.29366553, 0.01045146, -0.02013574, -0.01796320, 0.13775185, 0.11095868, -0.25678155, 0.10733776, -0.07584792],
[0.12738188, 0.07762879, -0.06429479, 0.39944342, 0.07958066, 0.46697047, -0.10674930, -0.12212183, -0.01540831, 0.08788434, 0.17299946, 0.25846422],
[0.26692817, 0.00930361, 0.24862845, 0.02167275, -0.09902105, -0.35391217, -0.41734406, 0.44949567, 0.46330830, 0.40603620, 0.08397861, 0.39809385],
[-0.30756459, -0.43368185, -0.00478506, 0.45611116, -0.05069341, 0.21090019, 0.28219289, 0.07687758, 0.54915971, 0.46933413, 0.35599890, 0.17573997],
[-0.19320646, 0.44751191, -0.14140815, 0.00427075, -0.19792002, -0.19400074, 0.19292155, 0.39845818, 0.21028778, -0.10284913, 0.31191504, -0.36995885],
];
pub const B1: [f32; 32] = [
1.12135851, 0.64268047, 0.44761124, -0.28471574, 0.70866716, -0.25293177, -0.19119856, 0.39284116,
-0.20628852, -0.29301032, -0.08837436, 0.92048728, 0.91167349, -0.33615190, -0.06016272, 0.79141164,
-0.43257964, 0.48180589, 0.70891160, -0.24290052, 0.83115542, 0.69964927, 0.97887653, 1.34517038,
1.10292709, 0.42009205, 1.07155228, 0.61349720, 0.46157768, 1.01911950, 0.51159418, 0.60460496
-0.24357778, -0.24826238, -0.03415382, 0.00968227, 0.51550633, 0.45242083, 0.60654080, 0.25456131,
-0.36509025, -0.22825000, 0.03829522, 0.65561563, -0.19379658, -0.25716159, 0.45115772, 0.73442084,
0.61352992, 0.59502298, 0.32757106, 0.28512844, 0.26663530, 0.27169749, 0.33571365, -0.34503689,
-0.08054741, -0.06313029, 0.43629149, 0.35936099, 0.39375633, -0.19984132, 0.49092621, -0.27418151
];
pub const W2: [f32; 32] = [
1.55191231, 1.27754235, 0.43588921, 0.10868450, 0.55931729, -1.46911597, -0.54461092, 0.78240824,
-1.25938582, -0.06287600, -1.02053738, 1.07076716, 1.58776867, -0.03168033, -0.11393511, 1.30535436,
-1.46621227, 0.62925971, 0.76781118, -0.74480098, 1.29669034, 0.62078375, 1.64134884, 2.09736991,
1.52834618, 0.87368065, 1.80090642, 0.89230227, 0.38757962, 1.80718291, 0.64923352, 1.18709576
-0.04975564, -0.54667705, -0.39323062, 0.72362727, 0.86801738, 1.93621075, 1.01259410, 0.75978750,
-0.67997259, -0.63063931, -0.07149173, 0.81899148, -0.69025612, -0.12359849, 1.09533453, 0.88092262,
0.89678788, 0.87908030, 1.12460852, 0.76745653, 0.85632098, 0.72992527, 0.93983871, -0.55915666,
-0.61104172, -0.56369978, 1.43480921, 0.71174467, 1.03119624, -0.57950914, 0.81188917, -0.78019017
];
pub const B2: f32 = 0.23270580;
pub const B2: f32 = 0.24626280;
pub const TREE_NODES: [(u8, f32, u16, u16); 3] = [
(3, 0.50000000, 1, 2),

View File

@@ -1,6 +1,10 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Replay audit logs through the ensemble models (scanner + DDoS).
use crate::ddos::audit_log::{self, AuditLog};
use crate::audit::AuditLogLine;
use crate::ddos::audit_log;
use crate::ddos::features::{method_to_u8, LogIpState};
use crate::ddos::model::DDoSAction;
use crate::ensemble::ddos::{ddos_ensemble_predict, DDoSEnsemblePath};
@@ -26,21 +30,31 @@ pub fn run(args: ReplayEnsembleArgs) -> Result<()> {
std::fs::File::open(&args.input).with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
// --- Parse all entries ---
let mut entries: Vec<AuditLog> = Vec::new();
let mut parse_errors = 0u64;
// --- Parse all entries, filtering for audit logs only ---
let mut entries: Vec<AuditLogLine> = Vec::new();
let mut skipped_non_audit = 0u64;
let mut schema_errors = 0u64;
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
match serde_json::from_str::<AuditLog>(&line) {
Ok(e) => entries.push(e),
Err(_) => parse_errors += 1,
match AuditLogLine::try_parse(&line) {
Ok(Some(entry)) => entries.push(entry),
Ok(None) => skipped_non_audit += 1,
Err(e) => {
schema_errors += 1;
if schema_errors <= 3 {
eprintln!(" schema error: {e}");
}
}
}
}
let total = entries.len() as u64;
eprintln!("parsed {} entries ({} parse errors)\n", total, parse_errors);
eprintln!(
"parsed {} audit entries ({} non-audit skipped, {} schema errors)\n",
total, skipped_non_audit, schema_errors,
);
// --- Scanner replay ---
eprintln!("═══ Scanner Ensemble ═════════════════════════════════════");
@@ -54,7 +68,7 @@ pub fn run(args: ReplayEnsembleArgs) -> Result<()> {
Ok(())
}
fn replay_scanner(entries: &[AuditLog]) {
fn replay_scanner(entries: &[AuditLogLine]) {
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
@@ -73,23 +87,15 @@ fn replay_scanner(entries: &[AuditLog]) {
let mut blocked = 0u64;
let mut allowed = 0u64;
let mut path_counts = [0u64; 3]; // TreeBlock, TreeAllow, Mlp
let mut blocked_examples: Vec<(String, String, f64)> = Vec::new(); // (path, reason, score)
let mut fp_candidates: Vec<(String, u16, f64)> = Vec::new(); // blocked but had 2xx status
let mut blocked_examples: Vec<(String, String, String, f64)> = Vec::new(); // (path, ua, reason, score)
let mut fp_candidates: Vec<(String, String, u16, f64)> = Vec::new(); // blocked but had 2xx status
for e in entries {
let f = &e.fields;
let host_prefix = f.host.split('.').next().unwrap_or("");
let has_cookies = f.has_cookies.unwrap_or(false);
let has_referer = f
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = f
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let has_cookies = f.has_cookies;
let has_referer = !f.referer.is_empty() && f.referer != "-";
let has_accept_language = !f.accept_language.is_empty() && f.accept_language != "-";
let feats = features::extract_features_f32(
&f.method,
@@ -98,7 +104,7 @@ fn replay_scanner(entries: &[AuditLog]) {
has_cookies,
has_referer,
has_accept_language,
"-",
&f.accept,
&f.user_agent,
f.content_length,
&fragment_hashes,
@@ -121,12 +127,13 @@ fn replay_scanner(entries: &[AuditLog]) {
if blocked_examples.len() < 20 {
blocked_examples.push((
f.path.clone(),
f.user_agent.clone(),
verdict.reason.to_string(),
verdict.score,
));
}
if (200..400).contains(&f.status) {
fp_candidates.push((f.path.clone(), f.status, verdict.score));
fp_candidates.push((f.path.clone(), f.user_agent.clone(), f.status, verdict.score));
}
}
ScannerAction::Allow => allowed += 1,
@@ -159,8 +166,9 @@ fn replay_scanner(entries: &[AuditLog]) {
if !blocked_examples.is_empty() {
eprintln!("\n blocked examples (first 20):");
for (path, reason, score) in &blocked_examples {
for (path, ua, reason, score) in &blocked_examples {
eprintln!(" {:<50} {reason} (score={score:.3})", truncate(path, 50));
eprintln!(" ua: {}", truncate(ua, 72));
}
}
@@ -170,16 +178,17 @@ fn replay_scanner(entries: &[AuditLog]) {
"\n potential false positives (blocked but had 2xx/3xx): {}",
fp_count
);
for (path, status, score) in fp_candidates.iter().take(10) {
for (path, ua, status, score) in fp_candidates.iter().take(10) {
eprintln!(
" {:<50} status={status} score={score:.3}",
truncate(path, 50)
);
eprintln!(" ua: {}", truncate(ua, 72));
}
}
}
fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) {
fn replay_ddos(entries: &[AuditLogLine], window_secs: f64, min_events: usize) {
fn fx_hash(s: &str) -> u64 {
let mut h = rustc_hash::FxHasher::default();
s.hash(&mut h);
@@ -208,18 +217,12 @@ fn replay_ddos(entries: &[AuditLog], window_secs: f64, min_events: usize) {
.push(f.content_length.min(u32::MAX as u64) as u32);
state
.has_cookies
.push(f.has_cookies.unwrap_or(false));
.push(f.has_cookies);
state.has_referer.push(
f.referer
.as_deref()
.map(|r| r != "-")
.unwrap_or(false),
!f.referer.is_empty() && f.referer != "-",
);
state.has_accept_language.push(
f.accept_language
.as_deref()
.map(|a| a != "-")
.unwrap_or(false),
!f.accept_language.is_empty() && f.accept_language != "-",
);
state
.suspicious_paths

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use crate::scanner::model::{ScannerAction, ScannerVerdict};
use super::gen::scanner_weights;
use super::mlp::mlp_predict_32;
@@ -92,27 +95,11 @@ impl From<EnsembleVerdict> for ScannerVerdict {
mod tests {
use super::*;
#[test]
fn test_tree_allow_path() {
// All features at zero → feature 3 (suspicious_ua) = 0.0 <= 0.65 → left (node 1)
// feature 0 (path_depth) = 0.0 <= 0.40 → left (node 3) → Allow leaf
let raw = [0.0f32; 12];
let v = scanner_ensemble_predict(&raw);
assert_eq!(v.action, ScannerAction::Allow);
assert_eq!(v.path, EnsemblePath::TreeAllow);
assert_eq!(v.reason, "ensemble:tree_allow");
assert!((v.score - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_tree_block_path() {
// Need: feature 3 (suspicious_ua) > 0.65 (normalized) → right (node 2)
// feature 7 (payload_entropy) > 0.72 (normalized) → right (node 6) → Block
// feature 3 max = 1.0, so raw 0.8 → normalized 0.8 > 0.65 ✓
// feature 7 max = 8.0, so raw 6.0 → normalized 0.75 > 0.72 ✓
let mut raw = [0.0f32; 12];
raw[3] = 0.8; // suspicious_ua: normalized = 0.8/1.0 = 0.8 > 0.65
raw[7] = 6.0; // payload_entropy: normalized = 6.0/8.0 = 0.75 > 0.72
// Tree: root splits on feature 7 (ua_category) at 0.75.
// All zeros → ua_category normalized = 0.0 <= 0.75 → Block (node 1)
let raw = [0.0f32; 12];
let v = scanner_ensemble_predict(&raw);
assert_eq!(v.action, ScannerAction::Block);
assert_eq!(v.path, EnsemblePath::TreeBlock);
@@ -120,28 +107,38 @@ mod tests {
}
#[test]
fn test_mlp_path() {
// Need: feature 3 > 0.65 normalized → right (node 2)
// feature 7 <= 0.72 normalized → left (node 5) → Defer
// Then MLP runs on the normalized input.
fn test_tree_allow_path() {
// Tree: root feature 7 > 0.75 → node 2, checks feature 3 (has_cookies) at 0.25.
// raw[7] = 1.0 → normalized 1.0 > 0.75 → right.
// raw[3] = 1.0 → normalized ~0.7 > 0.25 → right child node 6 → Allow leaf.
let mut raw = [0.0f32; 12];
raw[3] = 0.8; // normalized = 0.8 > 0.65
raw[7] = 4.0; // normalized = 4.0/8.0 = 0.5 <= 0.72
// Also need feature 2 (query_param_count) to navigate node 5 correctly
// node 5: split on feature 2, threshold 0.55 → left=9(Defer), right=10
// normalized feature 2 = 0.0/20.0 = 0.0 <= 0.55 → left (node 9) → Defer
raw[7] = 1.0; // ua_category = browser
raw[3] = 1.0; // has_cookies = yes
let v = scanner_ensemble_predict(&raw);
assert_eq!(v.path, EnsemblePath::Mlp);
assert_eq!(v.reason, "ensemble:mlp");
// MLP output is deterministic for these inputs
assert!(v.score >= 0.0 && v.score <= 1.0);
assert_eq!(v.action, ScannerAction::Allow);
assert_eq!(v.path, EnsemblePath::TreeAllow);
assert_eq!(v.reason, "ensemble:tree_allow");
}
#[test]
fn test_mlp_direct() {
// Current tree has no Defer leaves, so test MLP inference directly.
let input = [0.5f32; 12];
let score = mlp_predict_32::<12>(
&scanner_weights::W1,
&scanner_weights::B1,
&scanner_weights::W2,
scanner_weights::B2,
&input,
);
assert!(score >= 0.0 && score <= 1.0);
}
#[test]
fn test_normalize_clamps() {
// Values beyond max should be clamped to 1.0
let mut raw = [0.0f32; 12];
raw[0] = 100.0; // max is 10.0
raw[0] = 100.0;
let normed = normalize(&raw);
assert!((normed[0] - 1.0).abs() < f64::EPSILON as f32);
}

View File

@@ -1,7 +1,12 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
// Library crate root — exports the proxy/config/acme modules so that
// integration tests in tests/ can construct and drive a SunbeamProxy
// without going through the binary entry point.
#![recursion_limit = "256"]
pub mod acme;
pub mod audit;
pub mod autotune;
pub mod cache;
pub mod cluster;

View File

@@ -1,10 +1,12 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
mod cert;
mod telemetry;
mod watcher;
use sunbeam_proxy::{acme, autotune, config};
use sunbeam_proxy::{acme, config};
use sunbeam_proxy::proxy::SunbeamProxy;
use sunbeam_proxy::ddos;
use sunbeam_proxy::rate_limit;
use sunbeam_proxy::scanner;
@@ -32,77 +34,18 @@ enum Commands {
#[arg(long)]
upgrade: bool,
},
/// Replay audit logs through detection models
/// Replay audit logs through ensemble models (scanner + DDoS)
Replay {
#[command(subcommand)]
mode: ReplayMode,
},
/// Train a DDoS detection model from audit logs
TrainDdos {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output model file path
#[arg(short, long)]
output: String,
/// File with known-attack IPs (one per line)
#[arg(long)]
attack_ips: Option<String>,
/// File with known-normal IPs (one per line)
#[arg(long)]
normal_ips: Option<String>,
/// TOML file with heuristic auto-labeling thresholds
#[arg(long)]
heuristics: Option<String>,
/// KNN k parameter
#[arg(long, default_value = "5")]
k: usize,
/// Attack threshold (fraction of k neighbors)
#[arg(long, default_value = "0.6")]
threshold: f64,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP to include in training
#[arg(long, default_value = "10")]
/// Minimum events per IP before DDoS classification
#[arg(long, default_value = "5")]
min_events: usize,
},
/// Train a per-request scanner detection model from audit logs
TrainScanner {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output model file path
#[arg(short, long, default_value = "scanner_model.bin")]
output: String,
/// Directory (or file) containing .txt wordlists of scanner paths
#[arg(long)]
wordlists: Option<String>,
/// Classification threshold
#[arg(long, default_value = "0.5")]
threshold: f64,
/// Include CSIC 2010 dataset as base training data (downloaded from GitHub, cached locally)
#[arg(long)]
csic: bool,
},
/// Bayesian hyperparameter optimization for DDoS model
AutotuneDdos {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output best model file path
#[arg(short, long, default_value = "ddos_model_best.bin")]
output: String,
/// Number of optimization trials
#[arg(long, default_value = "200")]
trials: usize,
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
#[arg(long, default_value = "1.0")]
beta: f64,
/// JSONL file to log each trial's parameters and results
#[arg(long)]
trial_log: Option<String>,
},
/// Download and cache upstream datasets (CIC-IDS2017)
DownloadDatasets,
/// Prepare a unified training dataset from multiple sources
@@ -125,6 +68,12 @@ enum Commands {
/// Path to heuristics.toml for auto-labeling production logs
#[arg(long)]
heuristics: Option<String>,
/// Inject CSIC 2010 dataset as labeled audit log entries
#[arg(long)]
inject_csic: bool,
/// Inject OWASP ModSec audit log entries (path to .log file)
#[arg(long)]
inject_modsec: Option<String>,
},
#[cfg(feature = "training")]
/// Train scanner ensemble (decision tree + MLP) from prepared dataset
@@ -142,7 +91,7 @@ enum Commands {
#[arg(long, default_value = "100")]
epochs: usize,
/// Learning rate
#[arg(long, default_value = "0.001")]
#[arg(long, default_value = "0.0001")]
learning_rate: f64,
/// Batch size
#[arg(long, default_value = "64")]
@@ -153,6 +102,12 @@ enum Commands {
/// Min purity for tree leaves (below -> Defer)
#[arg(long, default_value = "0.90")]
tree_min_purity: f32,
/// Min samples required in a leaf node (higher = less overfitting)
#[arg(long, default_value = "2")]
min_samples_leaf: usize,
/// Weight for cookie feature (0.0=ignore, 1.0=full). Controls has_cookies influence.
#[arg(long, default_value = "1.0")]
cookie_weight: f32,
},
#[cfg(feature = "training")]
/// Train DDoS ensemble (decision tree + MLP) from prepared dataset
@@ -165,184 +120,82 @@ enum Commands {
hidden_dim: usize,
#[arg(long, default_value = "100")]
epochs: usize,
#[arg(long, default_value = "0.001")]
#[arg(long, default_value = "0.0001")]
learning_rate: f64,
#[arg(long, default_value = "64")]
batch_size: usize,
#[arg(long, default_value = "6")]
tree_max_depth: usize,
/// Min purity for tree leaves (below -> Defer)
#[arg(long, default_value = "0.90")]
tree_min_purity: f32,
},
/// Bayesian hyperparameter optimization for scanner model
AutotuneScanner {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output best model file path
#[arg(short, long, default_value = "scanner_model_best.bin")]
output: String,
/// Directory (or file) containing .txt wordlists of scanner paths
#[arg(long)]
wordlists: Option<String>,
/// Include CSIC 2010 dataset as base training data
#[arg(long)]
csic: bool,
/// Number of optimization trials
#[arg(long, default_value = "200")]
trials: usize,
/// F-beta parameter (1.0 = F1, 2.0 = recall-weighted)
/// Min samples required in a leaf node (higher = less overfitting)
#[arg(long, default_value = "2")]
min_samples_leaf: usize,
/// Weight for cookie feature (0.0=ignore, 1.0=full). Controls cookie_ratio influence.
#[arg(long, default_value = "1.0")]
beta: f64,
/// JSONL file to log each trial's parameters and results
cookie_weight: f32,
},
#[cfg(feature = "training")]
/// Sweep cookie_weight values and report tree structure + validation accuracy for each
SweepCookieWeight {
/// Path to prepared dataset (.bin)
#[arg(short = 'd', long)]
dataset: String,
/// Which detector to sweep: "scanner" or "ddos"
#[arg(long, default_value = "scanner")]
detector: String,
/// Comma-separated cookie_weight values to try (default: 0.0,0.1,0.2,...,1.0)
#[arg(long)]
trial_log: Option<String>,
weights: Option<String>,
/// Max tree depth
#[arg(long, default_value = "6")]
tree_max_depth: usize,
/// Min purity for tree leaves
#[arg(long, default_value = "0.90")]
tree_min_purity: f32,
/// Min samples required in a leaf node
#[arg(long, default_value = "2")]
min_samples_leaf: usize,
},
}
#[derive(Subcommand)]
enum ReplayMode {
/// Replay through ensemble models (scanner + DDoS)
Ensemble {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP before DDoS classification
#[arg(long, default_value = "5")]
min_events: usize,
},
/// Replay through legacy KNN DDoS detector
Ddos {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Path to trained model file
#[arg(short, long, default_value = "ddos_model.bin")]
model: String,
/// Optional config file (for rate limit settings)
#[arg(short, long)]
config: Option<String>,
/// KNN k parameter
#[arg(long, default_value = "5")]
k: usize,
/// Attack threshold
#[arg(long, default_value = "0.6")]
threshold: f64,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP before classification
#[arg(long, default_value = "10")]
min_events: usize,
/// Also run rate limiter during replay
#[arg(long)]
rate_limit: bool,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
Commands::Serve { upgrade } => run_serve(upgrade),
Commands::Replay { mode } => match mode {
ReplayMode::Ensemble { input, window_secs, min_events } => {
Commands::Replay { input, window_secs, min_events } => {
sunbeam_proxy::ensemble::replay::run(sunbeam_proxy::ensemble::replay::ReplayEnsembleArgs {
input, window_secs, min_events,
})
}
ReplayMode::Ddos { input, model, config, k, threshold, window_secs, min_events, rate_limit } => {
ddos::replay::run(ddos::replay::ReplayArgs {
input, model_path: model, config_path: config, k, threshold, window_secs, min_events, rate_limit,
})
}
},
Commands::TrainDdos {
input,
output,
attack_ips,
normal_ips,
heuristics,
k,
threshold,
window_secs,
min_events,
} => ddos::train::run(ddos::train::TrainArgs {
input,
output,
attack_ips,
normal_ips,
heuristics,
k,
threshold,
window_secs,
min_events,
}),
Commands::TrainScanner {
input,
output,
wordlists,
threshold,
csic,
} => scanner::train::run(scanner::train::TrainScannerArgs {
input,
output,
wordlists,
threshold,
csic,
}),
Commands::DownloadDatasets => {
sunbeam_proxy::dataset::download::download_all()
},
Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics } => {
Commands::PrepareDataset { input, owasp, wordlists, output, seed, heuristics, inject_csic, inject_modsec } => {
sunbeam_proxy::dataset::prepare::run(sunbeam_proxy::dataset::prepare::PrepareDatasetArgs {
input, owasp, wordlists, output, seed, heuristics,
input, owasp, wordlists, output, seed, heuristics, inject_csic, inject_modsec,
})
},
#[cfg(feature = "training")]
Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
Commands::TrainMlpScanner { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight } => {
sunbeam_proxy::training::train_scanner::run(sunbeam_proxy::training::train_scanner::TrainScannerMlpArgs {
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight,
})
},
#[cfg(feature = "training")]
Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity } => {
Commands::TrainMlpDdos { dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight } => {
sunbeam_proxy::training::train_ddos::run(sunbeam_proxy::training::train_ddos::TrainDdosMlpArgs {
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity,
dataset_path: dataset, output_dir, hidden_dim, epochs, learning_rate, batch_size, tree_max_depth, tree_min_purity, min_samples_leaf, cookie_weight,
})
},
Commands::AutotuneDdos {
input,
output,
trials,
beta,
trial_log,
} => autotune::ddos::run_autotune(autotune::ddos::AutotuneDdosArgs {
input,
output,
trials,
beta,
trial_log,
}),
Commands::AutotuneScanner {
input,
output,
wordlists,
csic,
trials,
beta,
trial_log,
} => autotune::scanner::run_autotune(autotune::scanner::AutotuneScannerArgs {
input,
output,
wordlists,
csic,
trials,
beta,
trial_log,
}),
#[cfg(feature = "training")]
Commands::SweepCookieWeight { dataset, detector, weights, tree_max_depth, tree_min_purity, min_samples_leaf } => {
sunbeam_proxy::training::sweep::run_cookie_sweep(
&dataset, &detector, weights.as_deref(), tree_max_depth, tree_min_purity, min_samples_leaf,
)
},
}
}
@@ -363,46 +216,19 @@ fn run_serve(upgrade: bool) -> Result<()> {
// 1b. Spawn metrics HTTP server (needs a tokio runtime for the TCP listener).
let metrics_port = cfg.telemetry.metrics_port;
// 2. Load DDoS detection model if configured.
// 2. Init DDoS detector if configured (ensemble: compiled-in weights).
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
if ddos_cfg.enabled {
if ddos_cfg.use_ensemble {
// Ensemble path: compiled-in weights, no model file needed.
// We still need a TrainedModel for the struct, but it won't be used.
let dummy_model = ddos::model::TrainedModel::empty(ddos_cfg.k, ddos_cfg.threshold);
let detector = Arc::new(ddos::detector::DDoSDetector::new_ensemble(dummy_model, ddos_cfg));
let detector = Arc::new(sunbeam_proxy::ddos::detector::DDoSDetector::new(ddos_cfg));
tracing::info!(
k = ddos_cfg.k,
threshold = ddos_cfg.threshold,
observe_only = ddos_cfg.observe_only,
"DDoS ensemble detector enabled"
);
if ddos_cfg.observe_only {
tracing::warn!("DDoS detector in OBSERVE-ONLY mode — decisions are logged but traffic is never blocked");
}
Some(detector)
} else if let Some(ref model_path) = ddos_cfg.model_path {
match ddos::model::TrainedModel::load(
std::path::Path::new(model_path),
Some(ddos_cfg.k),
Some(ddos_cfg.threshold),
) {
Ok(model) => {
let point_count = model.point_count();
let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg));
tracing::info!(
points = point_count,
k = ddos_cfg.k,
threshold = ddos_cfg.threshold,
"DDoS detector loaded"
);
Some(detector)
}
Err(e) => {
tracing::warn!(error = %e, "failed to load DDoS model; detection disabled");
None
}
}
} else {
tracing::warn!("DDoS enabled but no model_path and use_ensemble=false; detection disabled");
None
}
} else {
None
}
@@ -435,15 +261,12 @@ fn run_serve(upgrade: bool) -> Result<()> {
None
};
// 2c. Load scanner model if configured.
// 2c. Init scanner detector if configured (ensemble: compiled-in weights).
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
if scanner_cfg.enabled {
if scanner_cfg.use_ensemble {
// Ensemble path: compiled-in weights, no model file needed.
let detector = scanner::detector::ScannerDetector::new_ensemble(&cfg.routes);
let detector = scanner::detector::ScannerDetector::new(&cfg.routes);
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
// Start bot allowlist if rules are configured.
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
let al = scanner::allowlist::BotAllowlist::spawn(
&scanner_cfg.allowlist,
@@ -460,63 +283,13 @@ fn run_serve(upgrade: bool) -> Result<()> {
tracing::info!(
threshold = scanner_cfg.threshold,
observe_only = scanner_cfg.observe_only,
"scanner ensemble detector enabled"
);
if scanner_cfg.observe_only {
tracing::warn!("scanner detector in OBSERVE-ONLY mode — decisions are logged but traffic is never blocked");
}
(Some(handle), bot_allowlist)
} else if let Some(ref model_path) = scanner_cfg.model_path {
match scanner::model::ScannerModel::load(std::path::Path::new(model_path)) {
Ok(mut model) => {
let fragment_count = model.fragments.len();
model.threshold = scanner_cfg.threshold;
let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes);
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
// Start bot allowlist if rules are configured.
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
let al = scanner::allowlist::BotAllowlist::spawn(
&scanner_cfg.allowlist,
scanner_cfg.bot_cache_ttl_secs,
);
tracing::info!(
rules = scanner_cfg.allowlist.len(),
"bot allowlist enabled"
);
Some(al)
} else {
None
};
// Start background file watcher for hot-reload.
if scanner_cfg.poll_interval_secs > 0 {
let watcher_handle = handle.clone();
let watcher_model_path = std::path::PathBuf::from(model_path);
let threshold = scanner_cfg.threshold;
let routes = cfg.routes.clone();
let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs);
std::thread::spawn(move || {
scanner::watcher::watch_scanner_model(
watcher_handle, watcher_model_path, threshold, routes, interval,
);
});
}
tracing::info!(
fragments = fragment_count,
threshold = scanner_cfg.threshold,
poll_interval_secs = scanner_cfg.poll_interval_secs,
"scanner detector loaded"
);
(Some(handle), bot_allowlist)
}
Err(e) => {
tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled");
(None, None)
}
}
} else {
tracing::warn!("scanner enabled but no model_path and use_ensemble=false; scanner detection disabled");
(None, None)
}
} else {
(None, None)
}
@@ -617,6 +390,8 @@ fn run_serve(upgrade: bool) -> Result<()> {
&cfg.rate_limit.as_ref().map(|rl| rl.bypass_cidrs.clone()).unwrap_or_default(),
),
cluster: cluster_handle,
ddos_observe_only: cfg.ddos.as_ref().map(|d| d.observe_only).unwrap_or(false),
scanner_observe_only: cfg.scanner.as_ref().map(|s| s.observe_only).unwrap_or(false),
};
let mut svc = http_proxy_service(&server.configuration, proxy);

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use crate::acme::AcmeRoutes;
use crate::cluster::ClusterHandle;
use crate::config::RouteConfig;
@@ -32,9 +35,9 @@ pub struct SunbeamProxy {
pub routes: Vec<RouteConfig>,
/// Per-challenge route table populated by the Ingress watcher.
pub acme_routes: AcmeRoutes,
/// Optional KNN-based DDoS detector.
/// Optional DDoS detector (ensemble: decision tree + MLP).
pub ddos_detector: Option<Arc<DDoSDetector>>,
/// Optional per-request scanner detector (hot-reloadable via ArcSwap).
/// Optional per-request scanner detector (ensemble: decision tree + MLP).
pub scanner_detector: Option<Arc<ArcSwap<ScannerDetector>>>,
/// Optional verified-bot allowlist (bypasses scanner for known crawlers/agents).
pub bot_allowlist: Option<Arc<BotAllowlist>>,
@@ -48,6 +51,10 @@ pub struct SunbeamProxy {
pub pipeline_bypass_cidrs: Vec<crate::rate_limit::cidr::CidrBlock>,
/// Optional cluster handle for multi-node bandwidth tracking.
pub cluster: Option<Arc<ClusterHandle>>,
/// When true, DDoS detector logs decisions but never blocks traffic.
pub ddos_observe_only: bool,
/// When true, scanner detector logs decisions but never blocks traffic.
pub scanner_observe_only: bool,
}
pub struct RequestCtx {
@@ -341,7 +348,7 @@ impl ProxyHttp for SunbeamProxy {
metrics::DDOS_DECISIONS.with_label_values(&[decision]).inc();
if matches!(ddos_action, DDoSAction::Block) {
if matches!(ddos_action, DDoSAction::Block) && !self.ddos_observe_only {
let mut resp = ResponseHeader::build(429, None)?;
resp.insert_header("Retry-After", "60")?;
resp.insert_header("Content-Length", "0")?;
@@ -426,7 +433,7 @@ impl ProxyHttp for SunbeamProxy {
.with_label_values(&[decision, reason])
.inc();
if decision == "block" {
if decision == "block" && !self.scanner_observe_only {
let mut resp = ResponseHeader::build(403, None)?;
resp.insert_header("Content-Length", "0")?;
session.write_response_header(Box::new(resp), true).await?;
@@ -1150,6 +1157,21 @@ impl ProxyHttp for SunbeamProxy {
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
let query = session.req_header().uri.query().unwrap_or("");
let response_bytes = session.body_bytes_sent();
let http_version = format!("{:?}", session.req_header().version);
let header_count = session.req_header().headers.len() as u16;
let accept_encoding = session
.req_header()
.headers
.get("accept-encoding")
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
let connection = session
.req_header()
.headers
.get("connection")
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
tracing::info!(
target = "audit",
@@ -1162,14 +1184,19 @@ impl ProxyHttp for SunbeamProxy {
status,
duration_ms,
content_length,
response_bytes,
user_agent,
referer,
accept_language,
accept,
accept_encoding,
has_cookies,
cf_country,
backend,
error = error_str,
http_version,
header_count,
connection,
"request"
);

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Fetch and convert the CSIC 2010 HTTP dataset into labeled training samples.
//!
//! The CSIC 2010 dataset contains raw HTTP/1.1 requests (normal + anomalous)
@@ -65,6 +68,7 @@ struct ParsedRequest {
content_length: u64,
referer: String,
accept_language: String,
accept: String,
}
fn parse_csic_content(content: &str) -> Vec<ParsedRequest> {
@@ -158,6 +162,7 @@ fn parse_single_request(lines: &[&str]) -> Option<ParsedRequest> {
content_length,
referer: get_header("Referer").unwrap_or("-").to_string(),
accept_language: get_header("Accept-Language").unwrap_or("-").to_string(),
accept: get_header("Accept").unwrap_or("-").to_string(),
})
}
@@ -219,11 +224,12 @@ fn to_audit_fields(
// For anomalous samples, simulate real scanner behavior:
// strip cookies/referer/accept-language that CSIC attacks have from their session.
let (has_cookies, referer, accept_language, user_agent) = if label != "normal" {
let referer = None;
let referer = "-".to_string();
let accept_language = if rng.next_f64() < 0.8 {
None
"-".to_string()
} else {
Some(req.accept_language.clone()).filter(|a| a != "-")
let al = req.accept_language.clone();
if al == "-" { "-".to_string() } else { al }
};
let r = rng.next_f64();
let user_agent = if r < 0.15 {
@@ -241,12 +247,26 @@ fn to_audit_fields(
} else {
(
req.has_cookies,
Some(req.referer.clone()).filter(|r| r != "-"),
Some(req.accept_language.clone()).filter(|a| a != "-"),
if req.referer == "-" { "-".to_string() } else { req.referer.clone() },
if req.accept_language == "-" { "-".to_string() } else { req.accept_language.clone() },
req.user_agent.clone(),
)
};
// For normal traffic, preserve Accept header from CSIC request.
// For attacks, degrade it to simulate scanner behavior.
let accept = if label == "normal" {
if req.accept == "-" || req.accept.is_empty() {
"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8".to_string()
} else {
req.accept.clone()
}
} else if rng.next_f64() < 0.6 {
"*/*".to_string()
} else {
req.accept.clone()
};
AuditFields {
method: req.method.clone(),
host,
@@ -263,9 +283,10 @@ fn to_audit_fields(
duration_ms: rng.next_usize(50) as u64 + 1,
content_length: req.content_length,
user_agent,
has_cookies: Some(has_cookies),
has_cookies,
referer,
accept_language,
accept,
backend: if label == "normal" {
format!("{host_prefix}-svc:8080")
} else {
@@ -274,6 +295,7 @@ fn to_audit_fields(
label: Some(
if label == "normal" { "normal" } else { "attack" }.to_string(),
),
..AuditFields::default()
}
}
@@ -343,6 +365,7 @@ mod tests {
assert_eq!(req.path, "/index.html");
assert!(req.has_cookies);
assert_eq!(req.user_agent, "Mozilla/5.0");
assert_eq!(req.accept, "text/html");
}
#[test]
@@ -374,11 +397,12 @@ mod tests {
content_length: 100,
referer: "https://example.com".to_string(),
accept_language: "en-US".to_string(),
accept: "text/html".to_string(),
};
let mut rng = Rng::new(42);
let fields = to_audit_fields(&req, "normal", DEFAULT_HOSTS, &mut rng);
assert_eq!(fields.label.as_deref(), Some("normal"));
assert!(fields.has_cookies.unwrap_or(false));
assert!(fields.has_cookies);
assert!(fields.host.ends_with(".sunbeam.pt"));
}
@@ -393,11 +417,12 @@ mod tests {
content_length: 0,
referer: "https://example.com".to_string(),
accept_language: "en-US".to_string(),
accept: "text/html".to_string(),
};
let mut rng = Rng::new(42);
let fields = to_audit_fields(&req, "anomalous", DEFAULT_HOSTS, &mut rng);
assert_eq!(fields.label.as_deref(), Some("attack"));
assert!(!fields.has_cookies.unwrap_or(true));
assert!(!fields.has_cookies);
}
#[test]

View File

@@ -1,9 +1,9 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use crate::config::RouteConfig;
use crate::scanner::features::{
self, fx_hash_bytes, ScannerNormParams, SUSPICIOUS_EXTENSIONS_LIST, NUM_SCANNER_FEATURES,
NUM_SCANNER_WEIGHTS,
};
use crate::scanner::model::{ScannerAction, ScannerModel, ScannerVerdict};
use crate::scanner::features::{self, fx_hash_bytes, SUSPICIOUS_EXTENSIONS_LIST};
use crate::scanner::model::{ScannerAction, ScannerVerdict};
use rustc_hash::FxHashSet;
/// Immutable, zero-state per-request scanner detector.
@@ -12,44 +12,10 @@ pub struct ScannerDetector {
fragment_hashes: FxHashSet<u64>,
extension_hashes: FxHashSet<u64>,
configured_hosts: FxHashSet<u64>,
weights: [f64; NUM_SCANNER_WEIGHTS],
threshold: f64,
norm_params: ScannerNormParams,
use_ensemble: bool,
}
impl ScannerDetector {
pub fn new(model: &ScannerModel, routes: &[RouteConfig]) -> Self {
let fragment_hashes: FxHashSet<u64> = model
.fragments
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let extension_hashes: FxHashSet<u64> = SUSPICIOUS_EXTENSIONS_LIST
.iter()
.map(|e| fx_hash_bytes(e.as_bytes()))
.collect();
let configured_hosts: FxHashSet<u64> = routes
.iter()
.map(|r| fx_hash_bytes(r.host_prefix.as_bytes()))
.collect();
Self {
fragment_hashes,
extension_hashes,
configured_hosts,
weights: model.weights,
threshold: model.threshold,
norm_params: model.norm_params.clone(),
use_ensemble: false,
}
}
/// Create a detector that uses the ensemble (decision tree + MLP) path
/// instead of the linear model. No model file needed — weights are compiled in.
pub fn new_ensemble(routes: &[RouteConfig]) -> Self {
pub fn new(routes: &[RouteConfig]) -> Self {
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
@@ -69,13 +35,6 @@ impl ScannerDetector {
fragment_hashes,
extension_hashes,
configured_hosts,
weights: [0.0; NUM_SCANNER_WEIGHTS],
threshold: 0.5,
norm_params: ScannerNormParams {
mins: [0.0; NUM_SCANNER_FEATURES],
maxs: [1.0; NUM_SCANNER_FEATURES],
},
use_ensemble: true,
}
}
@@ -98,8 +57,6 @@ impl ScannerDetector {
content_length: u64,
) -> ScannerVerdict {
// Hard allowlist: obviously legitimate traffic bypasses the model.
// This prevents model drift from ever blocking real users and ensures
// the training pipeline always has clean positive labels.
let host_known = {
let hash = features::fx_hash_bytes(host_prefix.as_bytes());
self.configured_hosts.contains(&hash)
@@ -121,7 +78,6 @@ impl ScannerDetector {
};
}
if self.use_ensemble {
// Ensemble path: extract f32 features → decision tree + MLP.
let raw_f32 = features::extract_features_f32(
method, path, host_prefix,
@@ -137,79 +93,17 @@ impl ScannerDetector {
crate::ensemble::scanner::EnsemblePath::Mlp => "mlp",
}])
.inc();
return ev.into();
}
// 1. Extract 12 features
let raw = features::extract_features(
method,
path,
host_prefix,
has_cookies,
has_referer,
has_accept_language,
accept,
user_agent,
content_length,
&self.fragment_hashes,
&self.extension_hashes,
&self.configured_hosts,
);
// 2. Normalize
let f = self.norm_params.normalize(&raw);
// 3. Compute score = bias + dot(weights, features) + interaction terms
let mut score = self.weights[NUM_SCANNER_FEATURES + 2]; // bias (index 14)
for (i, &fi) in f.iter().enumerate().take(NUM_SCANNER_FEATURES) {
score += self.weights[i] * fi;
}
// Interaction: suspicious_path AND no_cookies
score += self.weights[12] * f[0] * (1.0 - f[3]);
// Interaction: unknown_host AND no_accept_language
score += self.weights[13] * (1.0 - f[9]) * (1.0 - f[5]);
// 4. Threshold
let action = if score > self.threshold {
ScannerAction::Block
} else {
ScannerAction::Allow
};
ScannerVerdict {
action,
score,
reason: "model",
}
ev.into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scanner::features::NUM_SCANNER_FEATURES;
use crate::config::RouteConfig;
fn make_detector(weights: [f64; NUM_SCANNER_WEIGHTS], threshold: f64) -> ScannerDetector {
let model = ScannerModel {
weights,
threshold,
norm_params: ScannerNormParams {
mins: [0.0; NUM_SCANNER_FEATURES],
maxs: [1.0; NUM_SCANNER_FEATURES],
},
fragments: vec![
".env".into(),
"wp-admin".into(),
"wp-login".into(),
"phpinfo".into(),
"phpmyadmin".into(),
".git".into(),
"cgi-bin".into(),
".htaccess".into(),
".htpasswd".into(),
],
};
let routes = vec![RouteConfig {
fn test_routes() -> Vec<RouteConfig> {
vec![RouteConfig {
host_prefix: "app".into(),
backend: "http://127.0.0.1:8080".into(),
websocket: false,
@@ -221,35 +115,12 @@ mod tests {
body_rewrites: vec![],
response_headers: vec![],
cache: None,
}];
ScannerDetector::new(&model, &routes)
}
/// Weights tuned to block scanner-like requests:
/// High weight on suspicious_path (w[0]), no_cookies interaction (w[12]),
/// has_suspicious_extension (w[2]), traversal (w[11]).
/// Negative weight on has_cookies (w[3]), has_referer (w[4]),
/// accept_quality (w[6]), ua_category (w[7]), host_is_configured (w[9]).
fn attack_tuned_weights() -> [f64; NUM_SCANNER_WEIGHTS] {
let mut w = [0.0; NUM_SCANNER_WEIGHTS];
w[0] = 2.0; // suspicious_path_score
w[2] = 2.0; // has_suspicious_extension
w[3] = -2.0; // has_cookies (negative = good)
w[4] = -1.0; // has_referer (negative = good)
w[5] = -1.0; // has_accept_language (negative = good)
w[6] = -0.5; // accept_quality (negative = good)
w[7] = -1.0; // ua_category (negative = browser is good)
w[9] = -1.5; // host_is_configured (negative = known host is good)
w[11] = 2.0; // path_has_traversal
w[12] = 1.5; // interaction: suspicious_path AND no_cookies
w[13] = 1.0; // interaction: unknown_host AND no_accept_lang
w[14] = 0.5; // bias
w
}]
}
#[test]
fn test_normal_browser_request_allowed() {
let detector = make_detector(attack_tuned_weights(), 0.5);
let detector = ScannerDetector::new(&test_routes());
let verdict = detector.check(
"GET",
"/blog/hello-world",
@@ -267,7 +138,7 @@ mod tests {
#[test]
fn test_api_client_with_auth_allowed() {
let detector = make_detector(attack_tuned_weights(), 0.5);
let detector = ScannerDetector::new(&test_routes());
let verdict = detector.check(
"POST",
"/api/v1/data",
@@ -285,81 +156,24 @@ mod tests {
#[test]
fn test_env_probe_blocked() {
let detector = make_detector(attack_tuned_weights(), 0.5);
let detector = ScannerDetector::new(&test_routes());
let verdict = detector.check(
"GET",
"/.env",
"unknown",
false, // no cookies
false, // no referer
false, // no accept-language
false,
false,
false,
"*/*",
"curl/7.0",
0,
);
assert_eq!(verdict.action, ScannerAction::Block);
assert_eq!(verdict.reason, "model");
}
#[test]
fn test_wordpress_scan_blocked() {
let detector = make_detector(attack_tuned_weights(), 0.5);
let verdict = detector.check(
"GET",
"/wp-admin/install.php",
"unknown",
false,
false,
false,
"*/*",
"",
0,
);
assert_eq!(verdict.action, ScannerAction::Block);
assert_eq!(verdict.reason, "model");
}
#[test]
fn test_path_traversal_blocked() {
let detector = make_detector(attack_tuned_weights(), 0.5);
let verdict = detector.check(
"GET",
"/etc/../../../passwd",
"unknown",
false,
false,
false,
"*/*",
"python-requests/2.28",
0,
);
assert_eq!(verdict.action, ScannerAction::Block);
assert_eq!(verdict.reason, "model");
}
#[test]
fn test_legitimate_php_path_allowed() {
let detector = make_detector(attack_tuned_weights(), 0.5);
// "/blog/php-is-dead" — "php-is-dead" is not a known fragment
// has_cookies=true + known host "app" → hits allowlist
let verdict = detector.check(
"GET",
"/blog/php-is-dead",
"app",
true,
true,
true,
"text/html",
"Mozilla/5.0 Chrome/120",
0,
);
assert_eq!(verdict.action, ScannerAction::Allow);
}
#[test]
fn test_allowlist_browser_on_known_host() {
let detector = make_detector(attack_tuned_weights(), 0.5);
// No cookies but browser UA + accept-language + known host → allowlist
let detector = ScannerDetector::new(&test_routes());
let verdict = detector.check(
"GET",
"/",
@@ -374,22 +188,4 @@ mod tests {
assert_eq!(verdict.action, ScannerAction::Allow);
assert_eq!(verdict.reason, "allowlist:host+browser");
}
#[test]
fn test_model_path_for_non_allowlisted() {
let detector = make_detector(attack_tuned_weights(), 0.5);
// Unknown host, no cookies, curl UA → goes through model
let verdict = detector.check(
"GET",
"/robots.txt",
"unknown",
false,
false,
false,
"*/*",
"curl/7.0",
0,
);
assert_eq!(verdict.reason, "model");
}
}

View File

@@ -1,7 +1,9 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
pub mod allowlist;
pub mod csic;
pub mod detector;
pub mod features;
pub mod model;
pub mod train;
pub mod watcher;

View File

@@ -1,7 +1,5 @@
use crate::scanner::features::{ScannerNormParams, NUM_SCANNER_WEIGHTS};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScannerAction {
@@ -16,74 +14,3 @@ pub struct ScannerVerdict {
/// Why this decision was made: "model", "allowlist", etc.
pub reason: &'static str,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScannerModel {
pub weights: [f64; NUM_SCANNER_WEIGHTS],
pub threshold: f64,
pub norm_params: ScannerNormParams,
/// Suspicious path fragments used during training — kept for reproducibility.
pub fragments: Vec<String>,
}
impl ScannerModel {
pub fn save(&self, path: &Path) -> Result<()> {
let data = bincode::serialize(self).context("serializing scanner model")?;
std::fs::write(path, data)
.with_context(|| format!("writing scanner model to {}", path.display()))?;
Ok(())
}
pub fn load(path: &Path) -> Result<Self> {
let data = std::fs::read(path)
.with_context(|| format!("reading scanner model from {}", path.display()))?;
bincode::deserialize(&data).context("deserializing scanner model")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scanner::features::NUM_SCANNER_FEATURES;
#[test]
fn test_serialization_roundtrip() {
let model = ScannerModel {
weights: [0.1; NUM_SCANNER_WEIGHTS],
threshold: 0.5,
norm_params: ScannerNormParams {
mins: [0.0; NUM_SCANNER_FEATURES],
maxs: [1.0; NUM_SCANNER_FEATURES],
},
fragments: vec![".env".into(), "wp-admin".into()],
};
let data = bincode::serialize(&model).unwrap();
let loaded: ScannerModel = bincode::deserialize(&data).unwrap();
assert_eq!(loaded.weights, model.weights);
assert_eq!(loaded.threshold, model.threshold);
assert_eq!(loaded.fragments, model.fragments);
}
#[test]
fn test_save_load_file() {
let dir = std::env::temp_dir().join("scanner_model_test");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test_model.bin");
let model = ScannerModel {
weights: [0.5; NUM_SCANNER_WEIGHTS],
threshold: 0.42,
norm_params: ScannerNormParams {
mins: [0.0; NUM_SCANNER_FEATURES],
maxs: [1.0; NUM_SCANNER_FEATURES],
},
fragments: vec!["phpinfo".into()],
};
model.save(&path).unwrap();
let loaded = ScannerModel::load(&path).unwrap();
assert_eq!(loaded.threshold, 0.42);
assert_eq!(loaded.fragments, vec!["phpinfo"]);
let _ = std::fs::remove_dir_all(&dir);
}
}

View File

@@ -1,9 +1,30 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use crate::ddos::audit_log::{AuditLog, AuditFields};
use crate::scanner::features::{
self, fx_hash_bytes, ScannerFeatureVector, ScannerNormParams, NUM_SCANNER_FEATURES,
NUM_SCANNER_WEIGHTS,
};
use crate::scanner::model::ScannerModel;
use serde::{Deserialize, Serialize};
/// Legacy linear scanner model — kept for the `train-scanner` CLI command.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScannerModel {
pub weights: [f64; NUM_SCANNER_WEIGHTS],
pub threshold: f64,
pub norm_params: ScannerNormParams,
pub fragments: Vec<String>,
}
impl ScannerModel {
pub fn save(&self, path: &Path) -> Result<()> {
let data = bincode::serialize(self).context("serializing scanner model")?;
std::fs::write(path, data)
.with_context(|| format!("writing scanner model to {}", path.display()))?;
Ok(())
}
}
use anyhow::{Context, Result};
use rustc_hash::FxHashSet;
use std::io::BufRead;
@@ -88,17 +109,9 @@ pub fn train_and_evaluate(
}
for (fields, host_prefix) in &parsed_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let has_cookies = fields.has_cookies;
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
let feats = features::extract_features(
&fields.method,
@@ -149,17 +162,9 @@ pub fn train_and_evaluate(
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
}
for (fields, host_prefix) in &csic_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let has_cookies = fields.has_cookies;
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
let feats = features::extract_features(
&fields.method,
@@ -288,17 +293,9 @@ pub fn run(args: TrainScannerArgs) -> Result<()> {
}
for (fields, host_prefix) in &parsed_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let has_cookies = fields.has_cookies;
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
let feats = features::extract_features(
&fields.method,
@@ -352,17 +349,9 @@ pub fn run(args: TrainScannerArgs) -> Result<()> {
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
}
for (fields, host_prefix) in &csic_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let has_cookies = fields.has_cookies;
let has_referer = !fields.referer.is_empty() && fields.referer != "-";
let has_accept_language = !fields.accept_language.is_empty() && fields.accept_language != "-";
let feats = features::extract_features(
&fields.method,

View File

@@ -1,51 +0,0 @@
use crate::config::RouteConfig;
use crate::scanner::detector::ScannerDetector;
use crate::scanner::model::ScannerModel;
use arc_swap::ArcSwap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
/// Poll the scanner model file for mtime changes and hot-swap the detector.
/// Runs forever on a dedicated OS thread — never returns.
pub fn watch_scanner_model(
handle: Arc<ArcSwap<ScannerDetector>>,
model_path: PathBuf,
threshold: f64,
routes: Vec<RouteConfig>,
poll_interval: Duration,
) {
let mut last_mtime = std::fs::metadata(&model_path)
.and_then(|m| m.modified())
.ok();
loop {
std::thread::sleep(poll_interval);
let current_mtime = match std::fs::metadata(&model_path).and_then(|m| m.modified()) {
Ok(t) => t,
Err(_) => continue,
};
if Some(current_mtime) == last_mtime {
continue;
}
match ScannerModel::load(&model_path) {
Ok(mut model) => {
model.threshold = threshold;
let fragment_count = model.fragments.len();
let detector = ScannerDetector::new(&model, &routes);
handle.store(Arc::new(detector));
last_mtime = Some(current_mtime);
tracing::info!(
fragments = fragment_count,
"scanner model hot-reloaded"
);
}
Err(e) => {
tracing::warn!(error = %e, "failed to reload scanner model; keeping current");
}
}
}
}

112
src/training/batch.rs Normal file
View File

@@ -0,0 +1,112 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Shared training infrastructure: Dataset adapter, Batcher, and batch types
//! for use with burn's SupervisedTraining.
use crate::dataset::sample::TrainingSample;
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::Dataset;
use burn::prelude::*;
/// A single normalized training item ready for batching.
#[derive(Clone, Debug)]
pub struct TrainingItem {
pub features: Vec<f32>,
pub label: i32,
pub weight: f32,
}
/// A batch of training items as tensors.
#[derive(Clone, Debug)]
pub struct TrainingBatch<B: Backend> {
pub features: Tensor<B, 2>,
pub labels: Tensor<B, 1, Int>,
pub weights: Tensor<B, 2>,
}
/// Wraps a `Vec<TrainingSample>` as a burn `Dataset`, applying min-max
/// normalization to features at construction time.
#[derive(Clone)]
pub struct SampleDataset {
items: Vec<TrainingItem>,
}
impl SampleDataset {
pub fn new(samples: &[TrainingSample], mins: &[f32], maxs: &[f32]) -> Self {
let items = samples
.iter()
.map(|s| {
let features: Vec<f32> = s
.features
.iter()
.enumerate()
.map(|(i, &v)| {
let range = maxs[i] - mins[i];
if range > f32::EPSILON {
((v - mins[i]) / range).clamp(0.0, 1.0)
} else {
0.0
}
})
.collect();
TrainingItem {
features,
label: if s.label >= 0.5 { 1 } else { 0 },
weight: s.weight,
}
})
.collect();
Self { items }
}
}
impl Dataset<TrainingItem> for SampleDataset {
fn get(&self, index: usize) -> Option<TrainingItem> {
self.items.get(index).cloned()
}
fn len(&self) -> usize {
self.items.len()
}
}
/// Converts a `Vec<TrainingItem>` into a `TrainingBatch` of tensors.
#[derive(Clone)]
pub struct SampleBatcher;
impl SampleBatcher {
pub fn new() -> Self {
Self
}
}
impl<B: Backend> Batcher<B, TrainingItem, TrainingBatch<B>> for SampleBatcher {
fn batch(&self, items: Vec<TrainingItem>, device: &B::Device) -> TrainingBatch<B> {
let batch_size = items.len();
let num_features = items[0].features.len();
let flat_features: Vec<f32> = items
.iter()
.flat_map(|item| item.features.iter().copied())
.collect();
let labels: Vec<i32> = items.iter().map(|item| item.label).collect();
let weights: Vec<f32> = items.iter().map(|item| item.weight).collect();
let features = Tensor::<B, 1>::from_floats(flat_features.as_slice(), device)
.reshape([batch_size, num_features]);
let labels = Tensor::<B, 1, Int>::from_ints(labels.as_slice(), device);
let weights = Tensor::<B, 1>::from_floats(weights.as_slice(), device)
.reshape([batch_size, 1]);
TrainingBatch {
features,
labels,
weights,
}
}
}

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Weight export: converts trained models into standalone Rust `const` arrays
//! and optionally Lean 4 definitions.
//!
@@ -54,7 +57,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
writeln!(s).unwrap();
// Threshold.
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", model.threshold).unwrap();
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", sanitize(model.threshold)).unwrap();
writeln!(s).unwrap();
// Normalization params.
@@ -74,7 +77,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
if i > 0 {
write!(s, ", ").unwrap();
}
write!(s, "{:.8}", v).unwrap();
write!(s, "{:.8}", sanitize(*v)).unwrap();
}
writeln!(s, "],").unwrap();
}
@@ -88,7 +91,7 @@ pub fn generate_rust_source(model: &ExportedModel) -> String {
write_f32_array(&mut s, "W2", &model.w2);
// B2.
writeln!(s, "pub const B2: f32 = {:.8};", model.b2).unwrap();
writeln!(s, "pub const B2: f32 = {:.8};", sanitize(model.b2)).unwrap();
writeln!(s).unwrap();
// Tree nodes.
@@ -207,6 +210,11 @@ pub fn export_to_file(model: &ExportedModel, path: &Path) -> Result<()> {
// Helpers
// ---------------------------------------------------------------------------
/// Sanitize a float for Rust source: replace NaN/Inf with 0.0.
fn sanitize(v: f32) -> f32 {
if v.is_finite() { v } else { 0.0 }
}
fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap();
write!(s, " ").unwrap();
@@ -218,7 +226,7 @@ fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
if i > 0 && i % 8 == 0 {
write!(s, "\n ").unwrap();
}
write!(s, "{:.8}", v).unwrap();
write!(s, "{:.8}", sanitize(*v)).unwrap();
}
writeln!(s, "\n];").unwrap();
writeln!(s).unwrap();

View File

@@ -1,11 +1,18 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! burn-rs MLP model definition for ensemble training.
//!
//! A two-layer network (linear -> ReLU -> linear -> sigmoid) used as the
//! "uncertain region" classifier in the tree+MLP ensemble.
use crate::training::batch::TrainingBatch;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::prelude::*;
use burn::tensor::backend::AutodiffBackend;
use burn::train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep};
/// Two-layer MLP: input -> hidden (ReLU) -> output (sigmoid).
#[derive(Module, Debug)]
@@ -34,24 +41,79 @@ impl MlpConfig {
}
impl<B: Backend> MlpModel<B> {
/// Forward pass: ReLU hidden activation, sigmoid output.
/// Forward pass returning raw logits (pre-sigmoid).
///
/// Input shape: `[batch, input_dim]`
/// Output shape: `[batch, 1]`
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
pub fn forward_logits(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let h = self.linear1.forward(x);
let h = burn::tensor::activation::relu(h);
let out = self.linear2.forward(h);
burn::tensor::activation::sigmoid(out)
self.linear2.forward(h)
}
/// Forward pass with sigmoid activation for inference/export.
///
/// Input shape: `[batch, input_dim]`
/// Output shape: `[batch, 1]` (values in [0, 1])
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
burn::tensor::activation::sigmoid(self.forward_logits(x))
}
/// Forward pass returning a `ClassificationOutput` for burn's training loop.
///
/// Uses raw logits for BCE (which applies sigmoid internally) and converts
/// to two-column format `[1-p, p]` for AccuracyMetric (which uses argmax).
pub fn forward_classification(
&self,
batch: TrainingBatch<B>,
) -> ClassificationOutput<B> {
let logits = self.forward_logits(batch.features); // [batch, 1]
let logits_1d = logits.clone().squeeze::<1>(); // [batch]
// Numerically stable BCE from logits:
// loss = max(logits, 0) - logits * targets + log(1 + exp(-|logits|))
// This avoids log(0) and exp(large) overflow.
let targets_float = batch.labels.clone().float(); // [batch]
let zeros = Tensor::zeros_like(&logits_1d);
let relu_logits = logits_1d.clone().max_pair(zeros); // max(logits, 0)
let neg_abs = logits_1d.clone().abs().neg(); // -|logits|
let log_term = neg_abs.exp().log1p(); // log(1 + exp(-|logits|))
let per_sample = relu_logits - logits_1d.clone() * targets_float + log_term;
let loss = per_sample.mean(); // scalar [1]
// AccuracyMetric expects [batch, num_classes] and uses argmax.
let neg_logits = logits.clone().neg();
let output_2col = Tensor::cat(vec![neg_logits, logits], 1); // [batch, 2]
ClassificationOutput::new(loss, output_2col, batch.labels)
}
}
impl<B: AutodiffBackend> TrainStep for MlpModel<B> {
type Input = TrainingBatch<B>;
type Output = ClassificationOutput<B>;
fn step(&self, batch: Self::Input) -> TrainOutput<Self::Output> {
let item = self.forward_classification(batch);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> InferenceStep for MlpModel<B> {
type Input = TrainingBatch<B>;
type Output = ClassificationOutput<B>;
fn step(&self, batch: Self::Input) -> Self::Output {
self.forward_classification(batch)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::backend::Wgpu;
type TestBackend = NdArray<f32>;
type TestBackend = Wgpu<f32, i32>;
#[test]
fn test_forward_pass_shape() {
@@ -80,7 +142,6 @@ mod tests {
};
let model = config.init::<TestBackend>(&device);
// Random-ish input values.
let input = Tensor::<TestBackend, 2>::from_data(
[[1.0, -2.0, 0.5, 3.0], [0.0, 0.0, 0.0, 0.0]],
&device,

View File

@@ -1,5 +1,10 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
pub mod tree;
pub mod mlp;
pub mod batch;
pub mod export;
pub mod train_scanner;
pub mod train_ddos;
pub mod sweep;

103
src/training/sweep.rs Normal file
View File

@@ -0,0 +1,103 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Cookie weight sweep: trains full tree+MLP ensembles (GPU via wgpu) across a
//! range of cookie_weight values and reports accuracy metrics for each.
use anyhow::{Context, Result};
use std::path::Path;
use crate::dataset::sample::load_dataset;
use crate::training::train_scanner::TrainScannerMlpArgs;
use crate::training::train_ddos::TrainDdosMlpArgs;
/// Run a sweep across cookie_weight values for either scanner or ddos.
///
/// Each trial does a full GPU training run (tree + MLP) with the specified
/// cookie_weight, writing artifacts to a temp directory.
pub fn run_cookie_sweep(
dataset_path: &str,
detector: &str,
weights_csv: Option<&str>,
tree_max_depth: usize,
tree_min_purity: f32,
min_samples_leaf: usize,
) -> Result<()> {
// Validate dataset exists and has samples.
let manifest = load_dataset(Path::new(dataset_path))
.context("loading dataset manifest")?;
let (cookie_idx, sample_count) = match detector {
"scanner" => (3usize, manifest.scanner_samples.len()),
"ddos" => (10usize, manifest.ddos_samples.len()),
other => anyhow::bail!("unknown detector '{}', expected 'scanner' or 'ddos'", other),
};
anyhow::ensure!(sample_count > 0, "no {} samples in dataset", detector);
drop(manifest); // Free memory before training loop.
let weights: Vec<f32> = if let Some(csv) = weights_csv {
csv.split(',')
.map(|s| s.trim().parse::<f32>())
.collect::<std::result::Result<Vec<_>, _>>()
.context("parsing --weights as comma-separated floats")?
} else {
(0..=10).map(|i| i as f32 / 10.0).collect()
};
println!(
"[sweep] {} detector, {} samples, cookie feature index: {}",
detector, sample_count, cookie_idx,
);
println!("[sweep] training {} trials with full tree+MLP (wgpu)\n", weights.len());
let sweep_dir = tempfile::tempdir().context("creating temp dir for sweep")?;
for (trial, &cw) in weights.iter().enumerate() {
let trial_dir = sweep_dir.path().join(format!("trial_{}", trial));
std::fs::create_dir_all(&trial_dir)?;
let trial_dir_str = trial_dir.to_string_lossy().to_string();
println!("━━━ Trial {}/{}: cookie_weight={:.2} ━━━", trial + 1, weights.len(), cw);
match detector {
"scanner" => {
crate::training::train_scanner::run(TrainScannerMlpArgs {
dataset_path: dataset_path.to_string(),
output_dir: trial_dir_str,
hidden_dim: 32,
epochs: 100,
learning_rate: 0.0001,
batch_size: 64,
tree_max_depth,
tree_min_purity,
min_samples_leaf,
cookie_weight: cw,
})?;
}
"ddos" => {
crate::training::train_ddos::run(TrainDdosMlpArgs {
dataset_path: dataset_path.to_string(),
output_dir: trial_dir_str,
hidden_dim: 32,
epochs: 100,
learning_rate: 0.0001,
batch_size: 64,
tree_max_depth,
tree_min_purity,
min_samples_leaf,
cookie_weight: cw,
})?;
}
_ => unreachable!(),
}
println!();
}
println!("[sweep] All {} trials complete.", weights.len());
println!("[sweep] Tip: compare tree structures and validation accuracy above.");
println!("[sweep] Look for a cookie_weight where FP rate drops without FN rate spiking.");
Ok(())
}

View File

@@ -1,19 +1,27 @@
//! DDoS MLP+tree training loop.
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! DDoS MLP+tree training loop using burn's SupervisedTraining.
//!
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP,
//! then exports the combined ensemble weights as a Rust source file that can
//! be dropped into `src/ensemble/gen/ddos_weights.rs`.
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP
//! with cosine annealing + early stopping, then exports the combined ensemble
//! weights as a Rust source file for `src/ensemble/gen/ddos_weights.rs`.
use anyhow::{Context, Result};
use std::path::Path;
use burn::backend::ndarray::NdArray;
use burn::backend::Autodiff;
use burn::module::AutodiffModule;
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::backend::Wgpu;
use burn::data::dataloader::DataLoaderBuilder;
use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig;
use burn::optim::AdamConfig;
use burn::prelude::*;
use burn::record::CompactRecorder;
use burn::train::metric::{AccuracyMetric, LossMetric};
use burn::train::{Learner, SupervisedTraining};
use crate::dataset::sample::{load_dataset, TrainingSample};
use crate::training::batch::{SampleBatcher, SampleDataset};
use crate::training::export::{export_to_file, ExportedModel};
use crate::training::mlp::MlpConfig;
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
@@ -21,7 +29,7 @@ use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
/// Number of DDoS features (matches `crate::ddos::features::NUM_FEATURES`).
const NUM_FEATURES: usize = 14;
type TrainBackend = Autodiff<NdArray<f32>>;
type TrainBackend = Autodiff<Wgpu<f32, i32>>;
/// Arguments for the DDoS MLP training command.
pub struct TrainDdosMlpArgs {
@@ -37,10 +45,14 @@ pub struct TrainDdosMlpArgs {
pub learning_rate: f64,
/// Mini-batch size (default 64).
pub batch_size: usize,
/// CART max depth (default 6).
/// CART max depth (default 8).
pub tree_max_depth: usize,
/// CART leaf purity threshold (default 0.90).
/// CART leaf purity threshold (default 0.98).
pub tree_min_purity: f32,
/// Minimum samples in a leaf node (default 2).
pub min_samples_leaf: usize,
/// Weight for cookie feature (feature 10: cookie_ratio). 0.0 = ignore, 1.0 = full weight.
pub cookie_weight: f32,
}
impl Default for TrainDdosMlpArgs {
@@ -50,14 +62,19 @@ impl Default for TrainDdosMlpArgs {
output_dir: ".".into(),
hidden_dim: 32,
epochs: 100,
learning_rate: 0.001,
learning_rate: 0.0001,
batch_size: 64,
tree_max_depth: 6,
tree_min_purity: 0.90,
tree_max_depth: 8,
tree_min_purity: 0.98,
min_samples_leaf: 2,
cookie_weight: 1.0,
}
}
}
/// Index of the cookie_ratio feature in the DDoS feature vector.
const COOKIE_FEATURE_IDX: usize = 10;
/// Entry point: train DDoS ensemble and export weights.
pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
// 1. Load dataset.
@@ -86,6 +103,23 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
// 2. Compute normalization params from training data.
let (norm_mins, norm_maxs) = compute_norm_params(samples);
if args.cookie_weight < 1.0 - f32::EPSILON {
println!(
"[ddos] cookie_weight={:.2} (feature {} influence reduced)",
args.cookie_weight, COOKIE_FEATURE_IDX,
);
}
// MLP norm adjustment: scale cookie feature's normalization range.
let mut mlp_norm_maxs = norm_maxs.clone();
if args.cookie_weight < 1.0 - f32::EPSILON {
let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX];
if range > f32::EPSILON && args.cookie_weight > f32::EPSILON {
mlp_norm_maxs[COOKIE_FEATURE_IDX] =
range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX];
}
}
// 3. Stratified 80/20 split.
let (train_set, val_set) = stratified_split(samples, 0.8);
println!(
@@ -94,15 +128,16 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
val_set.len()
);
// 4. Train CART tree.
// 4. Train CART tree (with cookie feature masking for reduced weight).
let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight);
let tree_config = TreeConfig {
max_depth: args.tree_max_depth,
min_samples_leaf: 5,
min_samples_leaf: args.min_samples_leaf,
min_purity: args.tree_min_purity,
num_features: NUM_FEATURES,
};
let tree_nodes = train_tree(&train_set, &tree_config);
println!("[ddos] CART tree: {} nodes", tree_nodes.len());
let tree_nodes = train_tree(&tree_train_set, &tree_config);
println!("[ddos] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth);
// Evaluate tree on validation set.
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
@@ -112,23 +147,27 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
tree_deferred * 100.0,
);
// 5. Train MLP on the full training set.
// 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling).
let device = Default::default();
let mlp_config = MlpConfig {
input_dim: NUM_FEATURES,
hidden_dim: args.hidden_dim,
};
let artifact_dir = Path::new(&args.output_dir).join("ddos_artifacts");
std::fs::create_dir_all(&artifact_dir).ok();
let model = train_mlp(
&train_set,
&val_set,
&mlp_config,
&norm_mins,
&norm_maxs,
&mlp_norm_maxs,
args.epochs,
args.learning_rate,
args.batch_size,
&device,
&artifact_dir,
);
// 6. Extract weights from trained model.
@@ -136,9 +175,9 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
&model,
"ddos",
&tree_nodes,
0.5, // threshold
0.5,
&norm_mins,
&norm_maxs,
&mlp_norm_maxs,
&device,
);
@@ -153,6 +192,37 @@ pub fn run(args: TrainDdosMlpArgs) -> Result<()> {
Ok(())
}
// ---------------------------------------------------------------------------
// Cookie feature masking for CART trees
// ---------------------------------------------------------------------------
fn mask_cookie_feature(
samples: &[TrainingSample],
cookie_idx: usize,
cookie_weight: f32,
) -> Vec<TrainingSample> {
if cookie_weight >= 1.0 - f32::EPSILON {
return samples.to_vec();
}
samples
.iter()
.enumerate()
.map(|(i, s)| {
let mut s2 = s.clone();
if cookie_weight < f32::EPSILON {
s2.features[cookie_idx] = 0.5;
} else {
let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42);
let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32;
if r > cookie_weight {
s2.features[cookie_idx] = 0.5;
}
}
s2
})
.collect()
}
// ---------------------------------------------------------------------------
// Normalization
// ---------------------------------------------------------------------------
@@ -170,21 +240,6 @@ fn compute_norm_params(samples: &[TrainingSample]) -> (Vec<f32>, Vec<f32>) {
(mins, maxs)
}
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
features
.iter()
.enumerate()
.map(|(i, &v)| {
let range = maxs[i] - mins[i];
if range > f32::EPSILON {
((v - mins[i]) / range).clamp(0.0, 1.0)
} else {
0.0
}
})
.collect()
}
// ---------------------------------------------------------------------------
// Stratified split
// ---------------------------------------------------------------------------
@@ -272,8 +327,23 @@ fn eval_tree(
(accuracy, defer_rate)
}
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
features
.iter()
.enumerate()
.map(|(i, &v)| {
let range = maxs[i] - mins[i];
if range > f32::EPSILON {
((v - mins[i]) / range).clamp(0.0, 1.0)
} else {
0.0
}
})
.collect()
}
// ---------------------------------------------------------------------------
// MLP training
// MLP training via SupervisedTraining
// ---------------------------------------------------------------------------
fn train_mlp(
@@ -286,117 +356,47 @@ fn train_mlp(
learning_rate: f64,
batch_size: usize,
device: &<TrainBackend as Backend>::Device,
) -> crate::training::mlp::MlpModel<NdArray<f32>> {
let mut model = config.init::<TrainBackend>(device);
let mut optim = AdamConfig::new().init();
artifact_dir: &Path,
) -> crate::training::mlp::MlpModel<Wgpu<f32, i32>> {
let model = config.init::<TrainBackend>(device);
// Pre-normalize all training data.
let train_features: Vec<Vec<f32>> = train_set
.iter()
.map(|s| normalize_features(&s.features, mins, maxs))
.collect();
let train_labels: Vec<f32> = train_set.iter().map(|s| s.label).collect();
let train_weights: Vec<f32> = train_set.iter().map(|s| s.weight).collect();
let train_dataset = SampleDataset::new(train_set, mins, maxs);
let val_dataset = SampleDataset::new(val_set, mins, maxs);
let n = train_features.len();
let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new())
.batch_size(batch_size)
.shuffle(42)
.num_workers(1)
.build(train_dataset);
for epoch in 0..epochs {
let mut epoch_loss = 0.0f32;
let mut batches = 0usize;
let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new())
.batch_size(batch_size)
.num_workers(1)
.build(val_dataset);
let mut offset = 0;
while offset < n {
let end = (offset + batch_size).min(n);
let batch_n = end - offset;
// Cosine annealing: initial_lr must be in (0.0, 1.0].
let lr = learning_rate.min(1.0);
let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs)
.init()
.expect("valid cosine annealing config");
// Build input tensor [batch, features].
let flat: Vec<f32> = train_features[offset..end]
.iter()
.flat_map(|f| f.iter().copied())
.collect();
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
.reshape([batch_n, NUM_FEATURES]);
// Labels [batch, 1].
let y = Tensor::<TrainBackend, 1>::from_floats(
&train_labels[offset..end],
device,
)
.reshape([batch_n, 1]);
// Sample weights [batch, 1].
let w = Tensor::<TrainBackend, 1>::from_floats(
&train_weights[offset..end],
device,
)
.reshape([batch_n, 1]);
// Forward pass.
let pred = model.forward(x);
// Binary cross-entropy with sample weights.
let eps = 1e-7;
let pred_clamped = pred.clone().clamp(eps, 1.0 - eps);
let bce = (y.clone() * pred_clamped.clone().log()
+ (y.clone().neg().add_scalar(1.0))
* pred_clamped.neg().add_scalar(1.0).log())
.neg();
let weighted_bce = bce * w;
let loss = weighted_bce.mean();
epoch_loss += loss.clone().into_scalar().elem::<f32>();
batches += 1;
// Backward + optimizer step.
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(learning_rate, model, grads);
offset = end;
}
if (epoch + 1) % 10 == 0 || epoch == 0 {
let avg_loss = epoch_loss / batches as f32;
let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device);
println!(
"[ddos] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}",
epoch + 1,
epochs,
avg_loss,
val_acc,
let learner = Learner::new(
model,
AdamConfig::new().init(),
lr_scheduler,
);
}
}
model.valid()
}
let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid)
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.num_epochs(epochs)
.summary()
.launch(learner);
fn eval_mlp_accuracy(
model: &crate::training::mlp::MlpModel<TrainBackend>,
val_set: &[TrainingSample],
mins: &[f32],
maxs: &[f32],
device: &<TrainBackend as Backend>::Device,
) -> f64 {
let flat: Vec<f32> = val_set
.iter()
.flat_map(|s| normalize_features(&s.features, mins, maxs))
.collect();
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
.reshape([val_set.len(), NUM_FEATURES]);
let pred = model.forward(x);
let pred_data: Vec<f32> = pred.to_data().to_vec().expect("flat vec");
let mut correct = 0usize;
for (i, s) in val_set.iter().enumerate() {
let p = pred_data[i];
let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 };
if (predicted_label - s.label).abs() < 0.1 {
correct += 1;
}
}
correct as f64 / val_set.len() as f64
result.model
}
// ---------------------------------------------------------------------------
@@ -404,13 +404,13 @@ fn eval_mlp_accuracy(
// ---------------------------------------------------------------------------
fn extract_weights(
model: &crate::training::mlp::MlpModel<NdArray<f32>>,
model: &crate::training::mlp::MlpModel<Wgpu<f32, i32>>,
name: &str,
tree_nodes: &[(u8, f32, u16, u16)],
threshold: f32,
norm_mins: &[f32],
norm_maxs: &[f32],
_device: &<NdArray<f32> as Backend>::Device,
_device: &<Wgpu<f32, i32> as Backend>::Device,
) -> ExportedModel {
let w1_tensor = model.linear1.weight.val();
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();

View File

@@ -1,19 +1,27 @@
//! Scanner MLP+tree training loop.
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Scanner MLP+tree training loop using burn's SupervisedTraining.
//!
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP,
//! then exports the combined ensemble weights as a Rust source file that can
//! be dropped into `src/ensemble/gen/scanner_weights.rs`.
//! Loads a `DatasetManifest`, trains a CART decision tree and a burn-rs MLP
//! with cosine annealing + early stopping, then exports the combined ensemble
//! weights as a Rust source file for `src/ensemble/gen/scanner_weights.rs`.
use anyhow::{Context, Result};
use std::path::Path;
use burn::backend::ndarray::NdArray;
use burn::backend::Autodiff;
use burn::module::AutodiffModule;
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::backend::Wgpu;
use burn::data::dataloader::DataLoaderBuilder;
use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig;
use burn::optim::AdamConfig;
use burn::prelude::*;
use burn::record::CompactRecorder;
use burn::train::metric::{AccuracyMetric, LossMetric};
use burn::train::{Learner, SupervisedTraining};
use crate::dataset::sample::{load_dataset, TrainingSample};
use crate::training::batch::{SampleBatcher, SampleDataset};
use crate::training::export::{export_to_file, ExportedModel};
use crate::training::mlp::MlpConfig;
use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
@@ -21,7 +29,7 @@ use crate::training::tree::{train_tree, tree_predict, TreeConfig, TreeDecision};
/// Number of scanner features (matches `crate::scanner::features::NUM_SCANNER_FEATURES`).
const NUM_FEATURES: usize = 12;
type TrainBackend = Autodiff<NdArray<f32>>;
type TrainBackend = Autodiff<Wgpu<f32, i32>>;
/// Arguments for the scanner MLP training command.
pub struct TrainScannerMlpArgs {
@@ -37,10 +45,14 @@ pub struct TrainScannerMlpArgs {
pub learning_rate: f64,
/// Mini-batch size (default 64).
pub batch_size: usize,
/// CART max depth (default 6).
/// CART max depth (default 8).
pub tree_max_depth: usize,
/// CART leaf purity threshold (default 0.90).
/// CART leaf purity threshold (default 0.98).
pub tree_min_purity: f32,
/// Minimum samples in a leaf node (default 2).
pub min_samples_leaf: usize,
/// Weight for cookie feature (feature 3: has_cookies). 0.0 = ignore, 1.0 = full weight.
pub cookie_weight: f32,
}
impl Default for TrainScannerMlpArgs {
@@ -50,14 +62,19 @@ impl Default for TrainScannerMlpArgs {
output_dir: ".".into(),
hidden_dim: 32,
epochs: 100,
learning_rate: 0.001,
learning_rate: 0.0001,
batch_size: 64,
tree_max_depth: 6,
tree_min_purity: 0.90,
tree_max_depth: 8,
tree_min_purity: 0.98,
min_samples_leaf: 2,
cookie_weight: 1.0,
}
}
}
/// Index of the has_cookies feature in the scanner feature vector.
const COOKIE_FEATURE_IDX: usize = 3;
/// Entry point: train scanner ensemble and export weights.
pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
// 1. Load dataset.
@@ -86,6 +103,27 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
// 2. Compute normalization params from training data.
let (norm_mins, norm_maxs) = compute_norm_params(samples);
// Apply cookie_weight: for the MLP, we scale the normalization range so
// the feature contributes less gradient signal. For the CART tree, scaling
// doesn't help (the tree just adjusts its threshold), so we mask the feature
// to a constant on a fraction of training samples to degrade its Gini gain.
if args.cookie_weight < 1.0 - f32::EPSILON {
println!(
"[scanner] cookie_weight={:.2} (feature {} influence reduced)",
args.cookie_weight, COOKIE_FEATURE_IDX,
);
}
// MLP norm adjustment: scale the cookie feature's normalization range.
let mut mlp_norm_maxs = norm_maxs.clone();
if args.cookie_weight < 1.0 - f32::EPSILON {
let range = mlp_norm_maxs[COOKIE_FEATURE_IDX] - norm_mins[COOKIE_FEATURE_IDX];
if range > f32::EPSILON && args.cookie_weight > f32::EPSILON {
mlp_norm_maxs[COOKIE_FEATURE_IDX] =
range / args.cookie_weight + norm_mins[COOKIE_FEATURE_IDX];
}
}
// 3. Stratified 80/20 split.
let (train_set, val_set) = stratified_split(samples, 0.8);
println!(
@@ -94,17 +132,18 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
val_set.len()
);
// 4. Train CART tree.
// 4. Train CART tree (with cookie feature masking for reduced weight).
let tree_train_set = mask_cookie_feature(&train_set, COOKIE_FEATURE_IDX, args.cookie_weight);
let tree_config = TreeConfig {
max_depth: args.tree_max_depth,
min_samples_leaf: 5,
min_samples_leaf: args.min_samples_leaf,
min_purity: args.tree_min_purity,
num_features: NUM_FEATURES,
};
let tree_nodes = train_tree(&train_set, &tree_config);
println!("[scanner] CART tree: {} nodes", tree_nodes.len());
let tree_nodes = train_tree(&tree_train_set, &tree_config);
println!("[scanner] CART tree: {} nodes (max_depth={})", tree_nodes.len(), args.tree_max_depth);
// Evaluate tree on validation set.
// Evaluate tree on validation set (use original norms — tree learned on masked features).
let (tree_correct, tree_deferred) = eval_tree(&tree_nodes, &val_set, &norm_mins, &norm_maxs);
println!(
"[scanner] tree validation: {:.2}% correct (of decided), {:.1}% deferred",
@@ -112,35 +151,38 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
tree_deferred * 100.0,
);
// 5. Train MLP on the full training set (the MLP only fires on Defer
// at inference time, but we train it on all data so it learns the
// full decision boundary).
// 5. Train MLP with SupervisedTraining (uses mlp_norm_maxs for cookie scaling).
let device = Default::default();
let mlp_config = MlpConfig {
input_dim: NUM_FEATURES,
hidden_dim: args.hidden_dim,
};
let artifact_dir = Path::new(&args.output_dir).join("scanner_artifacts");
std::fs::create_dir_all(&artifact_dir).ok();
let model = train_mlp(
&train_set,
&val_set,
&mlp_config,
&norm_mins,
&norm_maxs,
&mlp_norm_maxs,
args.epochs,
args.learning_rate,
args.batch_size,
&device,
&artifact_dir,
);
// 6. Extract weights from trained model.
// 6. Extract weights from trained model (export mlp_norm_maxs so inference
// automatically applies the same cookie scaling).
let exported = extract_weights(
&model,
"scanner",
&tree_nodes,
0.5, // threshold
0.5,
&norm_mins,
&norm_maxs,
&mlp_norm_maxs,
&device,
);
@@ -155,6 +197,46 @@ pub fn run(args: TrainScannerMlpArgs) -> Result<()> {
Ok(())
}
// ---------------------------------------------------------------------------
// Cookie feature masking for CART trees
// ---------------------------------------------------------------------------
/// Mask the cookie feature to reduce its influence on CART tree training.
///
/// Scaling a binary feature doesn't reduce its Gini gain — the tree just adjusts
/// the split threshold. Instead, we mask (set to 0.5) a fraction of samples so
/// the feature's apparent class-separation degrades.
///
/// - `cookie_weight = 0.0` → fully masked (feature is constant 0.5, zero info gain)
/// - `cookie_weight = 0.5` → 50% of samples masked (noisy, reduced gain)
/// - `cookie_weight = 1.0` → no masking (full feature)
fn mask_cookie_feature(
samples: &[TrainingSample],
cookie_idx: usize,
cookie_weight: f32,
) -> Vec<TrainingSample> {
if cookie_weight >= 1.0 - f32::EPSILON {
return samples.to_vec();
}
samples
.iter()
.enumerate()
.map(|(i, s)| {
let mut s2 = s.clone();
if cookie_weight < f32::EPSILON {
s2.features[cookie_idx] = 0.5;
} else {
let hash = (i as u64).wrapping_mul(6364136223846793005).wrapping_add(42);
let r = (hash >> 33) as f32 / (u32::MAX >> 1) as f32;
if r > cookie_weight {
s2.features[cookie_idx] = 0.5;
}
}
s2
})
.collect()
}
// ---------------------------------------------------------------------------
// Normalization
// ---------------------------------------------------------------------------
@@ -172,21 +254,6 @@ fn compute_norm_params(samples: &[TrainingSample]) -> (Vec<f32>, Vec<f32>) {
(mins, maxs)
}
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
features
.iter()
.enumerate()
.map(|(i, &v)| {
let range = maxs[i] - mins[i];
if range > f32::EPSILON {
((v - mins[i]) / range).clamp(0.0, 1.0)
} else {
0.0
}
})
.collect()
}
// ---------------------------------------------------------------------------
// Stratified split
// ---------------------------------------------------------------------------
@@ -195,7 +262,6 @@ fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec<Traini
let mut attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label >= 0.5).collect();
let mut normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label < 0.5).collect();
// Deterministic shuffle using a simple index permutation seeded by length.
deterministic_shuffle(&mut attacks);
deterministic_shuffle(&mut normals);
@@ -224,7 +290,6 @@ fn stratified_split(samples: &[TrainingSample], train_ratio: f64) -> (Vec<Traini
}
fn deterministic_shuffle<T>(items: &mut [T]) {
// Simple Fisher-Yates with a fixed LCG seed for reproducibility.
let mut rng = 42u64;
for i in (1..items.len()).rev() {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
@@ -276,8 +341,23 @@ fn eval_tree(
(accuracy, defer_rate)
}
fn normalize_features(features: &[f32], mins: &[f32], maxs: &[f32]) -> Vec<f32> {
features
.iter()
.enumerate()
.map(|(i, &v)| {
let range = maxs[i] - mins[i];
if range > f32::EPSILON {
((v - mins[i]) / range).clamp(0.0, 1.0)
} else {
0.0
}
})
.collect()
}
// ---------------------------------------------------------------------------
// MLP training
// MLP training via SupervisedTraining
// ---------------------------------------------------------------------------
fn train_mlp(
@@ -290,119 +370,47 @@ fn train_mlp(
learning_rate: f64,
batch_size: usize,
device: &<TrainBackend as Backend>::Device,
) -> crate::training::mlp::MlpModel<NdArray<f32>> {
let mut model = config.init::<TrainBackend>(device);
let mut optim = AdamConfig::new().init();
artifact_dir: &Path,
) -> crate::training::mlp::MlpModel<Wgpu<f32, i32>> {
let model = config.init::<TrainBackend>(device);
// Pre-normalize all training data.
let train_features: Vec<Vec<f32>> = train_set
.iter()
.map(|s| normalize_features(&s.features, mins, maxs))
.collect();
let train_labels: Vec<f32> = train_set.iter().map(|s| s.label).collect();
let train_weights: Vec<f32> = train_set.iter().map(|s| s.weight).collect();
let train_dataset = SampleDataset::new(train_set, mins, maxs);
let val_dataset = SampleDataset::new(val_set, mins, maxs);
let n = train_features.len();
let dataloader_train = DataLoaderBuilder::new(SampleBatcher::new())
.batch_size(batch_size)
.shuffle(42)
.num_workers(1)
.build(train_dataset);
for epoch in 0..epochs {
let mut epoch_loss = 0.0f32;
let mut batches = 0usize;
let dataloader_valid = DataLoaderBuilder::new(SampleBatcher::new())
.batch_size(batch_size)
.num_workers(1)
.build(val_dataset);
let mut offset = 0;
while offset < n {
let end = (offset + batch_size).min(n);
let batch_n = end - offset;
// Cosine annealing: initial_lr must be in (0.0, 1.0].
let lr = learning_rate.min(1.0);
let lr_scheduler = CosineAnnealingLrSchedulerConfig::new(lr, epochs)
.init()
.expect("valid cosine annealing config");
// Build input tensor [batch, features].
let flat: Vec<f32> = train_features[offset..end]
.iter()
.flat_map(|f| f.iter().copied())
.collect();
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
.reshape([batch_n, NUM_FEATURES]);
// Labels [batch, 1].
let y = Tensor::<TrainBackend, 1>::from_floats(
&train_labels[offset..end],
device,
)
.reshape([batch_n, 1]);
// Sample weights [batch, 1].
let w = Tensor::<TrainBackend, 1>::from_floats(
&train_weights[offset..end],
device,
)
.reshape([batch_n, 1]);
// Forward pass.
let pred = model.forward(x);
// Binary cross-entropy with sample weights:
// loss = -w * [y * log(p) + (1-y) * log(1-p)]
let eps = 1e-7;
let pred_clamped = pred.clone().clamp(eps, 1.0 - eps);
let bce = (y.clone() * pred_clamped.clone().log()
+ (y.clone().neg().add_scalar(1.0))
* pred_clamped.neg().add_scalar(1.0).log())
.neg();
let weighted_bce = bce * w;
let loss = weighted_bce.mean();
epoch_loss += loss.clone().into_scalar().elem::<f32>();
batches += 1;
// Backward + optimizer step.
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(learning_rate, model, grads);
offset = end;
}
if (epoch + 1) % 10 == 0 || epoch == 0 {
let avg_loss = epoch_loss / batches as f32;
let val_acc = eval_mlp_accuracy(&model, val_set, mins, maxs, device);
println!(
"[scanner] epoch {:>4}/{}: loss={:.6}, val_acc={:.4}",
epoch + 1,
epochs,
avg_loss,
val_acc,
let learner = Learner::new(
model,
AdamConfig::new().init(),
lr_scheduler,
);
}
}
// Return the inner (non-autodiff) model for weight extraction.
model.valid()
}
let result = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_valid)
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.num_epochs(epochs)
.summary()
.launch(learner);
fn eval_mlp_accuracy(
model: &crate::training::mlp::MlpModel<TrainBackend>,
val_set: &[TrainingSample],
mins: &[f32],
maxs: &[f32],
device: &<TrainBackend as Backend>::Device,
) -> f64 {
let flat: Vec<f32> = val_set
.iter()
.flat_map(|s| normalize_features(&s.features, mins, maxs))
.collect();
let x = Tensor::<TrainBackend, 1>::from_floats(flat.as_slice(), device)
.reshape([val_set.len(), NUM_FEATURES]);
let pred = model.forward(x);
let pred_data: Vec<f32> = pred.to_data().to_vec().expect("flat vec");
let mut correct = 0usize;
for (i, s) in val_set.iter().enumerate() {
let p = pred_data[i];
let predicted_label = if p >= 0.5 { 1.0 } else { 0.0 };
if (predicted_label - s.label).abs() < 0.1 {
correct += 1;
}
}
correct as f64 / val_set.len() as f64
result.model
}
// ---------------------------------------------------------------------------
@@ -410,19 +418,14 @@ fn eval_mlp_accuracy(
// ---------------------------------------------------------------------------
fn extract_weights(
model: &crate::training::mlp::MlpModel<NdArray<f32>>,
model: &crate::training::mlp::MlpModel<Wgpu<f32, i32>>,
name: &str,
tree_nodes: &[(u8, f32, u16, u16)],
threshold: f32,
norm_mins: &[f32],
norm_maxs: &[f32],
_device: &<NdArray<f32> as Backend>::Device,
_device: &<Wgpu<f32, i32> as Backend>::Device,
) -> ExportedModel {
// Extract weight tensors from the model.
// linear1.weight: [hidden_dim, input_dim]
// linear1.bias: [hidden_dim]
// linear2.weight: [1, hidden_dim]
// linear2.bias: [1]
let w1_tensor = model.linear1.weight.val();
let b1_tensor = model.linear1.bias.as_ref().expect("linear1 has bias").val();
let w2_tensor = model.linear2.weight.val();
@@ -436,7 +439,6 @@ fn extract_weights(
let hidden_dim = b1_data.len();
let input_dim = w1_data.len() / hidden_dim;
// Reshape W1 into [hidden_dim][input_dim].
let w1: Vec<Vec<f32>> = (0..hidden_dim)
.map(|h| w1_data[h * input_dim..(h + 1) * input_dim].to_vec())
.collect();
@@ -485,9 +487,8 @@ mod tests {
let train_attacks = train.iter().filter(|s| s.label >= 0.5).count();
let val_attacks = val.iter().filter(|s| s.label >= 0.5).count();
// Should preserve the 80/20 attack ratio approximately.
assert_eq!(train_attacks, 16); // 80% of 20
assert_eq!(val_attacks, 4); // 20% of 20
assert_eq!(train_attacks, 16);
assert_eq!(val_attacks, 4);
assert_eq!(train.len() + val.len(), 100);
}
@@ -503,13 +504,4 @@ mod tests {
assert_eq!(mins[1], 10.0);
assert_eq!(maxs[1], 20.0);
}
#[test]
fn test_normalize_features() {
let mins = vec![0.0, 10.0];
let maxs = vec![1.0, 20.0];
let normed = normalize_features(&[0.5, 15.0], &mins, &maxs);
assert!((normed[0] - 0.5).abs() < 1e-6);
assert!((normed[1] - 0.5).abs() < 1e-6);
}
}