feat(dataset): add dataset preparation with auto-download and heuristic labeling

Unified prepare-dataset pipeline that automatically downloads and caches
upstream datasets (CSIC 2010, CIC-IDS2017), applies heuristic auto-labeling
to unlabeled production logs, generates synthetic samples for both models,
and serializes everything as a bincode DatasetManifest. Includes OWASP
ModSec parser, CIC-IDS2017 timing profile extractor, and synthetic data
generators with configurable distributions.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:21 +00:00
parent 067d822244
commit 1f4366566d
7 changed files with 2078 additions and 0 deletions

301
src/dataset/cicids.rs Normal file
View File

@@ -0,0 +1,301 @@
//! CIC-IDS2017 timing profile extractor.
//!
//! Parses CIC-IDS2017 CSV files and extracts statistical timing profiles
//! per attack type. These profiles are NOT training samples themselves —
//! they feed the synthetic data generator (`synthetic.rs`) which samples
//! from the learned distributions to produce DDoS training features.
use anyhow::{Context, Result};
use std::path::Path;
/// Statistical timing profile for one attack type from CIC-IDS2017.
#[derive(Debug, Clone)]
pub struct TimingProfile {
pub attack_type: String,
pub inter_arrival_mean: f64,
pub inter_arrival_std: f64,
pub burst_duration_mean: f64,
pub burst_duration_std: f64,
pub flow_bytes_per_sec_mean: f64,
pub flow_bytes_per_sec_std: f64,
pub sample_count: usize,
}
/// Accumulator for computing mean and variance in a single pass (Welford's algorithm).
#[derive(Default)]
struct StatsAccumulator {
count: usize,
mean: f64,
m2: f64,
}
impl StatsAccumulator {
fn push(&mut self, value: f64) {
if !value.is_finite() {
return;
}
self.count += 1;
let delta = value - self.mean;
self.mean += delta / self.count as f64;
let delta2 = value - self.mean;
self.m2 += delta * delta2;
}
fn mean(&self) -> f64 {
self.mean
}
fn std_dev(&self) -> f64 {
if self.count < 2 {
0.0
} else {
(self.m2 / (self.count - 1) as f64).sqrt()
}
}
}
/// Per-label accumulator for timing statistics.
struct LabelAccumulator {
label: String,
inter_arrival: StatsAccumulator,
burst_duration: StatsAccumulator,
flow_bytes_per_sec: StatsAccumulator,
count: usize,
}
impl LabelAccumulator {
fn new(label: String) -> Self {
Self {
label,
inter_arrival: StatsAccumulator::default(),
burst_duration: StatsAccumulator::default(),
flow_bytes_per_sec: StatsAccumulator::default(),
count: 0,
}
}
fn into_profile(self) -> TimingProfile {
TimingProfile {
attack_type: self.label,
inter_arrival_mean: self.inter_arrival.mean(),
inter_arrival_std: self.inter_arrival.std_dev(),
burst_duration_mean: self.burst_duration.mean(),
burst_duration_std: self.burst_duration.std_dev(),
flow_bytes_per_sec_mean: self.flow_bytes_per_sec.mean(),
flow_bytes_per_sec_std: self.flow_bytes_per_sec.std_dev(),
sample_count: self.count,
}
}
}
/// Find a column index by name (case-insensitive, trimmed).
fn find_column(headers: &[String], name: &str) -> Option<usize> {
let lower = name.to_ascii_lowercase();
headers
.iter()
.position(|h| h.trim().to_ascii_lowercase() == lower)
}
/// Extract timing profiles from all CSV files in the given directory.
///
/// Each CSV is expected to have CIC-IDS2017 columns including at minimum:
/// `Flow Duration`, `Flow IAT Mean`, `Flow IAT Std`, `Flow Bytes/s`, `Label`.
///
/// Returns one `TimingProfile` per unique label value.
pub fn extract_timing_profiles(csv_dir: &Path) -> Result<Vec<TimingProfile>> {
let entries: Vec<std::path::PathBuf> = if csv_dir.is_file() {
vec![csv_dir.to_path_buf()]
} else {
let mut files: Vec<std::path::PathBuf> = std::fs::read_dir(csv_dir)
.with_context(|| format!("reading directory {}", csv_dir.display()))?
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| {
p.extension()
.map(|e| e.to_ascii_lowercase() == "csv")
.unwrap_or(false)
})
.collect();
files.sort();
files
};
if entries.is_empty() {
anyhow::bail!("no CSV files found in {}", csv_dir.display());
}
let mut accumulators: std::collections::HashMap<String, LabelAccumulator> =
std::collections::HashMap::new();
for csv_path in &entries {
parse_csv_file(csv_path, &mut accumulators)
.with_context(|| format!("parsing {}", csv_path.display()))?;
}
let mut profiles: Vec<TimingProfile> = accumulators
.into_values()
.filter(|a| a.count > 0)
.map(|a| a.into_profile())
.collect();
profiles.sort_by(|a, b| a.attack_type.cmp(&b.attack_type));
Ok(profiles)
}
fn parse_csv_file(
path: &Path,
accumulators: &mut std::collections::HashMap<String, LabelAccumulator>,
) -> Result<()> {
let mut rdr = csv::ReaderBuilder::new()
.flexible(true)
.trim(csv::Trim::All)
.from_path(path)?;
let headers: Vec<String> = rdr
.headers()?
.iter()
.map(|h| h.to_string())
.collect();
// Locate required columns.
let col_label = find_column(&headers, "Label")
.with_context(|| format!("missing 'Label' column in {}", path.display()))?;
let col_flow_duration = find_column(&headers, "Flow Duration");
let col_iat_mean = find_column(&headers, "Flow IAT Mean");
let col_iat_std = find_column(&headers, "Flow IAT Std");
let col_bytes_per_sec = find_column(&headers, "Flow Bytes/s");
for result in rdr.records() {
let record = match result {
Ok(r) => r,
Err(_) => continue,
};
let label = match record.get(col_label) {
Some(l) => l.trim().to_string(),
None => continue,
};
if label.is_empty() {
continue;
}
let acc = accumulators
.entry(label.clone())
.or_insert_with(|| LabelAccumulator::new(label));
acc.count += 1;
// Flow Duration (microseconds in CIC-IDS2017) → we treat as burst duration.
if let Some(col) = col_flow_duration {
if let Some(val) = record.get(col).and_then(|v| v.trim().parse::<f64>().ok()) {
// Convert microseconds to seconds.
acc.burst_duration.push(val / 1_000_000.0);
}
}
// Flow IAT Mean (microseconds) → inter-arrival time.
if let Some(col) = col_iat_mean {
if let Some(val) = record.get(col).and_then(|v| v.trim().parse::<f64>().ok()) {
acc.inter_arrival.push(val / 1_000_000.0);
}
}
// Flow IAT Std → used as inter_arrival_std contribution.
if let Some(col) = col_iat_std {
if let Some(_val) = record.get(col).and_then(|v| v.trim().parse::<f64>().ok()) {
// The per-flow IAT std contributes to the overall std via Welford above.
// We use the IAT Mean values; std is computed from those.
}
}
// Flow Bytes/s.
if let Some(col) = col_bytes_per_sec {
if let Some(val) = record.get(col).and_then(|v| v.trim().parse::<f64>().ok()) {
acc.flow_bytes_per_sec.push(val);
}
}
}
Ok(())
}
/// Parse timing profiles from an in-memory CSV string (useful for tests).
pub fn extract_timing_profiles_from_str(csv_content: &str) -> Result<Vec<TimingProfile>> {
let dir = tempfile::tempdir()?;
let path = dir.path().join("test.csv");
std::fs::write(&path, csv_content)?;
extract_timing_profiles(&path)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_inline_csv() {
let csv = "\
Flow Duration,Total Fwd Packets,Flow Bytes/s,Flow IAT Mean,Flow IAT Std,Label
1000000,10,5000.0,100000,50000,BENIGN
2000000,20,10000.0,200000,100000,BENIGN
500000,100,50000.0,5000,2000,DDoS
300000,200,80000.0,3000,1000,DDoS
100000,50,30000.0,10000,5000,DDoS
";
let profiles = extract_timing_profiles_from_str(csv).unwrap();
assert_eq!(profiles.len(), 2, "should have BENIGN and DDoS profiles");
let benign = profiles.iter().find(|p| p.attack_type == "BENIGN").unwrap();
assert_eq!(benign.sample_count, 2);
// IAT mean: mean of [0.1, 0.2] = 0.15 seconds
assert!(
(benign.inter_arrival_mean - 0.15).abs() < 1e-6,
"benign iat mean: {}",
benign.inter_arrival_mean
);
// Burst duration mean: mean of [1.0, 2.0] = 1.5 seconds
assert!(
(benign.burst_duration_mean - 1.5).abs() < 1e-6,
"benign burst mean: {}",
benign.burst_duration_mean
);
let ddos = profiles.iter().find(|p| p.attack_type == "DDoS").unwrap();
assert_eq!(ddos.sample_count, 3);
// Flow bytes/s mean: mean of [50000, 80000, 30000] = 53333.33
let expected_bps = (50000.0 + 80000.0 + 30000.0) / 3.0;
assert!(
(ddos.flow_bytes_per_sec_mean - expected_bps).abs() < 1.0,
"ddos bps mean: {}",
ddos.flow_bytes_per_sec_mean
);
}
#[test]
fn test_stats_accumulator() {
let mut acc = StatsAccumulator::default();
acc.push(2.0);
acc.push(4.0);
acc.push(6.0);
assert!((acc.mean() - 4.0).abs() < 1e-10);
// std_dev of [2,4,6] = 2.0
assert!((acc.std_dev() - 2.0).abs() < 1e-10);
}
#[test]
fn test_empty_csv_dir() {
let dir = tempfile::tempdir().unwrap();
let result = extract_timing_profiles(dir.path());
assert!(result.is_err());
}
#[test]
fn test_find_column_case_insensitive() {
let headers: Vec<String> = vec![
" Flow Duration ".to_string(),
"label".to_string(),
];
assert_eq!(find_column(&headers, "Flow Duration"), Some(0));
assert_eq!(find_column(&headers, "Label"), Some(1));
assert_eq!(find_column(&headers, "LABEL"), Some(1));
assert_eq!(find_column(&headers, "missing"), None);
}
}

