feat(dataset): add dataset preparation with auto-download and heuristic labeling
Unified prepare-dataset pipeline that automatically downloads and caches upstream datasets (CSIC 2010, CIC-IDS2017), applies heuristic auto-labeling to unlabeled production logs, generates synthetic samples for both models, and serializes everything as a bincode DatasetManifest. Includes OWASP ModSec parser, CIC-IDS2017 timing profile extractor, and synthetic data generators with configurable distributions. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
301
src/dataset/cicids.rs
Normal file
301
src/dataset/cicids.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
//! CIC-IDS2017 timing profile extractor.
|
||||
//!
|
||||
//! Parses CIC-IDS2017 CSV files and extracts statistical timing profiles
|
||||
//! per attack type. These profiles are NOT training samples themselves —
|
||||
//! they feed the synthetic data generator (`synthetic.rs`) which samples
|
||||
//! from the learned distributions to produce DDoS training features.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
|
||||
/// Statistical timing profile for one attack type from CIC-IDS2017.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TimingProfile {
|
||||
pub attack_type: String,
|
||||
pub inter_arrival_mean: f64,
|
||||
pub inter_arrival_std: f64,
|
||||
pub burst_duration_mean: f64,
|
||||
pub burst_duration_std: f64,
|
||||
pub flow_bytes_per_sec_mean: f64,
|
||||
pub flow_bytes_per_sec_std: f64,
|
||||
pub sample_count: usize,
|
||||
}
|
||||
|
||||
/// Accumulator for computing mean and variance in a single pass (Welford's algorithm).
|
||||
#[derive(Default)]
|
||||
struct StatsAccumulator {
|
||||
count: usize,
|
||||
mean: f64,
|
||||
m2: f64,
|
||||
}
|
||||
|
||||
impl StatsAccumulator {
|
||||
fn push(&mut self, value: f64) {
|
||||
if !value.is_finite() {
|
||||
return;
|
||||
}
|
||||
self.count += 1;
|
||||
let delta = value - self.mean;
|
||||
self.mean += delta / self.count as f64;
|
||||
let delta2 = value - self.mean;
|
||||
self.m2 += delta * delta2;
|
||||
}
|
||||
|
||||
fn mean(&self) -> f64 {
|
||||
self.mean
|
||||
}
|
||||
|
||||
fn std_dev(&self) -> f64 {
|
||||
if self.count < 2 {
|
||||
0.0
|
||||
} else {
|
||||
(self.m2 / (self.count - 1) as f64).sqrt()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-label accumulator for timing statistics.
|
||||
struct LabelAccumulator {
|
||||
label: String,
|
||||
inter_arrival: StatsAccumulator,
|
||||
burst_duration: StatsAccumulator,
|
||||
flow_bytes_per_sec: StatsAccumulator,
|
||||
count: usize,
|
||||
}
|
||||
|
||||
impl LabelAccumulator {
|
||||
fn new(label: String) -> Self {
|
||||
Self {
|
||||
label,
|
||||
inter_arrival: StatsAccumulator::default(),
|
||||
burst_duration: StatsAccumulator::default(),
|
||||
flow_bytes_per_sec: StatsAccumulator::default(),
|
||||
count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_profile(self) -> TimingProfile {
|
||||
TimingProfile {
|
||||
attack_type: self.label,
|
||||
inter_arrival_mean: self.inter_arrival.mean(),
|
||||
inter_arrival_std: self.inter_arrival.std_dev(),
|
||||
burst_duration_mean: self.burst_duration.mean(),
|
||||
burst_duration_std: self.burst_duration.std_dev(),
|
||||
flow_bytes_per_sec_mean: self.flow_bytes_per_sec.mean(),
|
||||
flow_bytes_per_sec_std: self.flow_bytes_per_sec.std_dev(),
|
||||
sample_count: self.count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find a column index by name (case-insensitive, trimmed).
|
||||
fn find_column(headers: &[String], name: &str) -> Option<usize> {
|
||||
let lower = name.to_ascii_lowercase();
|
||||
headers
|
||||
.iter()
|
||||
.position(|h| h.trim().to_ascii_lowercase() == lower)
|
||||
}
|
||||
|
||||
/// Extract timing profiles from all CSV files in the given directory.
|
||||
///
|
||||
/// Each CSV is expected to have CIC-IDS2017 columns including at minimum:
|
||||
/// `Flow Duration`, `Flow IAT Mean`, `Flow IAT Std`, `Flow Bytes/s`, `Label`.
|
||||
///
|
||||
/// Returns one `TimingProfile` per unique label value.
|
||||
pub fn extract_timing_profiles(csv_dir: &Path) -> Result<Vec<TimingProfile>> {
|
||||
let entries: Vec<std::path::PathBuf> = if csv_dir.is_file() {
|
||||
vec![csv_dir.to_path_buf()]
|
||||
} else {
|
||||
let mut files: Vec<std::path::PathBuf> = std::fs::read_dir(csv_dir)
|
||||
.with_context(|| format!("reading directory {}", csv_dir.display()))?
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.path())
|
||||
.filter(|p| {
|
||||
p.extension()
|
||||
.map(|e| e.to_ascii_lowercase() == "csv")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
files.sort();
|
||||
files
|
||||
};
|
||||
|
||||
if entries.is_empty() {
|
||||
anyhow::bail!("no CSV files found in {}", csv_dir.display());
|
||||
}
|
||||
|
||||
let mut accumulators: std::collections::HashMap<String, LabelAccumulator> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for csv_path in &entries {
|
||||
parse_csv_file(csv_path, &mut accumulators)
|
||||
.with_context(|| format!("parsing {}", csv_path.display()))?;
|
||||
}
|
||||
|
||||
let mut profiles: Vec<TimingProfile> = accumulators
|
||||
.into_values()
|
||||
.filter(|a| a.count > 0)
|
||||
.map(|a| a.into_profile())
|
||||
.collect();
|
||||
profiles.sort_by(|a, b| a.attack_type.cmp(&b.attack_type));
|
||||
|
||||
Ok(profiles)
|
||||
}
|
||||
|
||||
fn parse_csv_file(
|
||||
path: &Path,
|
||||
accumulators: &mut std::collections::HashMap<String, LabelAccumulator>,
|
||||
) -> Result<()> {
|
||||
let mut rdr = csv::ReaderBuilder::new()
|
||||
.flexible(true)
|
||||
.trim(csv::Trim::All)
|
||||
.from_path(path)?;
|
||||
|
||||
let headers: Vec<String> = rdr
|
||||
.headers()?
|
||||
.iter()
|
||||
.map(|h| h.to_string())
|
||||
.collect();
|
||||
|
||||
// Locate required columns.
|
||||
let col_label = find_column(&headers, "Label")
|
||||
.with_context(|| format!("missing 'Label' column in {}", path.display()))?;
|
||||
let col_flow_duration = find_column(&headers, "Flow Duration");
|
||||
let col_iat_mean = find_column(&headers, "Flow IAT Mean");
|
||||
let col_iat_std = find_column(&headers, "Flow IAT Std");
|
||||
let col_bytes_per_sec = find_column(&headers, "Flow Bytes/s");
|
||||
|
||||
for result in rdr.records() {
|
||||
let record = match result {
|
||||
Ok(r) => r,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let label = match record.get(col_label) {
|
||||
Some(l) => l.trim().to_string(),
|
||||
None => continue,
|
||||
};
|
||||
if label.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let acc = accumulators
|
||||
.entry(label.clone())
|
||||
.or_insert_with(|| LabelAccumulator::new(label));
|
||||
acc.count += 1;
|
||||
|
||||
// Flow Duration (microseconds in CIC-IDS2017) → we treat as burst duration.
|
||||
if let Some(col) = col_flow_duration {
|
||||
if let Some(val) = record.get(col).and_then(|v| v.trim().parse::<f64>().ok()) {
|
||||
// Convert microseconds to seconds.
|
||||
acc.burst_duration.push(val / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
// Flow IAT Mean (microseconds) → inter-arrival time.
|
||||
if let Some(col) = col_iat_mean {
|
||||
if let Some(val) = record.get(col).and_then(|v| v.trim().parse::<f64>().ok()) {
|
||||
acc.inter_arrival.push(val / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
// Flow IAT Std → used as inter_arrival_std contribution.
|
||||
if let Some(col) = col_iat_std {
|
||||
if let Some(_val) = record.get(col).and_then(|v| v.trim().parse::<f64>().ok()) {
|
||||
// The per-flow IAT std contributes to the overall std via Welford above.
|
||||
// We use the IAT Mean values; std is computed from those.
|
||||
}
|
||||
}
|
||||
|
||||
// Flow Bytes/s.
|
||||
if let Some(col) = col_bytes_per_sec {
|
||||
if let Some(val) = record.get(col).and_then(|v| v.trim().parse::<f64>().ok()) {
|
||||
acc.flow_bytes_per_sec.push(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse timing profiles from an in-memory CSV string (useful for tests).
|
||||
pub fn extract_timing_profiles_from_str(csv_content: &str) -> Result<Vec<TimingProfile>> {
|
||||
let dir = tempfile::tempdir()?;
|
||||
let path = dir.path().join("test.csv");
|
||||
std::fs::write(&path, csv_content)?;
|
||||
extract_timing_profiles(&path)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_inline_csv() {
|
||||
let csv = "\
|
||||
Flow Duration,Total Fwd Packets,Flow Bytes/s,Flow IAT Mean,Flow IAT Std,Label
|
||||
1000000,10,5000.0,100000,50000,BENIGN
|
||||
2000000,20,10000.0,200000,100000,BENIGN
|
||||
500000,100,50000.0,5000,2000,DDoS
|
||||
300000,200,80000.0,3000,1000,DDoS
|
||||
100000,50,30000.0,10000,5000,DDoS
|
||||
";
|
||||
let profiles = extract_timing_profiles_from_str(csv).unwrap();
|
||||
assert_eq!(profiles.len(), 2, "should have BENIGN and DDoS profiles");
|
||||
|
||||
let benign = profiles.iter().find(|p| p.attack_type == "BENIGN").unwrap();
|
||||
assert_eq!(benign.sample_count, 2);
|
||||
// IAT mean: mean of [0.1, 0.2] = 0.15 seconds
|
||||
assert!(
|
||||
(benign.inter_arrival_mean - 0.15).abs() < 1e-6,
|
||||
"benign iat mean: {}",
|
||||
benign.inter_arrival_mean
|
||||
);
|
||||
// Burst duration mean: mean of [1.0, 2.0] = 1.5 seconds
|
||||
assert!(
|
||||
(benign.burst_duration_mean - 1.5).abs() < 1e-6,
|
||||
"benign burst mean: {}",
|
||||
benign.burst_duration_mean
|
||||
);
|
||||
|
||||
let ddos = profiles.iter().find(|p| p.attack_type == "DDoS").unwrap();
|
||||
assert_eq!(ddos.sample_count, 3);
|
||||
// Flow bytes/s mean: mean of [50000, 80000, 30000] = 53333.33
|
||||
let expected_bps = (50000.0 + 80000.0 + 30000.0) / 3.0;
|
||||
assert!(
|
||||
(ddos.flow_bytes_per_sec_mean - expected_bps).abs() < 1.0,
|
||||
"ddos bps mean: {}",
|
||||
ddos.flow_bytes_per_sec_mean
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats_accumulator() {
|
||||
let mut acc = StatsAccumulator::default();
|
||||
acc.push(2.0);
|
||||
acc.push(4.0);
|
||||
acc.push(6.0);
|
||||
assert!((acc.mean() - 4.0).abs() < 1e-10);
|
||||
// std_dev of [2,4,6] = 2.0
|
||||
assert!((acc.std_dev() - 2.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_csv_dir() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let result = extract_timing_profiles(dir.path());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_column_case_insensitive() {
|
||||
let headers: Vec<String> = vec![
|
||||
" Flow Duration ".to_string(),
|
||||
"label".to_string(),
|
||||
];
|
||||
assert_eq!(find_column(&headers, "Flow Duration"), Some(0));
|
||||
assert_eq!(find_column(&headers, "Label"), Some(1));
|
||||
assert_eq!(find_column(&headers, "LABEL"), Some(1));
|
||||
assert_eq!(find_column(&headers, "missing"), None);
|
||||
}
|
||||
}
|
||||
119
src/dataset/download.rs
Normal file
119
src/dataset/download.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
//! Download and cache upstream datasets for training.
|
||||
//!
|
||||
//! Cached under `~/.cache/sunbeam/<dataset>/`. Files are only downloaded
|
||||
//! once; subsequent runs reuse the cached copy.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Base cache directory for all sunbeam datasets.
|
||||
fn cache_base() -> PathBuf {
|
||||
let base = std::env::var("XDG_CACHE_HOME")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| {
|
||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||
PathBuf::from(home).join(".cache")
|
||||
});
|
||||
base.join("sunbeam")
|
||||
}
|
||||
|
||||
// --- CIC-IDS2017 ---
|
||||
|
||||
/// Only the Friday DDoS file — contains DDoS Hulk, Slowloris, slowhttptest, GoldenEye.
|
||||
const CICIDS_FILE: &str = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv";
|
||||
|
||||
/// Hugging Face mirror (public, no auth required).
|
||||
const CICIDS_BASE_URL: &str =
|
||||
"https://huggingface.co/datasets/c01dsnap/CIC-IDS2017/resolve/main";
|
||||
|
||||
fn cicids_cache_dir() -> PathBuf {
|
||||
cache_base().join("cicids")
|
||||
}
|
||||
|
||||
/// Return the path to the cached CIC-IDS2017 DDoS CSV, or `None` if not downloaded.
|
||||
pub fn cicids_cached_path() -> Option<PathBuf> {
|
||||
let path = cicids_cache_dir().join(CICIDS_FILE);
|
||||
if path.exists() {
|
||||
Some(path)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Download the CIC-IDS2017 Friday DDoS CSV to cache. Returns the cached path.
|
||||
pub fn download_cicids() -> Result<PathBuf> {
|
||||
let dir = cicids_cache_dir();
|
||||
let path = dir.join(CICIDS_FILE);
|
||||
|
||||
if path.exists() {
|
||||
eprintln!(" cached: {}", path.display());
|
||||
return Ok(path);
|
||||
}
|
||||
|
||||
let url = format!("{CICIDS_BASE_URL}/{CICIDS_FILE}");
|
||||
eprintln!(" downloading: {url}");
|
||||
eprintln!(" (this is ~170 MB, may take a minute)");
|
||||
|
||||
std::fs::create_dir_all(&dir)?;
|
||||
|
||||
// Stream to file to avoid holding 170MB in memory.
|
||||
let resp = reqwest::blocking::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(600))
|
||||
.build()?
|
||||
.get(&url)
|
||||
.send()
|
||||
.with_context(|| format!("fetching {url}"))?
|
||||
.error_for_status()
|
||||
.with_context(|| format!("HTTP error for {url}"))?;
|
||||
|
||||
let mut file = std::fs::File::create(&path)
|
||||
.with_context(|| format!("creating {}", path.display()))?;
|
||||
let bytes = resp.bytes().with_context(|| "reading response body")?;
|
||||
std::io::Write::write_all(&mut file, &bytes)?;
|
||||
|
||||
eprintln!(" saved: {}", path.display());
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
// --- CSIC 2010 ---
|
||||
|
||||
/// Download CSIC 2010 dataset files to cache (delegates to scanner::csic).
|
||||
pub fn download_csic() -> Result<()> {
|
||||
if crate::scanner::csic::csic_is_cached() {
|
||||
eprintln!(" cached: {}", crate::scanner::csic::csic_cache_path().display());
|
||||
return Ok(());
|
||||
}
|
||||
// fetch_csic_dataset downloads, caches, and parses — we only need the download side-effect.
|
||||
crate::scanner::csic::fetch_csic_dataset()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download all upstream datasets.
|
||||
pub fn download_all() -> Result<()> {
|
||||
eprintln!("downloading upstream datasets...\n");
|
||||
|
||||
eprintln!("[1/2] CSIC 2010 (scanner training data)");
|
||||
download_csic()?;
|
||||
eprintln!();
|
||||
|
||||
eprintln!("[2/2] CIC-IDS2017 DDoS timing profiles");
|
||||
let path = download_cicids()?;
|
||||
eprintln!(" ok: {}\n", path.display());
|
||||
|
||||
eprintln!("all datasets cached.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cache_paths() {
|
||||
let base = cache_base();
|
||||
assert!(base.to_str().unwrap().contains("sunbeam"));
|
||||
|
||||
let cicids = cicids_cache_dir();
|
||||
assert!(cicids.to_str().unwrap().contains("cicids"));
|
||||
}
|
||||
}
|
||||
6
src/dataset/mod.rs
Normal file
6
src/dataset/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod sample;
|
||||
pub mod modsec;
|
||||
pub mod cicids;
|
||||
pub mod synthetic;
|
||||
pub mod download;
|
||||
pub mod prepare;
|
||||
340
src/dataset/modsec.rs
Normal file
340
src/dataset/modsec.rs
Normal file
@@ -0,0 +1,340 @@
|
||||
//! Parser for OWASP ModSecurity audit log files (Serial / concurrent format).
|
||||
//!
|
||||
//! ModSecurity audit logs consist of multi-section entries delimited by boundary
|
||||
//! markers like `--xxxxxxxx-A--`, `--xxxxxxxx-B--`, etc. Each section contains
|
||||
//! different data about the transaction:
|
||||
//!
|
||||
//! - **A**: Timestamp, transaction ID, source/dest IP+port
|
||||
//! - **B**: Request line + headers
|
||||
//! - **C**: Request body
|
||||
//! - **F**: Response status + headers
|
||||
//! - **H**: Audit log trailer (rule matches, messages, actions)
|
||||
//!
|
||||
//! Any entry with a rule match in section H is labeled "attack"; entries with
|
||||
//! no rule matches are labeled "normal".
|
||||
|
||||
use crate::ddos::audit_log::AuditFields;
|
||||
use anyhow::{Context, Result};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
/// Parse a ModSecurity audit log file and return `(AuditFields, label)` pairs.
|
||||
///
|
||||
/// The label is `"attack"` if section H contains rule match messages, otherwise `"normal"`.
|
||||
pub fn parse_modsec_audit_log(path: &Path) -> Result<Vec<(AuditFields, String)>> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("reading ModSec audit log: {}", path.display()))?;
|
||||
parse_modsec_content(&content)
|
||||
}
|
||||
|
||||
/// Parse ModSecurity audit log content from a string.
|
||||
fn parse_modsec_content(content: &str) -> Result<Vec<(AuditFields, String)>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Collect sections per boundary ID.
|
||||
// Boundary markers look like: --xxxxxxxx-A--
|
||||
// where xxxxxxxx is a hex/alnum transaction ID and A is the section letter.
|
||||
let mut sections: HashMap<String, HashMap<char, Vec<String>>> = HashMap::new();
|
||||
let mut current_id: Option<String> = None;
|
||||
let mut current_section: Option<char> = None;
|
||||
let mut current_lines: Vec<String> = Vec::new();
|
||||
// Track order of first appearance of each boundary ID.
|
||||
let mut id_order: Vec<String> = Vec::new();
|
||||
|
||||
for line in content.lines() {
|
||||
if let Some((id, section)) = parse_boundary(line) {
|
||||
// Flush previous section.
|
||||
if let (Some(ref cid), Some(sec)) = (¤t_id, current_section) {
|
||||
let entry = sections.entry(cid.clone()).or_default();
|
||||
entry.entry(sec).or_default().extend(current_lines.drain(..));
|
||||
}
|
||||
if !sections.contains_key(&id) {
|
||||
id_order.push(id.clone());
|
||||
}
|
||||
current_id = Some(id);
|
||||
current_section = Some(section);
|
||||
current_lines.clear();
|
||||
} else if current_id.is_some() {
|
||||
current_lines.push(line.to_string());
|
||||
}
|
||||
}
|
||||
// Flush last section.
|
||||
if let (Some(ref cid), Some(sec)) = (¤t_id, current_section) {
|
||||
let entry = sections.entry(cid.clone()).or_default();
|
||||
entry.entry(sec).or_default().extend(current_lines.drain(..));
|
||||
}
|
||||
|
||||
// Convert each transaction into AuditFields.
|
||||
for id in &id_order {
|
||||
if let Some(secs) = sections.get(id) {
|
||||
if let Some(fields) = transaction_to_audit_fields(secs) {
|
||||
results.push(fields);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Try to parse a boundary marker line.
|
||||
/// Returns `(boundary_id, section_letter)` on success.
|
||||
fn parse_boundary(line: &str) -> Option<(String, char)> {
|
||||
let trimmed = line.trim();
|
||||
if !trimmed.starts_with("--") || !trimmed.ends_with("--") {
|
||||
return None;
|
||||
}
|
||||
// Strip leading and trailing --
|
||||
let inner = &trimmed[2..trimmed.len() - 2];
|
||||
// Should be: boundary_id-SECTION_LETTER
|
||||
let dash_pos = inner.rfind('-')?;
|
||||
if dash_pos == 0 || dash_pos == inner.len() - 1 {
|
||||
return None;
|
||||
}
|
||||
let id = &inner[..dash_pos];
|
||||
let section_str = &inner[dash_pos + 1..];
|
||||
if section_str.len() != 1 {
|
||||
return None;
|
||||
}
|
||||
let section = section_str.chars().next()?;
|
||||
if !section.is_ascii_alphabetic() {
|
||||
return None;
|
||||
}
|
||||
Some((id.to_string(), section))
|
||||
}
|
||||
|
||||
/// Convert parsed sections into `(AuditFields, label)`.
|
||||
fn transaction_to_audit_fields(
|
||||
sections: &HashMap<char, Vec<String>>,
|
||||
) -> Option<(AuditFields, String)> {
|
||||
// Section A: timestamp + connection info
|
||||
let client_ip = sections
|
||||
.get(&'A')
|
||||
.and_then(|lines| {
|
||||
// Section A first line typically:
|
||||
// [dd/Mon/yyyy:HH:MM:SS +offset] transaction_id source_ip source_port dest_ip dest_port
|
||||
lines.first().and_then(|line| {
|
||||
// Find the content after the timestamp bracket.
|
||||
let after_bracket = line.find(']').map(|i| &line[i + 1..])?;
|
||||
let parts: Vec<&str> = after_bracket.split_whitespace().collect();
|
||||
// parts: [transaction_id, source_ip, source_port, dest_ip, dest_port]
|
||||
if parts.len() >= 3 {
|
||||
Some(parts[1].to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| "0.0.0.0".to_string());
|
||||
|
||||
// Section B: request line + headers
|
||||
let section_b = sections.get(&'B')?;
|
||||
if section_b.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// First non-empty line is the request line.
|
||||
let request_line = section_b.iter().find(|l| !l.trim().is_empty())?;
|
||||
let req_parts: Vec<&str> = request_line.splitn(3, ' ').collect();
|
||||
if req_parts.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
let method = req_parts[0].to_string();
|
||||
let raw_url = req_parts[1];
|
||||
|
||||
let (path, query) = if let Some(q) = raw_url.find('?') {
|
||||
(raw_url[..q].to_string(), raw_url[q + 1..].to_string())
|
||||
} else {
|
||||
(raw_url.to_string(), String::new())
|
||||
};
|
||||
|
||||
// Parse headers from remaining lines.
|
||||
let mut headers: Vec<(String, String)> = Vec::new();
|
||||
for line in section_b.iter().skip(1) {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
break;
|
||||
}
|
||||
if let Some(colon) = trimmed.find(':') {
|
||||
let key = trimmed[..colon].trim().to_ascii_lowercase();
|
||||
let value = trimmed[colon + 1..].trim().to_string();
|
||||
headers.push((key, value));
|
||||
}
|
||||
}
|
||||
|
||||
let get_header = |name: &str| -> Option<String> {
|
||||
headers
|
||||
.iter()
|
||||
.find(|(k, _)| k == name)
|
||||
.map(|(_, v)| v.clone())
|
||||
};
|
||||
|
||||
let host = get_header("host").unwrap_or_else(|| "unknown".to_string());
|
||||
let user_agent = get_header("user-agent").unwrap_or_else(|| "-".to_string());
|
||||
let has_cookies = get_header("cookie").is_some();
|
||||
let referer = get_header("referer").filter(|r| r != "-" && !r.is_empty());
|
||||
let accept_language = get_header("accept-language").filter(|a| a != "-" && !a.is_empty());
|
||||
let content_length: u64 = get_header("content-length")
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
// Section F: response status
|
||||
let status = sections
|
||||
.get(&'F')
|
||||
.and_then(|lines| {
|
||||
// First non-empty line: "HTTP/1.1 403 Forbidden"
|
||||
lines.iter().find(|l| !l.trim().is_empty()).and_then(|line| {
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.len() >= 2 {
|
||||
parts[1].parse::<u16>().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
// Section H: rule matches → determines label.
|
||||
let has_rule_match = sections
|
||||
.get(&'H')
|
||||
.map(|lines| {
|
||||
lines.iter().any(|l| {
|
||||
let lower = l.to_ascii_lowercase();
|
||||
lower.contains("matched") || lower.contains("warning") || lower.contains("id:")
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
let label = if has_rule_match { "attack" } else { "normal" }.to_string();
|
||||
|
||||
let fields = AuditFields {
|
||||
method,
|
||||
host,
|
||||
path,
|
||||
query,
|
||||
client_ip,
|
||||
status,
|
||||
duration_ms: 0,
|
||||
content_length,
|
||||
user_agent,
|
||||
has_cookies: Some(has_cookies),
|
||||
referer,
|
||||
accept_language,
|
||||
backend: "-".to_string(),
|
||||
label: Some(label.clone()),
|
||||
};
|
||||
|
||||
Some((fields, label))
|
||||
}
|
||||
|
||||
/// Cache directory for ModSec data (mirrors the CSIC caching pattern).
|
||||
#[allow(dead_code)]
|
||||
fn cache_dir() -> std::path::PathBuf {
|
||||
let base = std::env::var("XDG_CACHE_HOME")
|
||||
.map(std::path::PathBuf::from)
|
||||
.unwrap_or_else(|_| {
|
||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||
std::path::PathBuf::from(home).join(".cache")
|
||||
});
|
||||
base.join("sunbeam").join("modsec")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const SAMPLE_AUDIT_LOG: &str = r#"--a1b2c3d4-A--
|
||||
[01/Jan/2026:12:00:00 +0000] XYZ123 192.168.1.100 54321 10.0.0.1 80
|
||||
--a1b2c3d4-B--
|
||||
GET /admin/config.php?debug=1 HTTP/1.1
|
||||
Host: example.com
|
||||
User-Agent: curl/7.68.0
|
||||
Accept: */*
|
||||
--a1b2c3d4-C--
|
||||
--a1b2c3d4-F--
|
||||
HTTP/1.1 403 Forbidden
|
||||
Content-Type: text/html
|
||||
--a1b2c3d4-H--
|
||||
Message: Warning. Matched "Operator `Rx' with parameter" [id "941100"]
|
||||
Action: Intercepted (phase 2)
|
||||
--a1b2c3d4-Z--
|
||||
--e5f6a7b8-A--
|
||||
[01/Jan/2026:12:00:01 +0000] ABC456 10.0.0.50 12345 10.0.0.1 80
|
||||
--e5f6a7b8-B--
|
||||
GET /index.html HTTP/1.1
|
||||
Host: example.com
|
||||
User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120
|
||||
Accept: text/html
|
||||
Accept-Language: en-US,en;q=0.9
|
||||
Cookie: session=abc123
|
||||
Referer: https://example.com/
|
||||
--e5f6a7b8-F--
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: text/html
|
||||
--e5f6a7b8-H--
|
||||
--e5f6a7b8-Z--
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn test_parse_boundary() {
|
||||
assert_eq!(
|
||||
parse_boundary("--a1b2c3d4-A--"),
|
||||
Some(("a1b2c3d4".to_string(), 'A'))
|
||||
);
|
||||
assert_eq!(
|
||||
parse_boundary("--a1b2c3d4-H--"),
|
||||
Some(("a1b2c3d4".to_string(), 'H'))
|
||||
);
|
||||
assert_eq!(parse_boundary("not a boundary"), None);
|
||||
assert_eq!(parse_boundary("--invalid--"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_modsec_audit_log_snippet() {
|
||||
let results = parse_modsec_content(SAMPLE_AUDIT_LOG).unwrap();
|
||||
assert_eq!(results.len(), 2, "should parse two transactions");
|
||||
|
||||
// First entry: attack (has rule match in section H).
|
||||
let (attack_fields, attack_label) = &results[0];
|
||||
assert_eq!(attack_label, "attack");
|
||||
assert_eq!(attack_fields.method, "GET");
|
||||
assert_eq!(attack_fields.path, "/admin/config.php");
|
||||
assert_eq!(attack_fields.query, "debug=1");
|
||||
assert_eq!(attack_fields.client_ip, "192.168.1.100");
|
||||
assert_eq!(attack_fields.user_agent, "curl/7.68.0");
|
||||
assert_eq!(attack_fields.status, 403);
|
||||
assert!(!attack_fields.has_cookies.unwrap_or(true));
|
||||
|
||||
// Second entry: normal (no rule match).
|
||||
let (normal_fields, normal_label) = &results[1];
|
||||
assert_eq!(normal_label, "normal");
|
||||
assert_eq!(normal_fields.method, "GET");
|
||||
assert_eq!(normal_fields.path, "/index.html");
|
||||
assert_eq!(normal_fields.client_ip, "10.0.0.50");
|
||||
assert_eq!(normal_fields.status, 200);
|
||||
assert!(normal_fields.has_cookies.unwrap_or(false));
|
||||
assert!(normal_fields.referer.is_some());
|
||||
assert!(normal_fields.accept_language.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_input() {
|
||||
let results = parse_modsec_content("").unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_entry_no_section_h() {
|
||||
let content = r#"--abc123-A--
|
||||
[01/Jan/2026:12:00:00 +0000] TX1 1.2.3.4 1234 5.6.7.8 80
|
||||
--abc123-B--
|
||||
GET / HTTP/1.1
|
||||
Host: test.com
|
||||
--abc123-F--
|
||||
HTTP/1.1 200 OK
|
||||
--abc123-Z--
|
||||
"#;
|
||||
let results = parse_modsec_content(content).unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].1, "normal");
|
||||
}
|
||||
}
|
||||
598
src/dataset/prepare.rs
Normal file
598
src/dataset/prepare.rs
Normal file
@@ -0,0 +1,598 @@
|
||||
//! Dataset preparation orchestrator.
|
||||
//!
|
||||
//! Combines production logs, external datasets (CSIC, OWASP ModSec), and
|
||||
//! synthetic data (CIC-IDS2017 timing profiles + wordlists) into a single
|
||||
//! `DatasetManifest` serialized as bincode.
|
||||
|
||||
use crate::dataset::sample::{DataSource, DatasetManifest, DatasetStats, TrainingSample};
|
||||
use crate::ddos::audit_log::{AuditFields, AuditLog};
|
||||
use crate::ddos::features::{method_to_u8, LogIpState};
|
||||
use crate::ddos::train::HeuristicThresholds;
|
||||
use crate::scanner::features::{self, fx_hash_bytes};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use rustc_hash::FxHashSet;
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
|
||||
/// Arguments for the `prepare-dataset` command.
|
||||
pub struct PrepareDatasetArgs {
|
||||
/// Path to production audit log JSONL file.
|
||||
pub input: String,
|
||||
/// Path to OWASP ModSecurity audit log file (optional extra data).
|
||||
pub owasp: Option<String>,
|
||||
/// Path to wordlist directory (optional, enhances synthetic scanner).
|
||||
pub wordlists: Option<String>,
|
||||
/// Output path for the bincode dataset file.
|
||||
pub output: String,
|
||||
/// RNG seed for synthetic generation.
|
||||
pub seed: u64,
|
||||
/// Path to heuristics.toml for auto-labeling production logs.
|
||||
pub heuristics: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for PrepareDatasetArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
input: String::new(),
|
||||
owasp: None,
|
||||
wordlists: None,
|
||||
output: "dataset.bin".to_string(),
|
||||
seed: 42,
|
||||
heuristics: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(args: PrepareDatasetArgs) -> Result<()> {
|
||||
let mut scanner_samples: Vec<TrainingSample> = Vec::new();
|
||||
let mut ddos_samples: Vec<TrainingSample> = Vec::new();
|
||||
|
||||
// --- 0. Download upstream datasets if not cached ---
|
||||
crate::dataset::download::download_all()?;
|
||||
|
||||
// --- 1. Parse production logs ---
|
||||
let heuristics = if let Some(h_path) = &args.heuristics {
|
||||
let content = std::fs::read_to_string(h_path)
|
||||
.with_context(|| format!("reading heuristics from {h_path}"))?;
|
||||
toml::from_str::<HeuristicThresholds>(&content)
|
||||
.with_context(|| format!("parsing heuristics from {h_path}"))?
|
||||
} else {
|
||||
HeuristicThresholds::new(10.0, 0.85, 0.7, 0.3, 0.05, 20.0, 10)
|
||||
};
|
||||
eprintln!("parsing production logs from {}...", args.input);
|
||||
let (prod_scanner, prod_ddos) = parse_production_logs(&args.input, &heuristics)?;
|
||||
eprintln!(
|
||||
" production: {} scanner, {} ddos samples",
|
||||
prod_scanner.len(),
|
||||
prod_ddos.len()
|
||||
);
|
||||
scanner_samples.extend(prod_scanner);
|
||||
ddos_samples.extend(prod_ddos);
|
||||
|
||||
// --- 2. CSIC 2010 (scanner) ---
|
||||
eprintln!("fetching CSIC 2010 dataset...");
|
||||
let csic_entries = crate::scanner::csic::fetch_csic_dataset()?;
|
||||
let csic_samples = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?;
|
||||
eprintln!(" CSIC: {} scanner samples", csic_samples.len());
|
||||
scanner_samples.extend(csic_samples);
|
||||
|
||||
// --- 3. OWASP ModSec (scanner) ---
|
||||
if let Some(owasp_path) = &args.owasp {
|
||||
eprintln!("parsing OWASP ModSec audit log from {owasp_path}...");
|
||||
let modsec_entries =
|
||||
crate::dataset::modsec::parse_modsec_audit_log(Path::new(owasp_path))?;
|
||||
let entries_with_host: Vec<(AuditFields, String)> = modsec_entries
|
||||
.into_iter()
|
||||
.map(|(fields, _label)| {
|
||||
let host_prefix = fields.host.split('.').next().unwrap_or("").to_string();
|
||||
(fields, host_prefix)
|
||||
})
|
||||
.collect();
|
||||
let modsec_samples =
|
||||
entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?;
|
||||
eprintln!(" OWASP: {} scanner samples", modsec_samples.len());
|
||||
scanner_samples.extend(modsec_samples);
|
||||
}
|
||||
|
||||
// --- 4. CIC-IDS2017 timing profiles (from cache if downloaded) ---
|
||||
let cicids_profiles = if let Some(cached_path) = crate::dataset::download::cicids_cached_path()
|
||||
{
|
||||
eprintln!("extracting CIC-IDS2017 timing profiles from cache...");
|
||||
let profiles = crate::dataset::cicids::extract_timing_profiles(&cached_path)?;
|
||||
eprintln!(" extracted {} attack-type profiles", profiles.len());
|
||||
profiles
|
||||
} else {
|
||||
eprintln!(" CIC-IDS2017 not cached; using built-in DDoS distributions");
|
||||
eprintln!(" (run `download-datasets` first for real timing profiles)");
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// --- 5. Synthetic data (both models, always generated) ---
|
||||
eprintln!("generating synthetic samples...");
|
||||
let config = crate::dataset::synthetic::SyntheticConfig {
|
||||
num_ddos_attack: 10000,
|
||||
num_ddos_normal: 10000,
|
||||
num_scanner_attack: 5000,
|
||||
num_scanner_normal: 5000,
|
||||
seed: args.seed,
|
||||
};
|
||||
|
||||
// Synthetic DDoS (uses CIC-IDS2017 profiles if cached, fallback defaults otherwise).
|
||||
let synthetic_ddos =
|
||||
crate::dataset::synthetic::generate_ddos_samples(&cicids_profiles, &config);
|
||||
eprintln!(" synthetic DDoS: {} samples", synthetic_ddos.len());
|
||||
ddos_samples.extend(synthetic_ddos);
|
||||
|
||||
// Synthetic scanner (uses wordlists if provided, built-in patterns otherwise).
|
||||
let synthetic_scanner = crate::dataset::synthetic::generate_scanner_samples(
|
||||
args.wordlists.as_deref().map(Path::new),
|
||||
None,
|
||||
&config,
|
||||
)?;
|
||||
eprintln!(" synthetic scanner: {} samples", synthetic_scanner.len());
|
||||
scanner_samples.extend(synthetic_scanner);
|
||||
|
||||
// --- 6. Compute stats ---
|
||||
let mut samples_by_source: HashMap<DataSource, usize> = HashMap::new();
|
||||
for s in scanner_samples.iter().chain(ddos_samples.iter()) {
|
||||
*samples_by_source.entry(s.source.clone()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let scanner_attack_count = scanner_samples.iter().filter(|s| s.label > 0.5).count();
|
||||
let ddos_attack_count = ddos_samples.iter().filter(|s| s.label > 0.5).count();
|
||||
|
||||
let stats = DatasetStats {
|
||||
total_samples: scanner_samples.len() + ddos_samples.len(),
|
||||
scanner_samples: scanner_samples.len(),
|
||||
ddos_samples: ddos_samples.len(),
|
||||
samples_by_source,
|
||||
attack_ratio_scanner: if scanner_samples.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
scanner_attack_count as f64 / scanner_samples.len() as f64
|
||||
},
|
||||
attack_ratio_ddos: if ddos_samples.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
ddos_attack_count as f64 / ddos_samples.len() as f64
|
||||
},
|
||||
};
|
||||
|
||||
eprintln!("\n--- dataset stats ---");
|
||||
eprintln!("total samples: {}", stats.total_samples);
|
||||
eprintln!(
|
||||
"scanner: {} ({} attack, {:.1}% attack ratio)",
|
||||
stats.scanner_samples,
|
||||
scanner_attack_count,
|
||||
stats.attack_ratio_scanner * 100.0
|
||||
);
|
||||
eprintln!(
|
||||
"ddos: {} ({} attack, {:.1}% attack ratio)",
|
||||
stats.ddos_samples,
|
||||
ddos_attack_count,
|
||||
stats.attack_ratio_ddos * 100.0
|
||||
);
|
||||
for (source, count) in &stats.samples_by_source {
|
||||
eprintln!(" {source}: {count}");
|
||||
}
|
||||
|
||||
let manifest = DatasetManifest {
|
||||
scanner_samples,
|
||||
ddos_samples,
|
||||
stats,
|
||||
};
|
||||
|
||||
crate::dataset::sample::save_dataset(&manifest, Path::new(&args.output))?;
|
||||
eprintln!("\ndataset saved to {}", args.output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse production JSONL logs and produce both scanner and DDoS training samples.
|
||||
///
|
||||
/// When logs lack explicit labels, heuristic auto-labeling is applied:
|
||||
/// - Scanner: attack if status >= 400 AND path matches scanner patterns; normal if
|
||||
/// status < 400 AND has browser indicators (cookies + referer + accept-language).
|
||||
/// - DDoS: per-IP feature thresholds from `HeuristicThresholds` (same logic as `ddos/train.rs`).
|
||||
fn parse_production_logs(
|
||||
input: &str,
|
||||
heuristics: &HeuristicThresholds,
|
||||
) -> Result<(Vec<TrainingSample>, Vec<TrainingSample>)> {
|
||||
let file = std::fs::File::open(input)
|
||||
.with_context(|| format!("opening {input}"))?;
|
||||
let reader = std::io::BufReader::new(file);
|
||||
|
||||
let mut scanner_samples = Vec::new();
|
||||
let mut parsed_entries: Vec<(AuditFields, String)> = Vec::new();
|
||||
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
|
||||
|
||||
// Build hashes for scanner feature extraction.
|
||||
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
||||
.iter()
|
||||
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
||||
.collect();
|
||||
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
|
||||
.iter()
|
||||
.map(|e| fx_hash_bytes(e.as_bytes()))
|
||||
.collect();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let entry: AuditLog = match serde_json::from_str(&line) {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let host_prefix = entry
|
||||
.fields
|
||||
.host
|
||||
.split('.')
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
|
||||
parsed_entries.push((entry.fields, host_prefix));
|
||||
}
|
||||
|
||||
// --- Scanner samples from production logs ---
|
||||
for (fields, host_prefix) in &parsed_entries {
|
||||
let has_cookies = fields.has_cookies.unwrap_or(false);
|
||||
let has_referer = fields
|
||||
.referer
|
||||
.as_ref()
|
||||
.map(|r| r != "-" && !r.is_empty())
|
||||
.unwrap_or(false);
|
||||
let has_accept_language = fields
|
||||
.accept_language
|
||||
.as_ref()
|
||||
.map(|a| a != "-" && !a.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
let feats = features::extract_features(
|
||||
&fields.method,
|
||||
&fields.path,
|
||||
host_prefix,
|
||||
has_cookies,
|
||||
has_referer,
|
||||
has_accept_language,
|
||||
"-",
|
||||
&fields.user_agent,
|
||||
fields.content_length,
|
||||
&fragment_hashes,
|
||||
&extension_hashes,
|
||||
&log_hosts,
|
||||
);
|
||||
|
||||
// Use ground-truth label if present, otherwise heuristic auto-label.
|
||||
let label = match fields.label.as_deref() {
|
||||
Some("attack" | "anomalous") => Some(1.0f32),
|
||||
Some("normal") => Some(0.0f32),
|
||||
_ => {
|
||||
// Heuristic scanner labeling:
|
||||
// Attack: 404+ AND suspicious path (excluding .git which is valid on Gitea hosts).
|
||||
// Normal: success status AND browser indicators (cookies + referer + accept-language).
|
||||
// Note: 401 is excluded since it's expected for private repos.
|
||||
let status = fields.status;
|
||||
let path_lower = fields.path.to_ascii_lowercase();
|
||||
let is_suspicious_path = path_lower.contains(".env")
|
||||
|| path_lower.contains("wp-login")
|
||||
|| path_lower.contains("wp-admin")
|
||||
|| path_lower.contains("phpmyadmin")
|
||||
|| path_lower.contains("cgi-bin")
|
||||
|| path_lower.contains("phpinfo")
|
||||
|| path_lower.contains("/shell")
|
||||
|| path_lower.contains("..%2f")
|
||||
|| path_lower.contains("../");
|
||||
if status >= 404 && is_suspicious_path {
|
||||
Some(1.0f32)
|
||||
} else if status < 400 && has_cookies && has_referer && has_accept_language {
|
||||
Some(0.0f32)
|
||||
} else {
|
||||
None // ambiguous — skip
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(l) = label {
|
||||
scanner_samples.push(TrainingSample {
|
||||
features: feats.iter().map(|&v| v as f32).collect(),
|
||||
label: l,
|
||||
source: DataSource::ProductionLogs,
|
||||
weight: 1.0,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// --- DDoS samples from production logs ---
|
||||
let ddos_samples = extract_ddos_samples_from_entries(&parsed_entries, heuristics)?;
|
||||
|
||||
Ok((scanner_samples, ddos_samples))
|
||||
}
|
||||
|
||||
/// Extract DDoS feature vectors from parsed log entries using sliding windows.
|
||||
fn extract_ddos_samples_from_entries(
|
||||
entries: &[(AuditFields, String)],
|
||||
heuristics: &HeuristicThresholds,
|
||||
) -> Result<Vec<TrainingSample>> {
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
fn fx_hash(s: &str) -> u64 {
|
||||
let mut h = rustc_hash::FxHasher::default();
|
||||
s.hash(&mut h);
|
||||
h.finish()
|
||||
}
|
||||
|
||||
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
|
||||
let mut ip_labels: FxHashMap<String, Option<f32>> = FxHashMap::default();
|
||||
|
||||
for (fields, _host_prefix) in entries {
|
||||
let ip = crate::ddos::audit_log::strip_port(&fields.client_ip).to_string();
|
||||
|
||||
let state = ip_states.entry(ip.clone()).or_default();
|
||||
// Use a simple counter as "timestamp" since we don't have parsed timestamps here.
|
||||
let ts = state.timestamps.len() as f64;
|
||||
state.timestamps.push(ts);
|
||||
state.methods.push(method_to_u8(&fields.method));
|
||||
state.path_hashes.push(fx_hash(&fields.path));
|
||||
state.host_hashes.push(fx_hash(&fields.host));
|
||||
state
|
||||
.user_agent_hashes
|
||||
.push(fx_hash(&fields.user_agent));
|
||||
state.statuses.push(fields.status);
|
||||
state
|
||||
.durations
|
||||
.push(fields.duration_ms.min(u32::MAX as u64) as u32);
|
||||
state
|
||||
.content_lengths
|
||||
.push(fields.content_length.min(u32::MAX as u64) as u32);
|
||||
state
|
||||
.has_cookies
|
||||
.push(fields.has_cookies.unwrap_or(false));
|
||||
state.has_referer.push(
|
||||
fields
|
||||
.referer
|
||||
.as_deref()
|
||||
.map(|r| r != "-")
|
||||
.unwrap_or(false),
|
||||
);
|
||||
state.has_accept_language.push(
|
||||
fields
|
||||
.accept_language
|
||||
.as_deref()
|
||||
.map(|a| a != "-")
|
||||
.unwrap_or(false),
|
||||
);
|
||||
state.suspicious_paths.push(
|
||||
crate::ddos::features::is_suspicious_path(&fields.path),
|
||||
);
|
||||
|
||||
// Track label per IP if available.
|
||||
if let Some(ref label_str) = fields.label {
|
||||
let label_val = match label_str.as_str() {
|
||||
"attack" | "anomalous" => Some(1.0f32),
|
||||
"normal" => Some(0.0f32),
|
||||
_ => None,
|
||||
};
|
||||
ip_labels.insert(ip, label_val);
|
||||
}
|
||||
}
|
||||
|
||||
let min_events = heuristics.min_events.max(3);
|
||||
let window_size = 50; // events per sliding window
|
||||
let window_step = 25; // step size (50% overlap)
|
||||
let window_secs = 60.0;
|
||||
let mut samples = Vec::new();
|
||||
|
||||
for (ip, state) in &ip_states {
|
||||
let n = state.timestamps.len();
|
||||
if n < min_events {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ground-truth label if available.
|
||||
let gt_label = match ip_labels.get(ip) {
|
||||
Some(Some(l)) => Some(*l),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Sliding windows: extract multiple feature vectors per IP.
|
||||
// For IPs with fewer events than window_size, use a single window over all events.
|
||||
let mut start = 0;
|
||||
loop {
|
||||
let end = (start + window_size).min(n);
|
||||
if end - start < min_events {
|
||||
break;
|
||||
}
|
||||
|
||||
let fv = state.extract_features_for_window(start, end, window_secs);
|
||||
|
||||
let label = gt_label.unwrap_or_else(|| {
|
||||
// Heuristic DDoS labeling (mirrors ddos/train.rs logic):
|
||||
let is_attack = fv[0] > heuristics.request_rate
|
||||
|| fv[7] > heuristics.path_repetition
|
||||
|| fv[3] > heuristics.error_rate
|
||||
|| fv[13] > heuristics.suspicious_path_ratio
|
||||
|| (fv[10] < heuristics.no_cookies_threshold
|
||||
&& fv[1] > heuristics.no_cookies_path_count);
|
||||
if is_attack { 1.0f32 } else { 0.0f32 }
|
||||
});
|
||||
|
||||
samples.push(TrainingSample {
|
||||
features: fv.iter().map(|&v| v as f32).collect(),
|
||||
label,
|
||||
source: DataSource::ProductionLogs,
|
||||
weight: 1.0,
|
||||
});
|
||||
|
||||
if end >= n {
|
||||
break;
|
||||
}
|
||||
start += window_step;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(samples)
|
||||
}
|
||||
|
||||
/// Convert external dataset entries (CSIC, OWASP) into scanner TrainingSamples.
|
||||
fn entries_to_scanner_samples(
|
||||
entries: &[(AuditFields, String)],
|
||||
source: DataSource,
|
||||
weight: f32,
|
||||
) -> Result<Vec<TrainingSample>> {
|
||||
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
|
||||
.iter()
|
||||
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
|
||||
.collect();
|
||||
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
|
||||
.iter()
|
||||
.map(|e| fx_hash_bytes(e.as_bytes()))
|
||||
.collect();
|
||||
|
||||
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
|
||||
for (_, host_prefix) in entries {
|
||||
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
|
||||
}
|
||||
|
||||
let mut samples = Vec::new();
|
||||
|
||||
for (fields, host_prefix) in entries {
|
||||
let has_cookies = fields.has_cookies.unwrap_or(false);
|
||||
let has_referer = fields
|
||||
.referer
|
||||
.as_ref()
|
||||
.map(|r| r != "-" && !r.is_empty())
|
||||
.unwrap_or(false);
|
||||
let has_accept_language = fields
|
||||
.accept_language
|
||||
.as_ref()
|
||||
.map(|a| a != "-" && !a.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
let feats = features::extract_features(
|
||||
&fields.method,
|
||||
&fields.path,
|
||||
host_prefix,
|
||||
has_cookies,
|
||||
has_referer,
|
||||
has_accept_language,
|
||||
"-",
|
||||
&fields.user_agent,
|
||||
fields.content_length,
|
||||
&fragment_hashes,
|
||||
&extension_hashes,
|
||||
&log_hosts,
|
||||
);
|
||||
|
||||
let label = match fields.label.as_deref() {
|
||||
Some("attack" | "anomalous") => 1.0f32,
|
||||
Some("normal") => 0.0f32,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
samples.push(TrainingSample {
|
||||
features: feats.iter().map(|&v| v as f32).collect(),
|
||||
label,
|
||||
source: source.clone(),
|
||||
weight,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(samples)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ddos::features::NUM_FEATURES;
|
||||
use crate::scanner::features::NUM_SCANNER_FEATURES;
|
||||
|
||||
#[test]
|
||||
fn test_ddos_sample_feature_count() {
|
||||
// Verify that DDoS samples from sliding windows produce 14 features.
|
||||
let entries = vec![
|
||||
make_test_entry("GET", "/", "1.2.3.4", 200, "normal"),
|
||||
make_test_entry("GET", "/about", "1.2.3.4", 200, "normal"),
|
||||
make_test_entry("GET", "/contact", "1.2.3.4", 200, "normal"),
|
||||
make_test_entry("GET", "/blog", "1.2.3.4", 200, "normal"),
|
||||
make_test_entry("GET", "/faq", "1.2.3.4", 200, "normal"),
|
||||
];
|
||||
let heuristics = HeuristicThresholds::new(10.0, 0.85, 0.7, 0.3, 0.05, 20.0, 5);
|
||||
let samples = extract_ddos_samples_from_entries(&entries, &heuristics).unwrap();
|
||||
// With only 5 entries at min_events=5, we should get 1 sample.
|
||||
assert!(!samples.is_empty(), "should produce DDoS samples");
|
||||
for s in &samples {
|
||||
assert_eq!(
|
||||
s.features.len(),
|
||||
NUM_FEATURES,
|
||||
"DDoS sample should have {NUM_FEATURES} features"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scanner_sample_feature_count() {
|
||||
let entries = vec![
|
||||
make_test_entry("GET", "/.env", "1.2.3.4", 404, "attack"),
|
||||
make_test_entry("GET", "/index.html", "5.6.7.8", 200, "normal"),
|
||||
];
|
||||
let samples =
|
||||
entries_to_scanner_samples(&entries, DataSource::ProductionLogs, 1.0).unwrap();
|
||||
assert_eq!(samples.len(), 2);
|
||||
for s in &samples {
|
||||
assert_eq!(
|
||||
s.features.len(),
|
||||
NUM_SCANNER_FEATURES,
|
||||
"scanner sample should have {NUM_SCANNER_FEATURES} features"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entries_to_scanner_samples_labels() {
|
||||
let entries = vec![
|
||||
make_test_entry("GET", "/.env", "1.2.3.4", 404, "attack"),
|
||||
make_test_entry("GET", "/", "5.6.7.8", 200, "normal"),
|
||||
make_test_entry("GET", "/page", "9.10.11.12", 200, "unknown_label"),
|
||||
];
|
||||
let samples =
|
||||
entries_to_scanner_samples(&entries, DataSource::Csic2010, 0.8).unwrap();
|
||||
// "unknown_label" should be skipped.
|
||||
assert_eq!(samples.len(), 2);
|
||||
assert_eq!(samples[0].label, 1.0);
|
||||
assert_eq!(samples[1].label, 0.0);
|
||||
assert_eq!(samples[0].weight, 0.8);
|
||||
assert_eq!(samples[0].source, DataSource::Csic2010);
|
||||
}
|
||||
|
||||
fn make_test_entry(
|
||||
method: &str,
|
||||
path: &str,
|
||||
client_ip: &str,
|
||||
status: u16,
|
||||
label: &str,
|
||||
) -> (AuditFields, String) {
|
||||
let fields = AuditFields {
|
||||
method: method.to_string(),
|
||||
host: "test.sunbeam.pt".to_string(),
|
||||
path: path.to_string(),
|
||||
query: String::new(),
|
||||
client_ip: client_ip.to_string(),
|
||||
status,
|
||||
duration_ms: 10,
|
||||
content_length: 0,
|
||||
user_agent: "Mozilla/5.0".to_string(),
|
||||
has_cookies: Some(true),
|
||||
referer: Some("https://test.sunbeam.pt".to_string()),
|
||||
accept_language: Some("en-US".to_string()),
|
||||
backend: "test-svc:8080".to_string(),
|
||||
label: Some(label.to_string()),
|
||||
};
|
||||
(fields, "test".to_string())
|
||||
}
|
||||
}
|
||||
156
src/dataset/sample.rs
Normal file
156
src/dataset/sample.rs
Normal file
@@ -0,0 +1,156 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
/// Provenance of a training sample.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum DataSource {
|
||||
ProductionLogs,
|
||||
Csic2010,
|
||||
OwaspModSec,
|
||||
SyntheticCicTiming,
|
||||
SyntheticWordlist,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DataSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DataSource::ProductionLogs => write!(f, "ProductionLogs"),
|
||||
DataSource::Csic2010 => write!(f, "Csic2010"),
|
||||
DataSource::OwaspModSec => write!(f, "OwaspModSec"),
|
||||
DataSource::SyntheticCicTiming => write!(f, "SyntheticCicTiming"),
|
||||
DataSource::SyntheticWordlist => write!(f, "SyntheticWordlist"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single labeled training sample with its feature vector, label, provenance,
|
||||
/// and weight (used during training to down-weight synthetic/external data).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingSample {
|
||||
pub features: Vec<f32>,
|
||||
/// 0.0 = normal, 1.0 = attack
|
||||
pub label: f32,
|
||||
pub source: DataSource,
|
||||
/// Sample weight: 1.0 for production, 0.8 for external datasets, 0.5 for synthetic.
|
||||
pub weight: f32,
|
||||
}
|
||||
|
||||
/// Aggregate statistics about a prepared dataset.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatasetStats {
|
||||
pub total_samples: usize,
|
||||
pub scanner_samples: usize,
|
||||
pub ddos_samples: usize,
|
||||
pub samples_by_source: HashMap<DataSource, usize>,
|
||||
pub attack_ratio_scanner: f64,
|
||||
pub attack_ratio_ddos: f64,
|
||||
}
|
||||
|
||||
/// The full serializable dataset: scanner samples, DDoS samples, and stats.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatasetManifest {
|
||||
pub scanner_samples: Vec<TrainingSample>,
|
||||
pub ddos_samples: Vec<TrainingSample>,
|
||||
pub stats: DatasetStats,
|
||||
}
|
||||
|
||||
/// Serialize a `DatasetManifest` to a bincode file.
|
||||
pub fn save_dataset(manifest: &DatasetManifest, path: &Path) -> Result<()> {
|
||||
let encoded = bincode::serialize(manifest).context("serializing dataset manifest")?;
|
||||
std::fs::write(path, &encoded)
|
||||
.with_context(|| format!("writing dataset to {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Deserialize a `DatasetManifest` from a bincode file.
|
||||
pub fn load_dataset(path: &Path) -> Result<DatasetManifest> {
|
||||
let data = std::fs::read(path)
|
||||
.with_context(|| format!("reading dataset from {}", path.display()))?;
|
||||
let manifest: DatasetManifest =
|
||||
bincode::deserialize(&data).context("deserializing dataset manifest")?;
|
||||
Ok(manifest)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_sample(features: Vec<f32>, label: f32, source: DataSource, weight: f32) -> TrainingSample {
|
||||
TrainingSample { features, label, source, weight }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bincode_roundtrip() {
|
||||
let manifest = DatasetManifest {
|
||||
scanner_samples: vec![
|
||||
make_sample(vec![0.1, 0.2, 0.3], 0.0, DataSource::ProductionLogs, 1.0),
|
||||
make_sample(vec![0.9, 0.8, 0.7], 1.0, DataSource::Csic2010, 0.8),
|
||||
],
|
||||
ddos_samples: vec![
|
||||
make_sample(vec![0.5; 14], 1.0, DataSource::SyntheticCicTiming, 0.5),
|
||||
make_sample(vec![0.1; 14], 0.0, DataSource::ProductionLogs, 1.0),
|
||||
],
|
||||
stats: DatasetStats {
|
||||
total_samples: 4,
|
||||
scanner_samples: 2,
|
||||
ddos_samples: 2,
|
||||
samples_by_source: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(DataSource::ProductionLogs, 2);
|
||||
m.insert(DataSource::Csic2010, 1);
|
||||
m.insert(DataSource::SyntheticCicTiming, 1);
|
||||
m
|
||||
},
|
||||
attack_ratio_scanner: 0.5,
|
||||
attack_ratio_ddos: 0.5,
|
||||
},
|
||||
};
|
||||
|
||||
let encoded = bincode::serialize(&manifest).unwrap();
|
||||
let decoded: DatasetManifest = bincode::deserialize(&encoded).unwrap();
|
||||
|
||||
assert_eq!(decoded.scanner_samples.len(), 2);
|
||||
assert_eq!(decoded.ddos_samples.len(), 2);
|
||||
assert_eq!(decoded.stats.total_samples, 4);
|
||||
assert_eq!(decoded.stats.samples_by_source[&DataSource::ProductionLogs], 2);
|
||||
|
||||
// Verify feature values survive the roundtrip.
|
||||
assert!((decoded.scanner_samples[0].features[0] - 0.1).abs() < 1e-6);
|
||||
assert_eq!(decoded.scanner_samples[1].label, 1.0);
|
||||
assert_eq!(decoded.ddos_samples[0].source, DataSource::SyntheticCicTiming);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_load_roundtrip() {
|
||||
let manifest = DatasetManifest {
|
||||
scanner_samples: vec![
|
||||
make_sample(vec![1.0, 2.0], 0.0, DataSource::ProductionLogs, 1.0),
|
||||
],
|
||||
ddos_samples: vec![
|
||||
make_sample(vec![3.0, 4.0], 1.0, DataSource::OwaspModSec, 0.8),
|
||||
],
|
||||
stats: DatasetStats {
|
||||
total_samples: 2,
|
||||
scanner_samples: 1,
|
||||
ddos_samples: 1,
|
||||
samples_by_source: HashMap::new(),
|
||||
attack_ratio_scanner: 0.0,
|
||||
attack_ratio_ddos: 1.0,
|
||||
},
|
||||
};
|
||||
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("test_dataset.bin");
|
||||
|
||||
save_dataset(&manifest, &path).unwrap();
|
||||
let loaded = load_dataset(&path).unwrap();
|
||||
|
||||
assert_eq!(loaded.scanner_samples.len(), 1);
|
||||
assert_eq!(loaded.ddos_samples.len(), 1);
|
||||
assert!((loaded.scanner_samples[0].features[0] - 1.0).abs() < 1e-6);
|
||||
assert_eq!(loaded.ddos_samples[0].source, DataSource::OwaspModSec);
|
||||
}
|
||||
}
|
||||
558
src/dataset/synthetic.rs
Normal file
558
src/dataset/synthetic.rs
Normal file
@@ -0,0 +1,558 @@
|
||||
//! Synthetic data generation for DDoS and scanner training samples.
|
||||
//!
|
||||
//! DDoS synthetic samples are generated by sampling from CIC-IDS2017 timing
|
||||
//! profiles. Scanner synthetic samples are generated from wordlist paths
|
||||
//! and realistic browser-like feature vectors.
|
||||
|
||||
use crate::dataset::cicids::TimingProfile;
|
||||
use crate::dataset::sample::{DataSource, TrainingSample};
|
||||
use crate::ddos::features::NUM_FEATURES;
|
||||
use crate::scanner::features::NUM_SCANNER_FEATURES;
|
||||
|
||||
use anyhow::Result;
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
|
||||
/// Configuration for synthetic sample generation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SyntheticConfig {
|
||||
pub num_ddos_attack: usize,
|
||||
pub num_ddos_normal: usize,
|
||||
pub num_scanner_attack: usize,
|
||||
pub num_scanner_normal: usize,
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
impl Default for SyntheticConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_ddos_attack: 1000,
|
||||
num_ddos_normal: 1000,
|
||||
num_scanner_attack: 1000,
|
||||
num_scanner_normal: 1000,
|
||||
seed: 42,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate synthetic DDoS training samples from CIC-IDS2017 timing profiles.
|
||||
///
|
||||
/// Produces 14-dimensional feature vectors matching `NUM_FEATURES`:
|
||||
/// 0: request_rate, 1: unique_paths, 2: unique_hosts, 3: error_rate,
|
||||
/// 4: avg_duration_ms, 5: method_entropy, 6: burst_score, 7: path_repetition,
|
||||
/// 8: avg_content_length, 9: unique_user_agents, 10: cookie_ratio,
|
||||
/// 11: referer_ratio, 12: accept_language_ratio, 13: suspicious_path_ratio
|
||||
pub fn generate_ddos_samples(
|
||||
profiles: &[TimingProfile],
|
||||
config: &SyntheticConfig,
|
||||
) -> Vec<TrainingSample> {
|
||||
let mut rng = StdRng::seed_from_u64(config.seed);
|
||||
let mut samples = Vec::with_capacity(config.num_ddos_attack + config.num_ddos_normal);
|
||||
|
||||
// Generate attack samples.
|
||||
for _ in 0..config.num_ddos_attack {
|
||||
let features = generate_ddos_attack_features(profiles, &mut rng);
|
||||
samples.push(TrainingSample {
|
||||
features,
|
||||
label: 1.0,
|
||||
source: DataSource::SyntheticCicTiming,
|
||||
weight: 0.5,
|
||||
});
|
||||
}
|
||||
|
||||
// Generate normal samples.
|
||||
for _ in 0..config.num_ddos_normal {
|
||||
let features = generate_ddos_normal_features(&mut rng);
|
||||
samples.push(TrainingSample {
|
||||
features,
|
||||
label: 0.0,
|
||||
source: DataSource::SyntheticCicTiming,
|
||||
weight: 0.5,
|
||||
});
|
||||
}
|
||||
|
||||
samples
|
||||
}
|
||||
|
||||
/// Generate a single DDoS attack feature vector, sampling from timing profiles.
|
||||
fn generate_ddos_attack_features(profiles: &[TimingProfile], rng: &mut StdRng) -> Vec<f32> {
|
||||
let mut features = vec![0.0f32; NUM_FEATURES];
|
||||
|
||||
// Sample from a random profile if available, otherwise use defaults.
|
||||
let (iat_mean, burst_mean, bps_mean) = if !profiles.is_empty() {
|
||||
let profile = &profiles[rng.random_range(0..profiles.len())];
|
||||
let iat = sample_positive_normal(
|
||||
rng,
|
||||
profile.inter_arrival_mean,
|
||||
profile.inter_arrival_std,
|
||||
);
|
||||
let burst = sample_positive_normal(
|
||||
rng,
|
||||
profile.burst_duration_mean,
|
||||
profile.burst_duration_std,
|
||||
);
|
||||
let bps = sample_positive_normal(
|
||||
rng,
|
||||
profile.flow_bytes_per_sec_mean,
|
||||
profile.flow_bytes_per_sec_std,
|
||||
);
|
||||
(iat, burst, bps)
|
||||
} else {
|
||||
// Fallback: aggressive attack defaults.
|
||||
let iat = rng.random_range(0.001..0.05);
|
||||
let burst = rng.random_range(0.5..5.0);
|
||||
let bps = rng.random_range(10000.0..100000.0);
|
||||
(iat, burst, bps)
|
||||
};
|
||||
|
||||
// 0: request_rate — high for attacks.
|
||||
features[0] = if iat_mean > 0.0 {
|
||||
(1.0 / iat_mean).min(1000.0) as f32
|
||||
} else {
|
||||
rng.random_range(50.0..500.0) as f32
|
||||
};
|
||||
|
||||
// 1: unique_paths — low to moderate (attackers repeat paths).
|
||||
features[1] = rng.random_range(1.0..10.0) as f32;
|
||||
|
||||
// 2: unique_hosts — typically 1 for DDoS.
|
||||
features[2] = rng.random_range(1.0..3.0) as f32;
|
||||
|
||||
// 3: error_rate — moderate to high.
|
||||
features[3] = rng.random_range(0.3..0.9) as f32;
|
||||
|
||||
// 4: avg_duration_ms — derived from burst duration.
|
||||
features[4] = (burst_mean * 100.0).max(1.0).min(5000.0) as f32;
|
||||
|
||||
// 5: method_entropy — low (mostly GET).
|
||||
features[5] = rng.random_range(0.0..0.3) as f32;
|
||||
|
||||
// 6: burst_score — high (inverse of inter-arrival).
|
||||
features[6] = if iat_mean > 0.0 {
|
||||
(1.0 / iat_mean).min(500.0) as f32
|
||||
} else {
|
||||
rng.random_range(10.0..200.0) as f32
|
||||
};
|
||||
|
||||
// 7: path_repetition — high (attackers repeat same paths).
|
||||
features[7] = rng.random_range(0.6..1.0) as f32;
|
||||
|
||||
// 8: avg_content_length — derived from flow bytes.
|
||||
features[8] = (bps_mean * 0.01).max(0.0).min(10000.0) as f32;
|
||||
|
||||
// 9: unique_user_agents — low (1-2 UAs).
|
||||
features[9] = rng.random_range(1.0..3.0) as f32;
|
||||
|
||||
// 10: cookie_ratio — very low (bots don't send cookies).
|
||||
features[10] = rng.random_range(0.0..0.1) as f32;
|
||||
|
||||
// 11: referer_ratio — very low.
|
||||
features[11] = rng.random_range(0.0..0.1) as f32;
|
||||
|
||||
// 12: accept_language_ratio — very low.
|
||||
features[12] = rng.random_range(0.0..0.1) as f32;
|
||||
|
||||
// 13: suspicious_path_ratio — moderate to high.
|
||||
features[13] = rng.random_range(0.1..0.7) as f32;
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Generate a single DDoS normal feature vector.
|
||||
fn generate_ddos_normal_features(rng: &mut StdRng) -> Vec<f32> {
|
||||
let mut features = vec![0.0f32; NUM_FEATURES];
|
||||
|
||||
// 0: request_rate — moderate.
|
||||
features[0] = rng.random_range(0.1..5.0) as f32;
|
||||
// 1: unique_paths — moderate.
|
||||
features[1] = rng.random_range(3.0..30.0) as f32;
|
||||
// 2: unique_hosts — typically 1-3.
|
||||
features[2] = rng.random_range(1.0..5.0) as f32;
|
||||
// 3: error_rate — low.
|
||||
features[3] = rng.random_range(0.0..0.15) as f32;
|
||||
// 4: avg_duration_ms — reasonable.
|
||||
features[4] = rng.random_range(10.0..500.0) as f32;
|
||||
// 5: method_entropy — moderate (mix of GET/POST).
|
||||
features[5] = rng.random_range(0.0..1.5) as f32;
|
||||
// 6: burst_score — low.
|
||||
features[6] = rng.random_range(0.05..2.0) as f32;
|
||||
// 7: path_repetition — low to moderate.
|
||||
features[7] = rng.random_range(0.05..0.5) as f32;
|
||||
// 8: avg_content_length — reasonable.
|
||||
features[8] = rng.random_range(0.0..2000.0) as f32;
|
||||
// 9: unique_user_agents — 1-3 (real users have few UAs).
|
||||
features[9] = rng.random_range(1.0..4.0) as f32;
|
||||
// 10: cookie_ratio — high (real users have cookies).
|
||||
features[10] = rng.random_range(0.7..1.0) as f32;
|
||||
// 11: referer_ratio — moderate to high.
|
||||
features[11] = rng.random_range(0.4..1.0) as f32;
|
||||
// 12: accept_language_ratio — high.
|
||||
features[12] = rng.random_range(0.7..1.0) as f32;
|
||||
// 13: suspicious_path_ratio — very low.
|
||||
features[13] = rng.random_range(0.0..0.05) as f32;
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Generate synthetic scanner training samples.
|
||||
///
|
||||
/// Attack samples use paths from wordlist files (if provided) or built-in
|
||||
/// scanner paths. Normal samples have realistic browser-like feature vectors.
|
||||
///
|
||||
/// Produces 12-dimensional feature vectors matching `NUM_SCANNER_FEATURES`:
|
||||
/// 0: suspicious_path_score, 1: path_depth, 2: has_suspicious_extension,
|
||||
/// 3: has_cookies, 4: has_referer, 5: has_accept_language,
|
||||
/// 6: accept_quality, 7: ua_category, 8: method_is_unusual,
|
||||
/// 9: host_is_configured, 10: content_length_mismatch, 11: path_has_traversal
|
||||
pub fn generate_scanner_samples(
|
||||
wordlist_dir: Option<&Path>,
|
||||
_production_logs: Option<&Path>,
|
||||
config: &SyntheticConfig,
|
||||
) -> Result<Vec<TrainingSample>> {
|
||||
let mut rng = StdRng::seed_from_u64(config.seed.wrapping_add(1));
|
||||
let mut samples = Vec::with_capacity(config.num_scanner_attack + config.num_scanner_normal);
|
||||
|
||||
// Load wordlist paths if provided.
|
||||
let attack_paths = load_wordlist_paths(wordlist_dir)?;
|
||||
|
||||
// Generate attack samples.
|
||||
for i in 0..config.num_scanner_attack {
|
||||
let features = if !attack_paths.is_empty() {
|
||||
let path = &attack_paths[i % attack_paths.len()];
|
||||
generate_scanner_attack_features_from_path(path, &mut rng)
|
||||
} else {
|
||||
generate_scanner_attack_features(&mut rng)
|
||||
};
|
||||
samples.push(TrainingSample {
|
||||
features,
|
||||
label: 1.0,
|
||||
source: DataSource::SyntheticWordlist,
|
||||
weight: 0.5,
|
||||
});
|
||||
}
|
||||
|
||||
// Generate normal samples.
|
||||
for _ in 0..config.num_scanner_normal {
|
||||
let features = generate_scanner_normal_features(&mut rng);
|
||||
samples.push(TrainingSample {
|
||||
features,
|
||||
label: 0.0,
|
||||
source: DataSource::SyntheticWordlist,
|
||||
weight: 0.5,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(samples)
|
||||
}
|
||||
|
||||
/// Load paths from wordlist files in a directory (or single file).
|
||||
fn load_wordlist_paths(dir: Option<&Path>) -> Result<Vec<String>> {
|
||||
let dir = match dir {
|
||||
Some(d) => d,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let files: Vec<std::path::PathBuf> = if dir.is_file() {
|
||||
vec![dir.to_path_buf()]
|
||||
} else if dir.exists() {
|
||||
std::fs::read_dir(dir)?
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.path())
|
||||
.filter(|p| p.extension().map(|e| e == "txt").unwrap_or(false))
|
||||
.collect()
|
||||
} else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
|
||||
let mut paths = Vec::new();
|
||||
for file in &files {
|
||||
let f = std::fs::File::open(file)?;
|
||||
let reader = std::io::BufReader::new(f);
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let line = line.trim().to_string();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
let normalized = if line.starts_with('/') {
|
||||
line
|
||||
} else {
|
||||
format!("/{line}")
|
||||
};
|
||||
paths.push(normalized);
|
||||
}
|
||||
}
|
||||
Ok(paths)
|
||||
}
|
||||
|
||||
/// Generate scanner attack features from a specific path.
|
||||
fn generate_scanner_attack_features_from_path(path: &str, rng: &mut StdRng) -> Vec<f32> {
|
||||
let mut features = vec![0.0f32; NUM_SCANNER_FEATURES];
|
||||
|
||||
let lower = path.to_ascii_lowercase();
|
||||
|
||||
// 0: suspicious_path_score — check for known bad fragments.
|
||||
let suspicious_frags = [
|
||||
".env", "wp-admin", "wp-login", "phpinfo", "phpmyadmin",
|
||||
".git", "cgi-bin", "shell", "admin", "config",
|
||||
];
|
||||
let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
|
||||
let seg_count = segments.len().max(1) as f32;
|
||||
let matches = segments
|
||||
.iter()
|
||||
.filter(|s| {
|
||||
let sl = s.to_ascii_lowercase();
|
||||
suspicious_frags.iter().any(|f| sl.contains(f))
|
||||
})
|
||||
.count() as f32;
|
||||
features[0] = matches / seg_count;
|
||||
|
||||
// 1: path_depth.
|
||||
features[1] = path.bytes().filter(|&b| b == b'/').count().min(20) as f32;
|
||||
|
||||
// 2: has_suspicious_extension.
|
||||
let suspicious_exts = [".php", ".env", ".sql", ".bak", ".asp", ".jsp", ".cgi", ".tar", ".zip", ".git"];
|
||||
features[2] = if suspicious_exts.iter().any(|e| lower.ends_with(e)) {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// 3: has_cookies — scanners typically don't.
|
||||
features[3] = if rng.random_bool(0.05) { 1.0 } else { 0.0 };
|
||||
|
||||
// 4: has_referer — scanners rarely send referer.
|
||||
features[4] = if rng.random_bool(0.05) { 1.0 } else { 0.0 };
|
||||
|
||||
// 5: has_accept_language — scanners rarely send this.
|
||||
features[5] = if rng.random_bool(0.1) { 1.0 } else { 0.0 };
|
||||
|
||||
// 6: accept_quality — scanners use */* or empty.
|
||||
features[6] = 0.0;
|
||||
|
||||
// 7: ua_category — mix of curl/empty/bot UAs.
|
||||
let ua_roll: f64 = rng.random();
|
||||
features[7] = if ua_roll < 0.3 {
|
||||
0.0 // empty
|
||||
} else if ua_roll < 0.6 {
|
||||
0.25 // curl/wget
|
||||
} else {
|
||||
0.5 // random bot
|
||||
};
|
||||
|
||||
// 8: method_is_unusual — mostly GET.
|
||||
features[8] = if rng.random_bool(0.05) { 1.0 } else { 0.0 };
|
||||
|
||||
// 9: host_is_configured — often unknown host.
|
||||
features[9] = if rng.random_bool(0.2) { 1.0 } else { 0.0 };
|
||||
|
||||
// 10: content_length_mismatch.
|
||||
features[10] = if rng.random_bool(0.1) { 1.0 } else { 0.0 };
|
||||
|
||||
// 11: path_has_traversal.
|
||||
let traversal_patterns = ["..", "%00", "%0a", "%27", "%3c"];
|
||||
features[11] = if traversal_patterns.iter().any(|p| lower.contains(p)) {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Generate scanner attack features without a specific path (built-in defaults).
|
||||
fn generate_scanner_attack_features(rng: &mut StdRng) -> Vec<f32> {
|
||||
let mut features = vec![0.0f32; NUM_SCANNER_FEATURES];
|
||||
|
||||
features[0] = rng.random_range(0.3..1.0) as f32; // suspicious_path_score
|
||||
features[1] = rng.random_range(2.0..8.0) as f32; // path_depth
|
||||
features[2] = if rng.random_bool(0.6) { 1.0 } else { 0.0 }; // suspicious_ext
|
||||
features[3] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; // cookies
|
||||
features[4] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; // referer
|
||||
features[5] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // accept_language
|
||||
features[6] = 0.0; // accept_quality
|
||||
features[7] = rng.random_range(0.0..0.5) as f32; // ua_category
|
||||
features[8] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // unusual method
|
||||
features[9] = if rng.random_bool(0.2) { 1.0 } else { 0.0 }; // host_configured
|
||||
features[10] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // content_len mismatch
|
||||
features[11] = if rng.random_bool(0.15) { 1.0 } else { 0.0 }; // traversal
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Generate normal scanner features (browser-like).
|
||||
fn generate_scanner_normal_features(rng: &mut StdRng) -> Vec<f32> {
|
||||
let mut features = vec![0.0f32; NUM_SCANNER_FEATURES];
|
||||
|
||||
features[0] = 0.0; // suspicious_path_score — clean path
|
||||
features[1] = rng.random_range(1.0..4.0) as f32; // path_depth — moderate
|
||||
features[2] = 0.0; // no suspicious extension
|
||||
features[3] = if rng.random_bool(0.85) { 1.0 } else { 0.0 }; // has_cookies — usually yes
|
||||
features[4] = if rng.random_bool(0.7) { 1.0 } else { 0.0 }; // has_referer — often
|
||||
features[5] = if rng.random_bool(0.9) { 1.0 } else { 0.0 }; // has_accept_language
|
||||
features[6] = 1.0; // accept_quality — browsers send proper accept
|
||||
features[7] = 1.0; // ua_category — browser UA
|
||||
features[8] = 0.0; // method_is_unusual — GET/POST
|
||||
features[9] = if rng.random_bool(0.95) { 1.0 } else { 0.0 }; // host_is_configured
|
||||
features[10] = 0.0; // no content_length_mismatch
|
||||
features[11] = 0.0; // no traversal
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
/// Sample from a normal distribution, clamped to positive values.
|
||||
fn sample_positive_normal(rng: &mut StdRng, mean: f64, std_dev: f64) -> f64 {
|
||||
if std_dev <= 0.0 {
|
||||
return mean.max(0.0);
|
||||
}
|
||||
// Box-Muller transform.
|
||||
let u1: f64 = rng.random::<f64>().max(1e-10);
|
||||
let u2: f64 = rng.random();
|
||||
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
|
||||
(mean + std_dev * z).max(0.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ddos_feature_dimensions() {
|
||||
let profiles = vec![TimingProfile {
|
||||
attack_type: "DDoS".to_string(),
|
||||
inter_arrival_mean: 0.01,
|
||||
inter_arrival_std: 0.005,
|
||||
burst_duration_mean: 2.0,
|
||||
burst_duration_std: 1.0,
|
||||
flow_bytes_per_sec_mean: 50000.0,
|
||||
flow_bytes_per_sec_std: 20000.0,
|
||||
sample_count: 100,
|
||||
}];
|
||||
|
||||
let config = SyntheticConfig {
|
||||
num_ddos_attack: 10,
|
||||
num_ddos_normal: 10,
|
||||
seed: 42,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let samples = generate_ddos_samples(&profiles, &config);
|
||||
assert_eq!(samples.len(), 20);
|
||||
for s in &samples {
|
||||
assert_eq!(
|
||||
s.features.len(),
|
||||
NUM_FEATURES,
|
||||
"DDoS features should have {} dimensions",
|
||||
NUM_FEATURES
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scanner_feature_dimensions() {
|
||||
let config = SyntheticConfig {
|
||||
num_scanner_attack: 10,
|
||||
num_scanner_normal: 10,
|
||||
seed: 42,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let samples = generate_scanner_samples(None, None, &config).unwrap();
|
||||
assert_eq!(samples.len(), 20);
|
||||
for s in &samples {
|
||||
assert_eq!(
|
||||
s.features.len(),
|
||||
NUM_SCANNER_FEATURES,
|
||||
"scanner features should have {} dimensions",
|
||||
NUM_SCANNER_FEATURES
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deterministic_with_fixed_seed() {
|
||||
let config = SyntheticConfig {
|
||||
num_ddos_attack: 5,
|
||||
num_ddos_normal: 5,
|
||||
num_scanner_attack: 5,
|
||||
num_scanner_normal: 5,
|
||||
seed: 123,
|
||||
};
|
||||
|
||||
let ddos_a = generate_ddos_samples(&[], &config);
|
||||
let ddos_b = generate_ddos_samples(&[], &config);
|
||||
|
||||
for (a, b) in ddos_a.iter().zip(ddos_b.iter()) {
|
||||
assert_eq!(a.features, b.features, "DDoS samples should be deterministic");
|
||||
assert_eq!(a.label, b.label);
|
||||
}
|
||||
|
||||
let scanner_a = generate_scanner_samples(None, None, &config).unwrap();
|
||||
let scanner_b = generate_scanner_samples(None, None, &config).unwrap();
|
||||
|
||||
for (a, b) in scanner_a.iter().zip(scanner_b.iter()) {
|
||||
assert_eq!(a.features, b.features, "scanner samples should be deterministic");
|
||||
assert_eq!(a.label, b.label);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ddos_attack_vs_normal_distinguishable() {
|
||||
let config = SyntheticConfig {
|
||||
num_ddos_attack: 100,
|
||||
num_ddos_normal: 100,
|
||||
seed: 42,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let samples = generate_ddos_samples(&[], &config);
|
||||
let attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label > 0.5).collect();
|
||||
let normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label <= 0.5).collect();
|
||||
|
||||
// Attack samples should have higher avg request_rate (feature 0).
|
||||
let avg_attack_rate: f32 =
|
||||
attacks.iter().map(|s| s.features[0]).sum::<f32>() / attacks.len() as f32;
|
||||
let avg_normal_rate: f32 =
|
||||
normals.iter().map(|s| s.features[0]).sum::<f32>() / normals.len() as f32;
|
||||
assert!(
|
||||
avg_attack_rate > avg_normal_rate,
|
||||
"attack rate {} should exceed normal rate {}",
|
||||
avg_attack_rate,
|
||||
avg_normal_rate
|
||||
);
|
||||
|
||||
// Attack samples should have lower avg cookie_ratio (feature 10).
|
||||
let avg_attack_cookies: f32 =
|
||||
attacks.iter().map(|s| s.features[10]).sum::<f32>() / attacks.len() as f32;
|
||||
let avg_normal_cookies: f32 =
|
||||
normals.iter().map(|s| s.features[10]).sum::<f32>() / normals.len() as f32;
|
||||
assert!(
|
||||
avg_attack_cookies < avg_normal_cookies,
|
||||
"attack cookies {} should be lower than normal cookies {}",
|
||||
avg_attack_cookies,
|
||||
avg_normal_cookies
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_positive_normal_stays_positive() {
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
for _ in 0..1000 {
|
||||
let val = sample_positive_normal(&mut rng, 1.0, 5.0);
|
||||
assert!(val >= 0.0, "value should be non-negative: {val}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scanner_attack_from_wordlist_path() {
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let features = generate_scanner_attack_features_from_path("/.env", &mut rng);
|
||||
assert_eq!(features.len(), NUM_SCANNER_FEATURES);
|
||||
// .env should trigger suspicious_path_score > 0.
|
||||
assert!(features[0] > 0.0, "suspicious_path_score should be positive for /.env");
|
||||
// .env should trigger suspicious_extension.
|
||||
assert_eq!(features[2], 1.0, "should detect .env as suspicious extension");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user