feat(scanner): add logistic regression training pipeline

JSONL audit log ingestion with ground-truth label support for external
datasets (CSIC 2010), SecLists wordlist ingestion for synthetic attack
samples, class-weighted gradient descent, stratified 80/20 train/test
split with held-out evaluation metrics (precision, recall, F1).

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:19 +00:00
parent b7c8243955
commit 273a203c41

657
src/scanner/train.rs Normal file
View File

@@ -0,0 +1,657 @@
use crate::ddos::audit_log::{AuditLog, AuditFields};
use crate::scanner::features::{
self, fx_hash_bytes, ScannerFeatureVector, ScannerNormParams, NUM_SCANNER_FEATURES,
NUM_SCANNER_WEIGHTS,
};
use crate::scanner::model::ScannerModel;
use anyhow::{Context, Result};
use rustc_hash::FxHashSet;
use std::io::BufRead;
use std::path::Path;
pub struct TrainScannerArgs {
pub input: String,
pub output: String,
pub wordlists: Option<String>,
pub threshold: f64,
}
/// Default suspicious fragments — matches the DDoS feature list plus extras.
const DEFAULT_FRAGMENTS: &[&str] = &[
".env", ".git", ".bak", ".sql", ".tar", ".zip",
"wp-admin", "wp-login", "wp-includes", "wp-content", "xmlrpc",
"phpinfo", "phpmyadmin", "php-info",
"cgi-bin", "shell", "eval-stdin",
".htaccess", ".htpasswd",
"config.", "admin",
"yarn.lock", "package.json", "composer.json",
"telescope", "actuator", "debug",
];
const ATTACK_EXTENSIONS: &[&str] = &[".env", ".sql", ".bak", ".git/config"];
const TRAVERSAL_MARKERS: &[&str] = &["..", "%00", "%0a"];
struct LabeledSample {
features: ScannerFeatureVector,
label: f64, // 1.0 = attack, 0.0 = normal
}
pub fn run(args: TrainScannerArgs) -> Result<()> {
let mut fragments: Vec<String> = DEFAULT_FRAGMENTS.iter().map(|s| s.to_string()).collect();
let fragment_hashes: FxHashSet<u64> = fragments
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
.iter()
.map(|e| fx_hash_bytes(e.as_bytes()))
.collect();
// Populated below from observed log hosts.
let _configured_hosts: FxHashSet<u64> = FxHashSet::default();
// 1. Parse JSONL audit logs and label each request
let mut samples: Vec<LabeledSample> = Vec::new();
let file = std::fs::File::open(&args.input)
.with_context(|| format!("opening {}", args.input))?;
let reader = std::io::BufReader::new(file);
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
let mut parsed_entries: Vec<(AuditFields, String)> = Vec::new();
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => continue,
};
let host_prefix = entry
.fields
.host
.split('.')
.next()
.unwrap_or("")
.to_string();
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
parsed_entries.push((entry.fields, host_prefix));
}
for (fields, host_prefix) in &parsed_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let feats = features::extract_features(
&fields.method,
&fields.path,
host_prefix,
has_cookies,
has_referer,
has_accept_language,
"-", // accept not in audit log, use default
&fields.user_agent,
fields.content_length,
&fragment_hashes,
&extension_hashes,
&log_hosts, // treat all observed hosts as configured during training
);
// Use ground-truth label if present (external datasets like CSIC 2010),
// otherwise fall back to heuristic labeling.
let label = if let Some(ref gt) = fields.label {
match gt.as_str() {
"attack" | "anomalous" => Some(1.0),
"normal" => Some(0.0),
_ => None,
}
} else {
label_request(
&fields.path,
has_cookies,
has_referer,
has_accept_language,
&fields.user_agent,
host_prefix,
fields.status,
&log_hosts,
&fragment_hashes,
)
};
if let Some(l) = label {
samples.push(LabeledSample {
features: feats,
label: l,
});
}
}
let log_sample_count = samples.len();
let log_attack_count = samples.iter().filter(|s| s.label > 0.5).count();
eprintln!(
"log samples: {} ({} attack, {} normal)",
log_sample_count,
log_attack_count,
log_sample_count - log_attack_count,
);
// 2. Ingest wordlist files for synthetic attack samples
if let Some(wordlist_dir) = &args.wordlists {
let wordlist_count = ingest_wordlists(
wordlist_dir,
&mut samples,
&mut fragments,
&fragment_hashes,
&extension_hashes,
&log_hosts,
)?;
eprintln!("synthetic attack samples from wordlists: {wordlist_count}");
}
if samples.is_empty() {
anyhow::bail!("no training samples found");
}
let attack_total = samples.iter().filter(|s| s.label > 0.5).count();
let normal_total = samples.len() - attack_total;
eprintln!(
"total samples: {} ({} attack, {} normal)",
samples.len(),
attack_total,
normal_total,
);
// 3. Stratified 80/20 train/test split (deterministic shuffle via simple LCG)
let (train_samples, test_samples) = stratified_split(&mut samples, 0.8, 42);
eprintln!(
"train: {} ({} attack, {} normal)",
train_samples.len(),
train_samples.iter().filter(|s| s.label > 0.5).count(),
train_samples.iter().filter(|s| s.label <= 0.5).count(),
);
eprintln!(
"test: {} ({} attack, {} normal)",
test_samples.len(),
test_samples.iter().filter(|s| s.label > 0.5).count(),
test_samples.iter().filter(|s| s.label <= 0.5).count(),
);
// 4. Compute normalization params from training set only
let train_feature_vecs: Vec<ScannerFeatureVector> =
train_samples.iter().map(|s| s.features).collect();
let norm_params = ScannerNormParams::from_data(&train_feature_vecs);
// 5. Normalize training features
let train_normalized: Vec<ScannerFeatureVector> =
train_feature_vecs.iter().map(|v| norm_params.normalize(v)).collect();
// 6. Train logistic regression
let weights = train_logistic_regression(&train_normalized, &train_samples, 1000, 0.01);
eprintln!("\nlearned weights:");
let feature_names = [
"suspicious_path_score",
"path_depth",
"has_suspicious_extension",
"has_cookies",
"has_referer",
"has_accept_language",
"accept_quality",
"ua_category",
"method_is_unusual",
"host_is_configured",
"content_length_mismatch",
"path_has_traversal",
"interaction:path*no_cookies",
"interaction:no_host*no_lang",
"bias",
];
for (i, name) in feature_names.iter().enumerate() {
eprintln!(" w[{i:2}] {name:>35} = {:.4}", weights[i]);
}
// 7. Save model
let model = ScannerModel {
weights,
threshold: args.threshold,
norm_params: norm_params.clone(),
fragments,
};
model.save(Path::new(&args.output))?;
eprintln!("\nmodel saved to {}", args.output);
// 8. Evaluate on training set
let train_metrics = evaluate(&train_normalized, &train_samples, &weights, args.threshold);
eprintln!("\n--- training set ---");
train_metrics.print(train_samples.len());
// 9. Evaluate on held-out test set
let test_feature_vecs: Vec<ScannerFeatureVector> =
test_samples.iter().map(|s| s.features).collect();
let test_normalized: Vec<ScannerFeatureVector> =
test_feature_vecs.iter().map(|v| norm_params.normalize(v)).collect();
let test_metrics = evaluate(&test_normalized, &test_samples, &weights, args.threshold);
eprintln!("\n--- test set (held-out 20%) ---");
test_metrics.print(test_samples.len());
Ok(())
}
struct Metrics {
tp: u32,
fp: u32,
tn: u32,
fn_: u32,
}
impl Metrics {
fn precision(&self) -> f64 {
if self.tp + self.fp > 0 { self.tp as f64 / (self.tp + self.fp) as f64 } else { 0.0 }
}
fn recall(&self) -> f64 {
if self.tp + self.fn_ > 0 { self.tp as f64 / (self.tp + self.fn_) as f64 } else { 0.0 }
}
fn f1(&self) -> f64 {
let p = self.precision();
let r = self.recall();
if p + r > 0.0 { 2.0 * p * r / (p + r) } else { 0.0 }
}
fn print(&self, total: usize) {
let acc = (self.tp + self.tn) as f64 / total as f64 * 100.0;
eprintln!(
"accuracy: {acc:.1}% (tp={} fp={} tn={} fn={})",
self.tp, self.fp, self.tn, self.fn_,
);
eprintln!(
"precision={:.3} recall={:.3} f1={:.3}",
self.precision(), self.recall(), self.f1(),
);
}
}
fn evaluate(
normalized: &[ScannerFeatureVector],
samples: &[LabeledSample],
weights: &[f64; NUM_SCANNER_WEIGHTS],
threshold: f64,
) -> Metrics {
let mut m = Metrics { tp: 0, fp: 0, tn: 0, fn_: 0 };
for (i, sample) in samples.iter().enumerate() {
let f = &normalized[i];
let mut score = weights[NUM_SCANNER_FEATURES + 2];
for j in 0..NUM_SCANNER_FEATURES {
score += weights[j] * f[j];
}
score += weights[12] * f[0] * (1.0 - f[3]);
score += weights[13] * (1.0 - f[9]) * (1.0 - f[5]);
let predicted = score > threshold;
let actual = sample.label > 0.5;
match (predicted, actual) {
(true, true) => m.tp += 1,
(true, false) => m.fp += 1,
(false, false) => m.tn += 1,
(false, true) => m.fn_ += 1,
}
}
m
}
/// Deterministic stratified split: separates attack/normal, shuffles each with
/// a simple LCG, then takes `train_frac` from each class.
fn stratified_split(
samples: &mut Vec<LabeledSample>,
train_frac: f64,
seed: u64,
) -> (Vec<LabeledSample>, Vec<LabeledSample>) {
// Partition into two classes
let mut attacks: Vec<LabeledSample> = Vec::new();
let mut normals: Vec<LabeledSample> = Vec::new();
for s in samples.drain(..) {
if s.label > 0.5 { attacks.push(s); } else { normals.push(s); }
}
// Deterministic Fisher-Yates shuffle using LCG
fn lcg_shuffle(v: &mut [LabeledSample], seed: u64) {
let mut state = seed;
for i in (1..v.len()).rev() {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let j = (state >> 33) as usize % (i + 1);
v.swap(i, j);
}
}
lcg_shuffle(&mut attacks, seed);
lcg_shuffle(&mut normals, seed.wrapping_add(1));
let attack_train_n = (attacks.len() as f64 * train_frac) as usize;
let normal_train_n = (normals.len() as f64 * train_frac) as usize;
let attack_test: Vec<_> = attacks.split_off(attack_train_n);
let normal_test: Vec<_> = normals.split_off(normal_train_n);
let mut train = attacks;
train.extend(normals);
let mut test = attack_test;
test.extend(normal_test);
(train, test)
}
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
fn train_logistic_regression(
normalized: &[ScannerFeatureVector],
samples: &[LabeledSample],
epochs: usize,
learning_rate: f64,
) -> [f64; NUM_SCANNER_WEIGHTS] {
let mut weights = [0.0f64; NUM_SCANNER_WEIGHTS];
let n = samples.len() as f64;
// Class weighting: give the minority class proportionally stronger gradients
// so the model doesn't collapse to always-predict-majority.
let n_attack = samples.iter().filter(|s| s.label > 0.5).count() as f64;
let n_normal = n - n_attack;
let (w_attack, w_normal) = if n_attack > 0.0 && n_normal > 0.0 {
(n / (2.0 * n_attack), n / (2.0 * n_normal))
} else {
(1.0, 1.0)
};
eprintln!("class weights: attack={w_attack:.2} normal={w_normal:.2}");
for _epoch in 0..epochs {
let mut gradients = [0.0f64; NUM_SCANNER_WEIGHTS];
for (i, sample) in samples.iter().enumerate() {
let f = &normalized[i];
// Compute prediction
let mut z = weights[NUM_SCANNER_FEATURES + 2]; // bias
for j in 0..NUM_SCANNER_FEATURES {
z += weights[j] * f[j];
}
z += weights[12] * f[0] * (1.0 - f[3]);
z += weights[13] * (1.0 - f[9]) * (1.0 - f[5]);
let prediction = sigmoid(z);
let error = prediction - sample.label;
// Apply class weight
let cw = if sample.label > 0.5 { w_attack } else { w_normal };
let weighted_error = error * cw;
// Accumulate gradients
for j in 0..NUM_SCANNER_FEATURES {
gradients[j] += weighted_error * f[j];
}
gradients[12] += weighted_error * f[0] * (1.0 - f[3]);
gradients[13] += weighted_error * (1.0 - f[9]) * (1.0 - f[5]);
gradients[14] += weighted_error;
}
// Update weights
for j in 0..NUM_SCANNER_WEIGHTS {
weights[j] -= learning_rate * gradients[j] / n;
}
}
weights
}
fn label_request(
path: &str,
has_cookies: bool,
has_referer: bool,
has_accept_language: bool,
user_agent: &str,
host_prefix: &str,
status: u16,
configured_hosts: &FxHashSet<u64>,
fragment_hashes: &FxHashSet<u64>,
) -> Option<f64> {
let lower_path = path.to_ascii_lowercase();
let host_hash = fx_hash_bytes(host_prefix.as_bytes());
let host_known = configured_hosts.contains(&host_hash);
let path_suspicious = is_path_suspicious(&lower_path, fragment_hashes);
let has_traversal = TRAVERSAL_MARKERS
.iter()
.any(|m| lower_path.contains(m));
// Attack if:
// - Path matches 2+ suspicious fragments AND (no cookies OR no referer)
if path_suspicious >= 2 && (!has_cookies || !has_referer) {
return Some(1.0);
}
// - Path ends in known-bad extension
for ext in ATTACK_EXTENSIONS {
if lower_path.ends_with(ext) && !has_cookies {
return Some(1.0);
}
}
// - Path contains traversal chars
if has_traversal {
return Some(1.0);
}
// - Unknown host AND no cookies AND no accept-language
if !host_known && !has_cookies && !has_accept_language {
// Only if UA is also suspicious
let ua_lower = user_agent.to_ascii_lowercase();
if ua_lower.is_empty()
|| ua_lower == "-"
|| ua_lower.starts_with("curl/")
|| ua_lower.starts_with("python")
|| ua_lower.starts_with("go-http")
{
return Some(1.0);
}
}
// - Empty UA AND path has suspicious fragments
if (user_agent.is_empty() || user_agent == "-") && path_suspicious >= 1 {
return Some(1.0);
}
// Normal if:
// - Configured host AND has cookies
if host_known && has_cookies {
return Some(0.0);
}
// - Status 2xx AND has cookies AND has referer
if (200..300).contains(&status) && has_cookies && has_referer {
return Some(0.0);
}
// Ambiguous — skip
None
}
fn is_path_suspicious(lower_path: &str, fragment_hashes: &FxHashSet<u64>) -> u32 {
let mut count = 0u32;
for segment in lower_path.split('/') {
if segment.is_empty() {
continue;
}
let hash = fx_hash_bytes(segment.as_bytes());
if fragment_hashes.contains(&hash) {
count += 1;
}
}
count
}
fn ingest_wordlists(
dir_path: &str,
samples: &mut Vec<LabeledSample>,
fragments: &mut Vec<String>,
fragment_hashes: &FxHashSet<u64>,
extension_hashes: &FxHashSet<u64>,
configured_hosts: &FxHashSet<u64>,
) -> Result<usize> {
let dir = Path::new(dir_path);
if !dir.exists() {
anyhow::bail!("wordlist directory not found: {dir_path}");
}
let mut count = 0usize;
let mut new_fragment_hashes = fragment_hashes.clone();
let entries: Vec<_> = if dir.is_file() {
vec![dir.to_path_buf()]
} else {
std::fs::read_dir(dir)?
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.extension().map(|e| e == "txt").unwrap_or(false))
.collect()
};
for path in &entries {
let file = std::fs::File::open(path)
.with_context(|| format!("opening wordlist {}", path.display()))?;
let reader = std::io::BufReader::new(file);
for line in reader.lines() {
let line = line?;
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
// Normalize: ensure leading /
let normalized_path = if line.starts_with('/') {
line.to_string()
} else {
format!("/{line}")
};
// Add path segments as fragments
for segment in normalized_path.split('/') {
if segment.is_empty() || segment.len() < 3 {
continue;
}
let lower = segment.to_ascii_lowercase();
let hash = fx_hash_bytes(lower.as_bytes());
if !new_fragment_hashes.contains(&hash) {
new_fragment_hashes.insert(hash);
fragments.push(lower);
}
}
// Generate synthetic attack sample
let feats = features::extract_features(
"GET",
&normalized_path,
"unknown-host",
false, // no cookies
false, // no referer
false, // no accept-language
"*/*",
"curl/7.0",
0,
&new_fragment_hashes,
extension_hashes,
configured_hosts,
);
samples.push(LabeledSample {
features: feats,
label: 1.0,
});
count += 1;
}
}
eprintln!("fragment vocabulary size: {}", fragments.len());
Ok(count)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-10);
assert!(sigmoid(10.0) > 0.99);
assert!(sigmoid(-10.0) < 0.01);
}
#[test]
fn test_label_request_env_probe() {
let hosts = FxHashSet::default();
let frags: FxHashSet<u64> = DEFAULT_FRAGMENTS
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let label = label_request(
"/.env",
false, false, false,
"curl/7.0", "unknown", 404,
&hosts, &frags,
);
assert_eq!(label, Some(1.0));
}
#[test]
fn test_label_request_normal_browser() {
let mut hosts = FxHashSet::default();
hosts.insert(fx_hash_bytes(b"app"));
let frags: FxHashSet<u64> = DEFAULT_FRAGMENTS
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let label = label_request(
"/blog/hello",
true, true, true,
"Mozilla/5.0", "app", 200,
&hosts, &frags,
);
assert_eq!(label, Some(0.0));
}
#[test]
fn test_label_request_traversal() {
let hosts = FxHashSet::default();
let frags = FxHashSet::default();
let label = label_request(
"/../../etc/passwd",
false, false, false,
"", "unknown", 404,
&hosts, &frags,
);
assert_eq!(label, Some(1.0));
}
#[test]
fn test_logistic_regression_converges() {
// Simple separable dataset: attack features have high f[0], normal have low
let mut samples = Vec::new();
let mut features = Vec::new();
for _ in 0..50 {
let f = [0.9, 0.5, 0.8, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.8];
features.push(f);
samples.push(LabeledSample { features: f, label: 1.0 });
}
for _ in 0..50 {
let f = [0.1, 0.2, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0];
features.push(f);
samples.push(LabeledSample { features: f, label: 0.0 });
}
let weights = train_logistic_regression(&features, &samples, 500, 0.1);
// Verify attack weight is positive, cookie weight is negative
assert!(weights[0] > 0.0, "suspicious_path weight should be positive");
assert!(weights[3] < 0.0, "has_cookies weight should be negative");
}
}