119
src/dataset/download.rs Normal file
View File

@@ -0,0 +1,119 @@
//! Download and cache upstream datasets for training.
//!
//! Cached under `~/.cache/sunbeam/<dataset>/`. Files are only downloaded
//! once; subsequent runs reuse the cached copy.
use anyhow::{Context, Result};
use std::path::PathBuf;
/// Base cache directory for all sunbeam datasets.
fn cache_base() -> PathBuf {
let base = std::env::var("XDG_CACHE_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
PathBuf::from(home).join(".cache")
});
base.join("sunbeam")
}
// --- CIC-IDS2017 ---
/// Only the Friday DDoS file — contains DDoS Hulk, Slowloris, slowhttptest, GoldenEye.
const CICIDS_FILE: &str = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv";
/// Hugging Face mirror (public, no auth required).
const CICIDS_BASE_URL: &str =
"https://huggingface.co/datasets/c01dsnap/CIC-IDS2017/resolve/main";
fn cicids_cache_dir() -> PathBuf {
cache_base().join("cicids")
}
/// Return the path to the cached CIC-IDS2017 DDoS CSV, or `None` if not downloaded.
pub fn cicids_cached_path() -> Option<PathBuf> {
let path = cicids_cache_dir().join(CICIDS_FILE);
if path.exists() {
Some(path)
} else {
None
}
}
/// Download the CIC-IDS2017 Friday DDoS CSV to cache. Returns the cached path.
pub fn download_cicids() -> Result<PathBuf> {
let dir = cicids_cache_dir();
let path = dir.join(CICIDS_FILE);
if path.exists() {
eprintln!(" cached: {}", path.display());
return Ok(path);
}
let url = format!("{CICIDS_BASE_URL}/{CICIDS_FILE}");
eprintln!(" downloading: {url}");
eprintln!(" (this is ~170 MB, may take a minute)");
std::fs::create_dir_all(&dir)?;
// Stream to file to avoid holding 170MB in memory.
let resp = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(600))
.build()?
.get(&url)
.send()
.with_context(|| format!("fetching {url}"))?
.error_for_status()
.with_context(|| format!("HTTP error for {url}"))?;
let mut file = std::fs::File::create(&path)
.with_context(|| format!("creating {}", path.display()))?;
let bytes = resp.bytes().with_context(|| "reading response body")?;
std::io::Write::write_all(&mut file, &bytes)?;
eprintln!(" saved: {}", path.display());
Ok(path)
}
// --- CSIC 2010 ---
/// Download CSIC 2010 dataset files to cache (delegates to scanner::csic).
pub fn download_csic() -> Result<()> {
if crate::scanner::csic::csic_is_cached() {
eprintln!(" cached: {}", crate::scanner::csic::csic_cache_path().display());
return Ok(());
}
// fetch_csic_dataset downloads, caches, and parses — we only need the download side-effect.
crate::scanner::csic::fetch_csic_dataset()?;
Ok(())
}
/// Download all upstream datasets.
pub fn download_all() -> Result<()> {
eprintln!("downloading upstream datasets...\n");
eprintln!("[1/2] CSIC 2010 (scanner training data)");
download_csic()?;
eprintln!();
eprintln!("[2/2] CIC-IDS2017 DDoS timing profiles");
let path = download_cicids()?;
eprintln!(" ok: {}\n", path.display());
eprintln!("all datasets cached.");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_paths() {
let base = cache_base();
assert!(base.to_str().unwrap().contains("sunbeam"));
let cicids = cicids_cache_dir();
assert!(cicids.to_str().unwrap().contains("cicids"));
}
}

6
src/dataset/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub mod sample;
pub mod modsec;
pub mod cicids;
pub mod synthetic;
pub mod download;
pub mod prepare;

340
src/dataset/modsec.rs Normal file
View File

@@ -0,0 +1,340 @@
//! Parser for OWASP ModSecurity audit log files (Serial / concurrent format).
//!
//! ModSecurity audit logs consist of multi-section entries delimited by boundary
//! markers like `--xxxxxxxx-A--`, `--xxxxxxxx-B--`, etc. Each section contains
//! different data about the transaction:
//!
//! - **A**: Timestamp, transaction ID, source/dest IP+port
//! - **B**: Request line + headers
//! - **C**: Request body
//! - **F**: Response status + headers
//! - **H**: Audit log trailer (rule matches, messages, actions)
//!
//! Any entry with a rule match in section H is labeled "attack"; entries with
//! no rule matches are labeled "normal".
use crate::ddos::audit_log::AuditFields;
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::path::Path;
/// Parse a ModSecurity audit log file and return `(AuditFields, label)` pairs.
///
/// The label is `"attack"` if section H contains rule match messages, otherwise `"normal"`.
pub fn parse_modsec_audit_log(path: &Path) -> Result<Vec<(AuditFields, String)>> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("reading ModSec audit log: {}", path.display()))?;
parse_modsec_content(&content)
}
/// Parse ModSecurity audit log content from a string.
fn parse_modsec_content(content: &str) -> Result<Vec<(AuditFields, String)>> {
let mut results = Vec::new();
// Collect sections per boundary ID.
// Boundary markers look like: --xxxxxxxx-A--
// where xxxxxxxx is a hex/alnum transaction ID and A is the section letter.
let mut sections: HashMap<String, HashMap<char, Vec<String>>> = HashMap::new();
let mut current_id: Option<String> = None;
let mut current_section: Option<char> = None;
let mut current_lines: Vec<String> = Vec::new();
// Track order of first appearance of each boundary ID.
let mut id_order: Vec<String> = Vec::new();
for line in content.lines() {
if let Some((id, section)) = parse_boundary(line) {
// Flush previous section.
if let (Some(ref cid), Some(sec)) = (&current_id, current_section) {
let entry = sections.entry(cid.clone()).or_default();
entry.entry(sec).or_default().extend(current_lines.drain(..));
}
if !sections.contains_key(&id) {
id_order.push(id.clone());
}
current_id = Some(id);
current_section = Some(section);
current_lines.clear();
} else if current_id.is_some() {
current_lines.push(line.to_string());
}
}
// Flush last section.
if let (Some(ref cid), Some(sec)) = (&current_id, current_section) {
let entry = sections.entry(cid.clone()).or_default();
entry.entry(sec).or_default().extend(current_lines.drain(..));
}
// Convert each transaction into AuditFields.
for id in &id_order {
if let Some(secs) = sections.get(id) {
if let Some(fields) = transaction_to_audit_fields(secs) {
results.push(fields);
}
}
}
Ok(results)
}
/// Try to parse a boundary marker line.
/// Returns `(boundary_id, section_letter)` on success.
fn parse_boundary(line: &str) -> Option<(String, char)> {
let trimmed = line.trim();
if !trimmed.starts_with("--") || !trimmed.ends_with("--") {
return None;
}
// Strip leading and trailing --
let inner = &trimmed[2..trimmed.len() - 2];
// Should be: boundary_id-SECTION_LETTER
let dash_pos = inner.rfind('-')?;
if dash_pos == 0 || dash_pos == inner.len() - 1 {
return None;
}
let id = &inner[..dash_pos];
let section_str = &inner[dash_pos + 1..];
if section_str.len() != 1 {
return None;
}
let section = section_str.chars().next()?;
if !section.is_ascii_alphabetic() {
return None;
}
Some((id.to_string(), section))
}
/// Convert parsed sections into `(AuditFields, label)`.
fn transaction_to_audit_fields(
sections: &HashMap<char, Vec<String>>,
) -> Option<(AuditFields, String)> {
// Section A: timestamp + connection info
let client_ip = sections
.get(&'A')
.and_then(|lines| {
// Section A first line typically:
// [dd/Mon/yyyy:HH:MM:SS +offset] transaction_id source_ip source_port dest_ip dest_port
lines.first().and_then(|line| {
// Find the content after the timestamp bracket.
let after_bracket = line.find(']').map(|i| &line[i + 1..])?;
let parts: Vec<&str> = after_bracket.split_whitespace().collect();
// parts: [transaction_id, source_ip, source_port, dest_ip, dest_port]
if parts.len() >= 3 {
Some(parts[1].to_string())
} else {
None
}
})
})
.unwrap_or_else(|| "0.0.0.0".to_string());
// Section B: request line + headers
let section_b = sections.get(&'B')?;
if section_b.is_empty() {
return None;
}
// First non-empty line is the request line.
let request_line = section_b.iter().find(|l| !l.trim().is_empty())?;
let req_parts: Vec<&str> = request_line.splitn(3, ' ').collect();
if req_parts.len() < 2 {
return None;
}
let method = req_parts[0].to_string();
let raw_url = req_parts[1];
let (path, query) = if let Some(q) = raw_url.find('?') {
(raw_url[..q].to_string(), raw_url[q + 1..].to_string())
} else {
(raw_url.to_string(), String::new())
};
// Parse headers from remaining lines.
let mut headers: Vec<(String, String)> = Vec::new();
for line in section_b.iter().skip(1) {
let trimmed = line.trim();
if trimmed.is_empty() {
break;
}
if let Some(colon) = trimmed.find(':') {
let key = trimmed[..colon].trim().to_ascii_lowercase();
let value = trimmed[colon + 1..].trim().to_string();
headers.push((key, value));
}
}
let get_header = |name: &str| -> Option<String> {
headers
.iter()
.find(|(k, _)| k == name)
.map(|(_, v)| v.clone())
};
let host = get_header("host").unwrap_or_else(|| "unknown".to_string());
let user_agent = get_header("user-agent").unwrap_or_else(|| "-".to_string());
let has_cookies = get_header("cookie").is_some();
let referer = get_header("referer").filter(|r| r != "-" && !r.is_empty());
let accept_language = get_header("accept-language").filter(|a| a != "-" && !a.is_empty());
let content_length: u64 = get_header("content-length")
.and_then(|v| v.parse().ok())
.unwrap_or(0);
// Section F: response status
let status = sections
.get(&'F')
.and_then(|lines| {
// First non-empty line: "HTTP/1.1 403 Forbidden"
lines.iter().find(|l| !l.trim().is_empty()).and_then(|line| {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
parts[1].parse::<u16>().ok()
} else {
None
}
})
})
.unwrap_or(0);
// Section H: rule matches → determines label.
let has_rule_match = sections
.get(&'H')
.map(|lines| {
lines.iter().any(|l| {
let lower = l.to_ascii_lowercase();
lower.contains("matched") || lower.contains("warning") || lower.contains("id:")
})
})
.unwrap_or(false);
let label = if has_rule_match { "attack" } else { "normal" }.to_string();
let fields = AuditFields {
method,
host,
path,
query,
client_ip,
status,
duration_ms: 0,
content_length,
user_agent,
has_cookies: Some(has_cookies),
referer,
accept_language,
backend: "-".to_string(),
label: Some(label.clone()),
};
Some((fields, label))
}
/// Cache directory for ModSec data (mirrors the CSIC caching pattern).
#[allow(dead_code)]
fn cache_dir() -> std::path::PathBuf {
let base = std::env::var("XDG_CACHE_HOME")
.map(std::path::PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
std::path::PathBuf::from(home).join(".cache")
});
base.join("sunbeam").join("modsec")
}
#[cfg(test)]
mod tests {
use super::*;
const SAMPLE_AUDIT_LOG: &str = r#"--a1b2c3d4-A--
[01/Jan/2026:12:00:00 +0000] XYZ123 192.168.1.100 54321 10.0.0.1 80
--a1b2c3d4-B--
GET /admin/config.php?debug=1 HTTP/1.1
Host: example.com
User-Agent: curl/7.68.0
Accept: */*
--a1b2c3d4-C--
--a1b2c3d4-F--
HTTP/1.1 403 Forbidden
Content-Type: text/html
--a1b2c3d4-H--
Message: Warning. Matched "Operator `Rx' with parameter" [id "941100"]
Action: Intercepted (phase 2)
--a1b2c3d4-Z--
--e5f6a7b8-A--
[01/Jan/2026:12:00:01 +0000] ABC456 10.0.0.50 12345 10.0.0.1 80
--e5f6a7b8-B--
GET /index.html HTTP/1.1
Host: example.com
User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120
Accept: text/html
Accept-Language: en-US,en;q=0.9
Cookie: session=abc123
Referer: https://example.com/
--e5f6a7b8-F--
HTTP/1.1 200 OK
Content-Type: text/html
--e5f6a7b8-H--
--e5f6a7b8-Z--
"#;
#[test]
fn test_parse_boundary() {
assert_eq!(
parse_boundary("--a1b2c3d4-A--"),
Some(("a1b2c3d4".to_string(), 'A'))
);
assert_eq!(
parse_boundary("--a1b2c3d4-H--"),
Some(("a1b2c3d4".to_string(), 'H'))
);
assert_eq!(parse_boundary("not a boundary"), None);
assert_eq!(parse_boundary("--invalid--"), None);
}
#[test]
fn test_parse_modsec_audit_log_snippet() {
let results = parse_modsec_content(SAMPLE_AUDIT_LOG).unwrap();
assert_eq!(results.len(), 2, "should parse two transactions");
// First entry: attack (has rule match in section H).
let (attack_fields, attack_label) = &results[0];
assert_eq!(attack_label, "attack");
assert_eq!(attack_fields.method, "GET");
assert_eq!(attack_fields.path, "/admin/config.php");
assert_eq!(attack_fields.query, "debug=1");
assert_eq!(attack_fields.client_ip, "192.168.1.100");
assert_eq!(attack_fields.user_agent, "curl/7.68.0");
assert_eq!(attack_fields.status, 403);
assert!(!attack_fields.has_cookies.unwrap_or(true));
// Second entry: normal (no rule match).
let (normal_fields, normal_label) = &results[1];
assert_eq!(normal_label, "normal");
assert_eq!(normal_fields.method, "GET");
assert_eq!(normal_fields.path, "/index.html");
assert_eq!(normal_fields.client_ip, "10.0.0.50");
assert_eq!(normal_fields.status, 200);
assert!(normal_fields.has_cookies.unwrap_or(false));
assert!(normal_fields.referer.is_some());
assert!(normal_fields.accept_language.is_some());
}
#[test]
fn test_empty_input() {
let results = parse_modsec_content("").unwrap();
assert!(results.is_empty());
}
#[test]
fn test_single_entry_no_section_h() {
let content = r#"--abc123-A--
[01/Jan/2026:12:00:00 +0000] TX1 1.2.3.4 1234 5.6.7.8 80
--abc123-B--
GET / HTTP/1.1
Host: test.com
--abc123-F--
HTTP/1.1 200 OK
--abc123-Z--
"#;
let results = parse_modsec_content(content).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].1, "normal");
}
}

598
src/dataset/prepare.rs Normal file
View File

@@ -0,0 +1,598 @@
//! Dataset preparation orchestrator.
//!
//! Combines production logs, external datasets (CSIC, OWASP ModSec), and
//! synthetic data (CIC-IDS2017 timing profiles + wordlists) into a single
//! `DatasetManifest` serialized as bincode.
use crate::dataset::sample::{DataSource, DatasetManifest, DatasetStats, TrainingSample};
use crate::ddos::audit_log::{AuditFields, AuditLog};
use crate::ddos::features::{method_to_u8, LogIpState};
use crate::ddos::train::HeuristicThresholds;
use crate::scanner::features::{self, fx_hash_bytes};
use anyhow::{Context, Result};
use rustc_hash::FxHashSet;
use std::collections::HashMap;
use std::io::BufRead;
use std::path::Path;
/// Arguments for the `prepare-dataset` command.
pub struct PrepareDatasetArgs {
/// Path to production audit log JSONL file.
pub input: String,
/// Path to OWASP ModSecurity audit log file (optional extra data).
pub owasp: Option<String>,
/// Path to wordlist directory (optional, enhances synthetic scanner).
pub wordlists: Option<String>,
/// Output path for the bincode dataset file.
pub output: String,
/// RNG seed for synthetic generation.
pub seed: u64,
/// Path to heuristics.toml for auto-labeling production logs.
pub heuristics: Option<String>,
}
impl Default for PrepareDatasetArgs {
fn default() -> Self {
Self {
input: String::new(),
owasp: None,
wordlists: None,
output: "dataset.bin".to_string(),
seed: 42,
heuristics: None,
}
}
}
pub fn run(args: PrepareDatasetArgs) -> Result<()> {
let mut scanner_samples: Vec<TrainingSample> = Vec::new();
let mut ddos_samples: Vec<TrainingSample> = Vec::new();
// --- 0. Download upstream datasets if not cached ---
crate::dataset::download::download_all()?;
// --- 1. Parse production logs ---
let heuristics = if let Some(h_path) = &args.heuristics {
let content = std::fs::read_to_string(h_path)
.with_context(|| format!("reading heuristics from {h_path}"))?;
toml::from_str::<HeuristicThresholds>(&content)
.with_context(|| format!("parsing heuristics from {h_path}"))?
} else {
HeuristicThresholds::new(10.0, 0.85, 0.7, 0.3, 0.05, 20.0, 10)
};
eprintln!("parsing production logs from {}...", args.input);
let (prod_scanner, prod_ddos) = parse_production_logs(&args.input, &heuristics)?;
eprintln!(
" production: {} scanner, {} ddos samples",
prod_scanner.len(),
prod_ddos.len()
);
scanner_samples.extend(prod_scanner);
ddos_samples.extend(prod_ddos);
// --- 2. CSIC 2010 (scanner) ---
eprintln!("fetching CSIC 2010 dataset...");
let csic_entries = crate::scanner::csic::fetch_csic_dataset()?;
let csic_samples = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?;
eprintln!(" CSIC: {} scanner samples", csic_samples.len());
scanner_samples.extend(csic_samples);
// --- 3. OWASP ModSec (scanner) ---
if let Some(owasp_path) = &args.owasp {
eprintln!("parsing OWASP ModSec audit log from {owasp_path}...");
let modsec_entries =
crate::dataset::modsec::parse_modsec_audit_log(Path::new(owasp_path))?;
let entries_with_host: Vec<(AuditFields, String)> = modsec_entries
.into_iter()
.map(|(fields, _label)| {
let host_prefix = fields.host.split('.').next().unwrap_or("").to_string();
(fields, host_prefix)
})
.collect();
let modsec_samples =
entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?;
eprintln!(" OWASP: {} scanner samples", modsec_samples.len());
scanner_samples.extend(modsec_samples);
}
// --- 4. CIC-IDS2017 timing profiles (from cache if downloaded) ---
let cicids_profiles = if let Some(cached_path) = crate::dataset::download::cicids_cached_path()
{
eprintln!("extracting CIC-IDS2017 timing profiles from cache...");
let profiles = crate::dataset::cicids::extract_timing_profiles(&cached_path)?;
eprintln!(" extracted {} attack-type profiles", profiles.len());
profiles
} else {
eprintln!(" CIC-IDS2017 not cached; using built-in DDoS distributions");
eprintln!(" (run `download-datasets` first for real timing profiles)");
Vec::new()
};
// --- 5. Synthetic data (both models, always generated) ---
eprintln!("generating synthetic samples...");
let config = crate::dataset::synthetic::SyntheticConfig {
num_ddos_attack: 10000,
num_ddos_normal: 10000,
num_scanner_attack: 5000,
num_scanner_normal: 5000,
seed: args.seed,
};
// Synthetic DDoS (uses CIC-IDS2017 profiles if cached, fallback defaults otherwise).
let synthetic_ddos =
crate::dataset::synthetic::generate_ddos_samples(&cicids_profiles, &config);
eprintln!(" synthetic DDoS: {} samples", synthetic_ddos.len());
ddos_samples.extend(synthetic_ddos);
// Synthetic scanner (uses wordlists if provided, built-in patterns otherwise).
let synthetic_scanner = crate::dataset::synthetic::generate_scanner_samples(
args.wordlists.as_deref().map(Path::new),
None,
&config,
)?;
eprintln!(" synthetic scanner: {} samples", synthetic_scanner.len());
scanner_samples.extend(synthetic_scanner);
// --- 6. Compute stats ---
let mut samples_by_source: HashMap<DataSource, usize> = HashMap::new();
for s in scanner_samples.iter().chain(ddos_samples.iter()) {
*samples_by_source.entry(s.source.clone()).or_insert(0) += 1;
}
let scanner_attack_count = scanner_samples.iter().filter(|s| s.label > 0.5).count();
let ddos_attack_count = ddos_samples.iter().filter(|s| s.label > 0.5).count();
let stats = DatasetStats {
total_samples: scanner_samples.len() + ddos_samples.len(),
scanner_samples: scanner_samples.len(),
ddos_samples: ddos_samples.len(),
samples_by_source,
attack_ratio_scanner: if scanner_samples.is_empty() {
0.0
} else {
scanner_attack_count as f64 / scanner_samples.len() as f64
},
attack_ratio_ddos: if ddos_samples.is_empty() {
0.0
} else {
ddos_attack_count as f64 / ddos_samples.len() as f64
},
};
eprintln!("\n--- dataset stats ---");
eprintln!("total samples: {}", stats.total_samples);
eprintln!(
"scanner: {} ({} attack, {:.1}% attack ratio)",
stats.scanner_samples,
scanner_attack_count,
stats.attack_ratio_scanner * 100.0
);
eprintln!(
"ddos: {} ({} attack, {:.1}% attack ratio)",
stats.ddos_samples,
ddos_attack_count,
stats.attack_ratio_ddos * 100.0
);
for (source, count) in &stats.samples_by_source {
eprintln!(" {source}: {count}");
}
let manifest = DatasetManifest {
scanner_samples,
ddos_samples,
stats,
};
crate::dataset::sample::save_dataset(&manifest, Path::new(&args.output))?;
eprintln!("\ndataset saved to {}", args.output);
Ok(())
}
/// Parse production JSONL logs and produce both scanner and DDoS training samples.
///
/// When logs lack explicit labels, heuristic auto-labeling is applied:
/// - Scanner: attack if status >= 400 AND path matches scanner patterns; normal if
/// status < 400 AND has browser indicators (cookies + referer + accept-language).
/// - DDoS: per-IP feature thresholds from `HeuristicThresholds` (same logic as `ddos/train.rs`).
fn parse_production_logs(
input: &str,
heuristics: &HeuristicThresholds,
) -> Result<(Vec<TrainingSample>, Vec<TrainingSample>)> {
let file = std::fs::File::open(input)
.with_context(|| format!("opening {input}"))?;
let reader = std::io::BufReader::new(file);
let mut scanner_samples = Vec::new();
let mut parsed_entries: Vec<(AuditFields, String)> = Vec::new();
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
// Build hashes for scanner feature extraction.
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
.iter()
.map(|e| fx_hash_bytes(e.as_bytes()))
.collect();
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
let entry: AuditLog = match serde_json::from_str(&line) {
Ok(e) => e,
Err(_) => continue,
};
let host_prefix = entry
.fields
.host
.split('.')
.next()
.unwrap_or("")
.to_string();
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
parsed_entries.push((entry.fields, host_prefix));
}
// --- Scanner samples from production logs ---
for (fields, host_prefix) in &parsed_entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let feats = features::extract_features(
&fields.method,
&fields.path,
host_prefix,
has_cookies,
has_referer,
has_accept_language,
"-",
&fields.user_agent,
fields.content_length,
&fragment_hashes,
&extension_hashes,
&log_hosts,
);
// Use ground-truth label if present, otherwise heuristic auto-label.
let label = match fields.label.as_deref() {
Some("attack" | "anomalous") => Some(1.0f32),
Some("normal") => Some(0.0f32),
_ => {
// Heuristic scanner labeling:
// Attack: 404+ AND suspicious path (excluding .git which is valid on Gitea hosts).
// Normal: success status AND browser indicators (cookies + referer + accept-language).
// Note: 401 is excluded since it's expected for private repos.
let status = fields.status;
let path_lower = fields.path.to_ascii_lowercase();
let is_suspicious_path = path_lower.contains(".env")
|| path_lower.contains("wp-login")
|| path_lower.contains("wp-admin")
|| path_lower.contains("phpmyadmin")
|| path_lower.contains("cgi-bin")
|| path_lower.contains("phpinfo")
|| path_lower.contains("/shell")
|| path_lower.contains("..%2f")
|| path_lower.contains("../");
if status >= 404 && is_suspicious_path {
Some(1.0f32)
} else if status < 400 && has_cookies && has_referer && has_accept_language {
Some(0.0f32)
} else {
None // ambiguous — skip
}
}
};
if let Some(l) = label {
scanner_samples.push(TrainingSample {
features: feats.iter().map(|&v| v as f32).collect(),
label: l,
source: DataSource::ProductionLogs,
weight: 1.0,
});
}
}
// --- DDoS samples from production logs ---
let ddos_samples = extract_ddos_samples_from_entries(&parsed_entries, heuristics)?;
Ok((scanner_samples, ddos_samples))
}
/// Extract DDoS feature vectors from parsed log entries using sliding windows.
fn extract_ddos_samples_from_entries(
entries: &[(AuditFields, String)],
heuristics: &HeuristicThresholds,
) -> Result<Vec<TrainingSample>> {
use rustc_hash::FxHashMap;
use std::hash::{Hash, Hasher};
fn fx_hash(s: &str) -> u64 {
let mut h = rustc_hash::FxHasher::default();
s.hash(&mut h);
h.finish()
}
let mut ip_states: FxHashMap<String, LogIpState> = FxHashMap::default();
let mut ip_labels: FxHashMap<String, Option<f32>> = FxHashMap::default();
for (fields, _host_prefix) in entries {
let ip = crate::ddos::audit_log::strip_port(&fields.client_ip).to_string();
let state = ip_states.entry(ip.clone()).or_default();
// Use a simple counter as "timestamp" since we don't have parsed timestamps here.
let ts = state.timestamps.len() as f64;
state.timestamps.push(ts);
state.methods.push(method_to_u8(&fields.method));
state.path_hashes.push(fx_hash(&fields.path));
state.host_hashes.push(fx_hash(&fields.host));
state
.user_agent_hashes
.push(fx_hash(&fields.user_agent));
state.statuses.push(fields.status);
state
.durations
.push(fields.duration_ms.min(u32::MAX as u64) as u32);
state
.content_lengths
.push(fields.content_length.min(u32::MAX as u64) as u32);
state
.has_cookies
.push(fields.has_cookies.unwrap_or(false));
state.has_referer.push(
fields
.referer
.as_deref()
.map(|r| r != "-")
.unwrap_or(false),
);
state.has_accept_language.push(
fields
.accept_language
.as_deref()
.map(|a| a != "-")
.unwrap_or(false),
);
state.suspicious_paths.push(
crate::ddos::features::is_suspicious_path(&fields.path),
);
// Track label per IP if available.
if let Some(ref label_str) = fields.label {
let label_val = match label_str.as_str() {
"attack" | "anomalous" => Some(1.0f32),
"normal" => Some(0.0f32),
_ => None,
};
ip_labels.insert(ip, label_val);
}
}
let min_events = heuristics.min_events.max(3);
let window_size = 50; // events per sliding window
let window_step = 25; // step size (50% overlap)
let window_secs = 60.0;
let mut samples = Vec::new();
for (ip, state) in &ip_states {
let n = state.timestamps.len();
if n < min_events {
continue;
}
// Ground-truth label if available.
let gt_label = match ip_labels.get(ip) {
Some(Some(l)) => Some(*l),
_ => None,
};
// Sliding windows: extract multiple feature vectors per IP.
// For IPs with fewer events than window_size, use a single window over all events.
let mut start = 0;
loop {
let end = (start + window_size).min(n);
if end - start < min_events {
break;
}
let fv = state.extract_features_for_window(start, end, window_secs);
let label = gt_label.unwrap_or_else(|| {
// Heuristic DDoS labeling (mirrors ddos/train.rs logic):
let is_attack = fv[0] > heuristics.request_rate
|| fv[7] > heuristics.path_repetition
|| fv[3] > heuristics.error_rate
|| fv[13] > heuristics.suspicious_path_ratio
|| (fv[10] < heuristics.no_cookies_threshold
&& fv[1] > heuristics.no_cookies_path_count);
if is_attack { 1.0f32 } else { 0.0f32 }
});
samples.push(TrainingSample {
features: fv.iter().map(|&v| v as f32).collect(),
label,
source: DataSource::ProductionLogs,
weight: 1.0,
});
if end >= n {
break;
}
start += window_step;
}
}
Ok(samples)
}
/// Convert external dataset entries (CSIC, OWASP) into scanner TrainingSamples.
fn entries_to_scanner_samples(
entries: &[(AuditFields, String)],
source: DataSource,
weight: f32,
) -> Result<Vec<TrainingSample>> {
let fragment_hashes: FxHashSet<u64> = crate::scanner::train::DEFAULT_FRAGMENTS
.iter()
.map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes()))
.collect();
let extension_hashes: FxHashSet<u64> = features::SUSPICIOUS_EXTENSIONS_LIST
.iter()
.map(|e| fx_hash_bytes(e.as_bytes()))
.collect();
let mut log_hosts: FxHashSet<u64> = FxHashSet::default();
for (_, host_prefix) in entries {
log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes()));
}
let mut samples = Vec::new();
for (fields, host_prefix) in entries {
let has_cookies = fields.has_cookies.unwrap_or(false);
let has_referer = fields
.referer
.as_ref()
.map(|r| r != "-" && !r.is_empty())
.unwrap_or(false);
let has_accept_language = fields
.accept_language
.as_ref()
.map(|a| a != "-" && !a.is_empty())
.unwrap_or(false);
let feats = features::extract_features(
&fields.method,
&fields.path,
host_prefix,
has_cookies,
has_referer,
has_accept_language,
"-",
&fields.user_agent,
fields.content_length,
&fragment_hashes,
&extension_hashes,
&log_hosts,
);
let label = match fields.label.as_deref() {
Some("attack" | "anomalous") => 1.0f32,
Some("normal") => 0.0f32,
_ => continue,
};
samples.push(TrainingSample {
features: feats.iter().map(|&v| v as f32).collect(),
label,
source: source.clone(),
weight,
});
}
Ok(samples)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ddos::features::NUM_FEATURES;
use crate::scanner::features::NUM_SCANNER_FEATURES;
#[test]
fn test_ddos_sample_feature_count() {
// Verify that DDoS samples from sliding windows produce 14 features.
let entries = vec![
make_test_entry("GET", "/", "1.2.3.4", 200, "normal"),
make_test_entry("GET", "/about", "1.2.3.4", 200, "normal"),
make_test_entry("GET", "/contact", "1.2.3.4", 200, "normal"),
make_test_entry("GET", "/blog", "1.2.3.4", 200, "normal"),
make_test_entry("GET", "/faq", "1.2.3.4", 200, "normal"),
];
let heuristics = HeuristicThresholds::new(10.0, 0.85, 0.7, 0.3, 0.05, 20.0, 5);
let samples = extract_ddos_samples_from_entries(&entries, &heuristics).unwrap();
// With only 5 entries at min_events=5, we should get 1 sample.
assert!(!samples.is_empty(), "should produce DDoS samples");
for s in &samples {
assert_eq!(
s.features.len(),
NUM_FEATURES,
"DDoS sample should have {NUM_FEATURES} features"
);
}
}
#[test]
fn test_scanner_sample_feature_count() {
let entries = vec![
make_test_entry("GET", "/.env", "1.2.3.4", 404, "attack"),
make_test_entry("GET", "/index.html", "5.6.7.8", 200, "normal"),
];
let samples =
entries_to_scanner_samples(&entries, DataSource::ProductionLogs, 1.0).unwrap();
assert_eq!(samples.len(), 2);
for s in &samples {
assert_eq!(
s.features.len(),
NUM_SCANNER_FEATURES,
"scanner sample should have {NUM_SCANNER_FEATURES} features"
);
}
}
#[test]
fn test_entries_to_scanner_samples_labels() {
let entries = vec![
make_test_entry("GET", "/.env", "1.2.3.4", 404, "attack"),
make_test_entry("GET", "/", "5.6.7.8", 200, "normal"),
make_test_entry("GET", "/page", "9.10.11.12", 200, "unknown_label"),
];
let samples =
entries_to_scanner_samples(&entries, DataSource::Csic2010, 0.8).unwrap();
// "unknown_label" should be skipped.
assert_eq!(samples.len(), 2);
assert_eq!(samples[0].label, 1.0);
assert_eq!(samples[1].label, 0.0);
assert_eq!(samples[0].weight, 0.8);
assert_eq!(samples[0].source, DataSource::Csic2010);
}
fn make_test_entry(
method: &str,
path: &str,
client_ip: &str,
status: u16,
label: &str,
) -> (AuditFields, String) {
let fields = AuditFields {
method: method.to_string(),
host: "test.sunbeam.pt".to_string(),
path: path.to_string(),
query: String::new(),
client_ip: client_ip.to_string(),
status,
duration_ms: 10,
content_length: 0,
user_agent: "Mozilla/5.0".to_string(),
has_cookies: Some(true),
referer: Some("https://test.sunbeam.pt".to_string()),
accept_language: Some("en-US".to_string()),
backend: "test-svc:8080".to_string(),
label: Some(label.to_string()),
};
(fields, "test".to_string())
}
}

156
src/dataset/sample.rs Normal file
View File

@@ -0,0 +1,156 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use anyhow::{Context, Result};
/// Provenance of a training sample.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DataSource {
ProductionLogs,
Csic2010,
OwaspModSec,
SyntheticCicTiming,
SyntheticWordlist,
}
impl std::fmt::Display for DataSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DataSource::ProductionLogs => write!(f, "ProductionLogs"),
DataSource::Csic2010 => write!(f, "Csic2010"),
DataSource::OwaspModSec => write!(f, "OwaspModSec"),
DataSource::SyntheticCicTiming => write!(f, "SyntheticCicTiming"),
DataSource::SyntheticWordlist => write!(f, "SyntheticWordlist"),
}
}
}
/// A single labeled training sample with its feature vector, label, provenance,
/// and weight (used during training to down-weight synthetic/external data).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingSample {
pub features: Vec<f32>,
/// 0.0 = normal, 1.0 = attack
pub label: f32,
pub source: DataSource,
/// Sample weight: 1.0 for production, 0.8 for external datasets, 0.5 for synthetic.
pub weight: f32,
}
/// Aggregate statistics about a prepared dataset.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetStats {
pub total_samples: usize,
pub scanner_samples: usize,
pub ddos_samples: usize,
pub samples_by_source: HashMap<DataSource, usize>,
pub attack_ratio_scanner: f64,
pub attack_ratio_ddos: f64,
}
/// The full serializable dataset: scanner samples, DDoS samples, and stats.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetManifest {
pub scanner_samples: Vec<TrainingSample>,
pub ddos_samples: Vec<TrainingSample>,
pub stats: DatasetStats,
}
/// Serialize a `DatasetManifest` to a bincode file.
pub fn save_dataset(manifest: &DatasetManifest, path: &Path) -> Result<()> {
let encoded = bincode::serialize(manifest).context("serializing dataset manifest")?;
std::fs::write(path, &encoded)
.with_context(|| format!("writing dataset to {}", path.display()))?;
Ok(())
}
/// Deserialize a `DatasetManifest` from a bincode file.
pub fn load_dataset(path: &Path) -> Result<DatasetManifest> {
let data = std::fs::read(path)
.with_context(|| format!("reading dataset from {}", path.display()))?;
let manifest: DatasetManifest =
bincode::deserialize(&data).context("deserializing dataset manifest")?;
Ok(manifest)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_sample(features: Vec<f32>, label: f32, source: DataSource, weight: f32) -> TrainingSample {
TrainingSample { features, label, source, weight }
}
#[test]
fn test_bincode_roundtrip() {
let manifest = DatasetManifest {
scanner_samples: vec![
make_sample(vec![0.1, 0.2, 0.3], 0.0, DataSource::ProductionLogs, 1.0),
make_sample(vec![0.9, 0.8, 0.7], 1.0, DataSource::Csic2010, 0.8),
],
ddos_samples: vec![
make_sample(vec![0.5; 14], 1.0, DataSource::SyntheticCicTiming, 0.5),
make_sample(vec![0.1; 14], 0.0, DataSource::ProductionLogs, 1.0),
],
stats: DatasetStats {
total_samples: 4,
scanner_samples: 2,
ddos_samples: 2,
samples_by_source: {
let mut m = HashMap::new();
m.insert(DataSource::ProductionLogs, 2);
m.insert(DataSource::Csic2010, 1);
m.insert(DataSource::SyntheticCicTiming, 1);
m
},
attack_ratio_scanner: 0.5,
attack_ratio_ddos: 0.5,
},
};
let encoded = bincode::serialize(&manifest).unwrap();
let decoded: DatasetManifest = bincode::deserialize(&encoded).unwrap();
assert_eq!(decoded.scanner_samples.len(), 2);
assert_eq!(decoded.ddos_samples.len(), 2);
assert_eq!(decoded.stats.total_samples, 4);
assert_eq!(decoded.stats.samples_by_source[&DataSource::ProductionLogs], 2);
// Verify feature values survive the roundtrip.
assert!((decoded.scanner_samples[0].features[0] - 0.1).abs() < 1e-6);
assert_eq!(decoded.scanner_samples[1].label, 1.0);
assert_eq!(decoded.ddos_samples[0].source, DataSource::SyntheticCicTiming);
}
#[test]
fn test_save_load_roundtrip() {
let manifest = DatasetManifest {
scanner_samples: vec![
make_sample(vec![1.0, 2.0], 0.0, DataSource::ProductionLogs, 1.0),
],
ddos_samples: vec![
make_sample(vec![3.0, 4.0], 1.0, DataSource::OwaspModSec, 0.8),
],
stats: DatasetStats {
total_samples: 2,
scanner_samples: 1,
ddos_samples: 1,
samples_by_source: HashMap::new(),
attack_ratio_scanner: 0.0,
attack_ratio_ddos: 1.0,
},
};
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test_dataset.bin");
save_dataset(&manifest, &path).unwrap();
let loaded = load_dataset(&path).unwrap();
assert_eq!(loaded.scanner_samples.len(), 1);
assert_eq!(loaded.ddos_samples.len(), 1);
assert!((loaded.scanner_samples[0].features[0] - 1.0).abs() < 1e-6);
assert_eq!(loaded.ddos_samples[0].source, DataSource::OwaspModSec);
}
}

558
src/dataset/synthetic.rs Normal file
View File

@@ -0,0 +1,558 @@
//! Synthetic data generation for DDoS and scanner training samples.
//!
//! DDoS synthetic samples are generated by sampling from CIC-IDS2017 timing
//! profiles. Scanner synthetic samples are generated from wordlist paths
//! and realistic browser-like feature vectors.
use crate::dataset::cicids::TimingProfile;
use crate::dataset::sample::{DataSource, TrainingSample};
use crate::ddos::features::NUM_FEATURES;
use crate::scanner::features::NUM_SCANNER_FEATURES;
use anyhow::Result;
use rand::prelude::*;
use rand::rngs::StdRng;
use std::io::BufRead;
use std::path::Path;
/// Configuration for synthetic sample generation.
#[derive(Debug, Clone)]
pub struct SyntheticConfig {
pub num_ddos_attack: usize,
pub num_ddos_normal: usize,
pub num_scanner_attack: usize,
pub num_scanner_normal: usize,
pub seed: u64,
}
impl Default for SyntheticConfig {
fn default() -> Self {
Self {
num_ddos_attack: 1000,
num_ddos_normal: 1000,
num_scanner_attack: 1000,
num_scanner_normal: 1000,
seed: 42,
}
}
}
/// Generate synthetic DDoS training samples from CIC-IDS2017 timing profiles.
///
/// Produces 14-dimensional feature vectors matching `NUM_FEATURES`:
/// 0: request_rate, 1: unique_paths, 2: unique_hosts, 3: error_rate,
/// 4: avg_duration_ms, 5: method_entropy, 6: burst_score, 7: path_repetition,
/// 8: avg_content_length, 9: unique_user_agents, 10: cookie_ratio,
/// 11: referer_ratio, 12: accept_language_ratio, 13: suspicious_path_ratio
pub fn generate_ddos_samples(
profiles: &[TimingProfile],
config: &SyntheticConfig,
) -> Vec<TrainingSample> {
let mut rng = StdRng::seed_from_u64(config.seed);
let mut samples = Vec::with_capacity(config.num_ddos_attack + config.num_ddos_normal);
// Generate attack samples.
for _ in 0..config.num_ddos_attack {
let features = generate_ddos_attack_features(profiles, &mut rng);
samples.push(TrainingSample {
features,
label: 1.0,
source: DataSource::SyntheticCicTiming,
weight: 0.5,
});
}
// Generate normal samples.
for _ in 0..config.num_ddos_normal {
let features = generate_ddos_normal_features(&mut rng);
samples.push(TrainingSample {
features,
label: 0.0,
source: DataSource::SyntheticCicTiming,
weight: 0.5,
});
}
samples
}
/// Generate a single DDoS attack feature vector, sampling from timing profiles.
fn generate_ddos_attack_features(profiles: &[TimingProfile], rng: &mut StdRng) -> Vec<f32> {
let mut features = vec![0.0f32; NUM_FEATURES];
// Sample from a random profile if available, otherwise use defaults.
let (iat_mean, burst_mean, bps_mean) = if !profiles.is_empty() {
let profile = &profiles[rng.random_range(0..profiles.len())];
let iat = sample_positive_normal(
rng,
profile.inter_arrival_mean,
profile.inter_arrival_std,
);
let burst = sample_positive_normal(
rng,
profile.burst_duration_mean,
profile.burst_duration_std,
);
let bps = sample_positive_normal(
rng,
profile.flow_bytes_per_sec_mean,
profile.flow_bytes_per_sec_std,
);
(iat, burst, bps)
} else {
// Fallback: aggressive attack defaults.
let iat = rng.random_range(0.001..0.05);
let burst = rng.random_range(0.5..5.0);
let bps = rng.random_range(10000.0..100000.0);
(iat, burst, bps)
};
// 0: request_rate — high for attacks.
features[0] = if iat_mean > 0.0 {
(1.0 / iat_mean).min(1000.0) as f32
} else {
rng.random_range(50.0..500.0) as f32
};
// 1: unique_paths — low to moderate (attackers repeat paths).
features[1] = rng.random_range(1.0..10.0) as f32;
// 2: unique_hosts — typically 1 for DDoS.
features[2] = rng.random_range(1.0..3.0) as f32;
// 3: error_rate — moderate to high.
features[3] = rng.random_range(0.3..0.9) as f32;
// 4: avg_duration_ms — derived from burst duration.
features[4] = (burst_mean * 100.0).max(1.0).min(5000.0) as f32;
// 5: method_entropy — low (mostly GET).
features[5] = rng.random_range(0.0..0.3) as f32;
// 6: burst_score — high (inverse of inter-arrival).
features[6] = if iat_mean > 0.0 {
(1.0 / iat_mean).min(500.0) as f32
} else {
rng.random_range(10.0..200.0) as f32
};
// 7: path_repetition — high (attackers repeat same paths).
features[7] = rng.random_range(0.6..1.0) as f32;
// 8: avg_content_length — derived from flow bytes.
features[8] = (bps_mean * 0.01).max(0.0).min(10000.0) as f32;
// 9: unique_user_agents — low (1-2 UAs).
features[9] = rng.random_range(1.0..3.0) as f32;
// 10: cookie_ratio — very low (bots don't send cookies).
features[10] = rng.random_range(0.0..0.1) as f32;
// 11: referer_ratio — very low.
features[11] = rng.random_range(0.0..0.1) as f32;
// 12: accept_language_ratio — very low.
features[12] = rng.random_range(0.0..0.1) as f32;
// 13: suspicious_path_ratio — moderate to high.
features[13] = rng.random_range(0.1..0.7) as f32;
features
}
/// Generate a single DDoS normal feature vector.
fn generate_ddos_normal_features(rng: &mut StdRng) -> Vec<f32> {
let mut features = vec![0.0f32; NUM_FEATURES];
// 0: request_rate — moderate.
features[0] = rng.random_range(0.1..5.0) as f32;
// 1: unique_paths — moderate.
features[1] = rng.random_range(3.0..30.0) as f32;
// 2: unique_hosts — typically 1-3.
features[2] = rng.random_range(1.0..5.0) as f32;
// 3: error_rate — low.
features[3] = rng.random_range(0.0..0.15) as f32;
// 4: avg_duration_ms — reasonable.
features[4] = rng.random_range(10.0..500.0) as f32;
// 5: method_entropy — moderate (mix of GET/POST).
features[5] = rng.random_range(0.0..1.5) as f32;
// 6: burst_score — low.
features[6] = rng.random_range(0.05..2.0) as f32;
// 7: path_repetition — low to moderate.
features[7] = rng.random_range(0.05..0.5) as f32;
// 8: avg_content_length — reasonable.
features[8] = rng.random_range(0.0..2000.0) as f32;
// 9: unique_user_agents — 1-3 (real users have few UAs).
features[9] = rng.random_range(1.0..4.0) as f32;
// 10: cookie_ratio — high (real users have cookies).
features[10] = rng.random_range(0.7..1.0) as f32;
// 11: referer_ratio — moderate to high.
features[11] = rng.random_range(0.4..1.0) as f32;
// 12: accept_language_ratio — high.
features[12] = rng.random_range(0.7..1.0) as f32;
// 13: suspicious_path_ratio — very low.
features[13] = rng.random_range(0.0..0.05) as f32;
features
}
/// Generate synthetic scanner training samples.
///
/// Attack samples use paths from wordlist files (if provided) or built-in
/// scanner paths. Normal samples have realistic browser-like feature vectors.
///
/// Produces 12-dimensional feature vectors matching `NUM_SCANNER_FEATURES`:
/// 0: suspicious_path_score, 1: path_depth, 2: has_suspicious_extension,
/// 3: has_cookies, 4: has_referer, 5: has_accept_language,
/// 6: accept_quality, 7: ua_category, 8: method_is_unusual,
/// 9: host_is_configured, 10: content_length_mismatch, 11: path_has_traversal
pub fn generate_scanner_samples(
wordlist_dir: Option<&Path>,
_production_logs: Option<&Path>,
config: &SyntheticConfig,
) -> Result<Vec<TrainingSample>> {
let mut rng = StdRng::seed_from_u64(config.seed.wrapping_add(1));
let mut samples = Vec::with_capacity(config.num_scanner_attack + config.num_scanner_normal);
// Load wordlist paths if provided.
let attack_paths = load_wordlist_paths(wordlist_dir)?;
// Generate attack samples.
for i in 0..config.num_scanner_attack {
let features = if !attack_paths.is_empty() {
let path = &attack_paths[i % attack_paths.len()];
generate_scanner_attack_features_from_path(path, &mut rng)
} else {
generate_scanner_attack_features(&mut rng)
};
samples.push(TrainingSample {
features,
label: 1.0,
source: DataSource::SyntheticWordlist,
weight: 0.5,
});
}
// Generate normal samples.
for _ in 0..config.num_scanner_normal {
let features = generate_scanner_normal_features(&mut rng);
samples.push(TrainingSample {
features,
label: 0.0,
source: DataSource::SyntheticWordlist,
weight: 0.5,
});
}
Ok(samples)
}
/// Load paths from wordlist files in a directory (or single file).
fn load_wordlist_paths(dir: Option<&Path>) -> Result<Vec<String>> {
let dir = match dir {
Some(d) => d,
None => return Ok(Vec::new()),
};
let files: Vec<std::path::PathBuf> = if dir.is_file() {
vec![dir.to_path_buf()]
} else if dir.exists() {
std::fs::read_dir(dir)?
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.extension().map(|e| e == "txt").unwrap_or(false))
.collect()
} else {
return Ok(Vec::new());
};
let mut paths = Vec::new();
for file in &files {
let f = std::fs::File::open(file)?;
let reader = std::io::BufReader::new(f);
for line in reader.lines() {
let line = line?;
let line = line.trim().to_string();
if line.is_empty() || line.starts_with('#') {
continue;
}
let normalized = if line.starts_with('/') {
line
} else {
format!("/{line}")
};
paths.push(normalized);
}
}
Ok(paths)
}
/// Generate scanner attack features from a specific path.
fn generate_scanner_attack_features_from_path(path: &str, rng: &mut StdRng) -> Vec<f32> {
let mut features = vec![0.0f32; NUM_SCANNER_FEATURES];
let lower = path.to_ascii_lowercase();
// 0: suspicious_path_score — check for known bad fragments.
let suspicious_frags = [
".env", "wp-admin", "wp-login", "phpinfo", "phpmyadmin",
".git", "cgi-bin", "shell", "admin", "config",
];
let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let seg_count = segments.len().max(1) as f32;
let matches = segments
.iter()
.filter(|s| {
let sl = s.to_ascii_lowercase();
suspicious_frags.iter().any(|f| sl.contains(f))
})
.count() as f32;
features[0] = matches / seg_count;
// 1: path_depth.
features[1] = path.bytes().filter(|&b| b == b'/').count().min(20) as f32;
// 2: has_suspicious_extension.
let suspicious_exts = [".php", ".env", ".sql", ".bak", ".asp", ".jsp", ".cgi", ".tar", ".zip", ".git"];
features[2] = if suspicious_exts.iter().any(|e| lower.ends_with(e)) {
1.0
} else {
0.0
};
// 3: has_cookies — scanners typically don't.
features[3] = if rng.random_bool(0.05) { 1.0 } else { 0.0 };
// 4: has_referer — scanners rarely send referer.
features[4] = if rng.random_bool(0.05) { 1.0 } else { 0.0 };
// 5: has_accept_language — scanners rarely send this.
features[5] = if rng.random_bool(0.1) { 1.0 } else { 0.0 };
// 6: accept_quality — scanners use */* or empty.
features[6] = 0.0;
// 7: ua_category — mix of curl/empty/bot UAs.
let ua_roll: f64 = rng.random();
features[7] = if ua_roll < 0.3 {
0.0 // empty
} else if ua_roll < 0.6 {
0.25 // curl/wget
} else {
0.5 // random bot
};
// 8: method_is_unusual — mostly GET.
features[8] = if rng.random_bool(0.05) { 1.0 } else { 0.0 };
// 9: host_is_configured — often unknown host.
features[9] = if rng.random_bool(0.2) { 1.0 } else { 0.0 };
// 10: content_length_mismatch.
features[10] = if rng.random_bool(0.1) { 1.0 } else { 0.0 };
// 11: path_has_traversal.
let traversal_patterns = ["..", "%00", "%0a", "%27", "%3c"];
features[11] = if traversal_patterns.iter().any(|p| lower.contains(p)) {
1.0
} else {
0.0
};
features
}
/// Generate scanner attack features without a specific path (built-in defaults).
fn generate_scanner_attack_features(rng: &mut StdRng) -> Vec<f32> {
let mut features = vec![0.0f32; NUM_SCANNER_FEATURES];
features[0] = rng.random_range(0.3..1.0) as f32; // suspicious_path_score
features[1] = rng.random_range(2.0..8.0) as f32; // path_depth
features[2] = if rng.random_bool(0.6) { 1.0 } else { 0.0 }; // suspicious_ext
features[3] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; // cookies
features[4] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; // referer
features[5] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // accept_language
features[6] = 0.0; // accept_quality
features[7] = rng.random_range(0.0..0.5) as f32; // ua_category
features[8] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // unusual method
features[9] = if rng.random_bool(0.2) { 1.0 } else { 0.0 }; // host_configured
features[10] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // content_len mismatch
features[11] = if rng.random_bool(0.15) { 1.0 } else { 0.0 }; // traversal
features
}
/// Generate normal scanner features (browser-like).
fn generate_scanner_normal_features(rng: &mut StdRng) -> Vec<f32> {
let mut features = vec![0.0f32; NUM_SCANNER_FEATURES];
features[0] = 0.0; // suspicious_path_score — clean path
features[1] = rng.random_range(1.0..4.0) as f32; // path_depth — moderate
features[2] = 0.0; // no suspicious extension
features[3] = if rng.random_bool(0.85) { 1.0 } else { 0.0 }; // has_cookies — usually yes
features[4] = if rng.random_bool(0.7) { 1.0 } else { 0.0 }; // has_referer — often
features[5] = if rng.random_bool(0.9) { 1.0 } else { 0.0 }; // has_accept_language
features[6] = 1.0; // accept_quality — browsers send proper accept
features[7] = 1.0; // ua_category — browser UA
features[8] = 0.0; // method_is_unusual — GET/POST
features[9] = if rng.random_bool(0.95) { 1.0 } else { 0.0 }; // host_is_configured
features[10] = 0.0; // no content_length_mismatch
features[11] = 0.0; // no traversal
features
}
/// Sample from a normal distribution, clamped to positive values.
fn sample_positive_normal(rng: &mut StdRng, mean: f64, std_dev: f64) -> f64 {
if std_dev <= 0.0 {
return mean.max(0.0);
}
// Box-Muller transform.
let u1: f64 = rng.random::<f64>().max(1e-10);
let u2: f64 = rng.random();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
(mean + std_dev * z).max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ddos_feature_dimensions() {
let profiles = vec![TimingProfile {
attack_type: "DDoS".to_string(),
inter_arrival_mean: 0.01,
inter_arrival_std: 0.005,
burst_duration_mean: 2.0,
burst_duration_std: 1.0,
flow_bytes_per_sec_mean: 50000.0,
flow_bytes_per_sec_std: 20000.0,
sample_count: 100,
}];
let config = SyntheticConfig {
num_ddos_attack: 10,
num_ddos_normal: 10,
seed: 42,
..Default::default()
};
let samples = generate_ddos_samples(&profiles, &config);
assert_eq!(samples.len(), 20);
for s in &samples {
assert_eq!(
s.features.len(),
NUM_FEATURES,
"DDoS features should have {} dimensions",
NUM_FEATURES
);
}
}
#[test]
fn test_scanner_feature_dimensions() {
let config = SyntheticConfig {
num_scanner_attack: 10,
num_scanner_normal: 10,
seed: 42,
..Default::default()
};
let samples = generate_scanner_samples(None, None, &config).unwrap();
assert_eq!(samples.len(), 20);
for s in &samples {
assert_eq!(
s.features.len(),
NUM_SCANNER_FEATURES,
"scanner features should have {} dimensions",
NUM_SCANNER_FEATURES
);
}
}
#[test]
fn test_deterministic_with_fixed_seed() {
let config = SyntheticConfig {
num_ddos_attack: 5,
num_ddos_normal: 5,
num_scanner_attack: 5,
num_scanner_normal: 5,
seed: 123,
};
let ddos_a = generate_ddos_samples(&[], &config);
let ddos_b = generate_ddos_samples(&[], &config);
for (a, b) in ddos_a.iter().zip(ddos_b.iter()) {
assert_eq!(a.features, b.features, "DDoS samples should be deterministic");
assert_eq!(a.label, b.label);
}
let scanner_a = generate_scanner_samples(None, None, &config).unwrap();
let scanner_b = generate_scanner_samples(None, None, &config).unwrap();
for (a, b) in scanner_a.iter().zip(scanner_b.iter()) {
assert_eq!(a.features, b.features, "scanner samples should be deterministic");
assert_eq!(a.label, b.label);
}
}
#[test]
fn test_ddos_attack_vs_normal_distinguishable() {
let config = SyntheticConfig {
num_ddos_attack: 100,
num_ddos_normal: 100,
seed: 42,
..Default::default()
};
let samples = generate_ddos_samples(&[], &config);
let attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label > 0.5).collect();
let normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label <= 0.5).collect();
// Attack samples should have higher avg request_rate (feature 0).
let avg_attack_rate: f32 =
attacks.iter().map(|s| s.features[0]).sum::<f32>() / attacks.len() as f32;
let avg_normal_rate: f32 =
normals.iter().map(|s| s.features[0]).sum::<f32>() / normals.len() as f32;
assert!(
avg_attack_rate > avg_normal_rate,
"attack rate {} should exceed normal rate {}",
avg_attack_rate,
avg_normal_rate
);
// Attack samples should have lower avg cookie_ratio (feature 10).
let avg_attack_cookies: f32 =
attacks.iter().map(|s| s.features[10]).sum::<f32>() / attacks.len() as f32;
let avg_normal_cookies: f32 =
normals.iter().map(|s| s.features[10]).sum::<f32>() / normals.len() as f32;
assert!(
avg_attack_cookies < avg_normal_cookies,
"attack cookies {} should be lower than normal cookies {}",
avg_attack_cookies,
avg_normal_cookies
);
}
#[test]
fn test_sample_positive_normal_stays_positive() {
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..1000 {
let val = sample_positive_normal(&mut rng, 1.0, 5.0);
assert!(val >= 0.0, "value should be non-negative: {val}");
}
}
#[test]
fn test_scanner_attack_from_wordlist_path() {
let mut rng = StdRng::seed_from_u64(42);
let features = generate_scanner_attack_features_from_path("/.env", &mut rng);
assert_eq!(features.len(), NUM_SCANNER_FEATURES);
// .env should trigger suspicious_path_score > 0.
assert!(features[0] > 0.0, "suspicious_path_score should be positive for /.env");
// .env should trigger suspicious_extension.
assert_eq!(features[2], 1.0, "should detect .env as suspicious extension");
}
}