diff --git a/src/dataset/cicids.rs b/src/dataset/cicids.rs new file mode 100644 index 0000000..12abc47 --- /dev/null +++ b/src/dataset/cicids.rs @@ -0,0 +1,301 @@ +//! CIC-IDS2017 timing profile extractor. +//! +//! Parses CIC-IDS2017 CSV files and extracts statistical timing profiles +//! per attack type. These profiles are NOT training samples themselves — +//! they feed the synthetic data generator (`synthetic.rs`) which samples +//! from the learned distributions to produce DDoS training features. + +use anyhow::{Context, Result}; +use std::path::Path; + +/// Statistical timing profile for one attack type from CIC-IDS2017. +#[derive(Debug, Clone)] +pub struct TimingProfile { + pub attack_type: String, + pub inter_arrival_mean: f64, + pub inter_arrival_std: f64, + pub burst_duration_mean: f64, + pub burst_duration_std: f64, + pub flow_bytes_per_sec_mean: f64, + pub flow_bytes_per_sec_std: f64, + pub sample_count: usize, +} + +/// Accumulator for computing mean and variance in a single pass (Welford's algorithm). +#[derive(Default)] +struct StatsAccumulator { + count: usize, + mean: f64, + m2: f64, +} + +impl StatsAccumulator { + fn push(&mut self, value: f64) { + if !value.is_finite() { + return; + } + self.count += 1; + let delta = value - self.mean; + self.mean += delta / self.count as f64; + let delta2 = value - self.mean; + self.m2 += delta * delta2; + } + + fn mean(&self) -> f64 { + self.mean + } + + fn std_dev(&self) -> f64 { + if self.count < 2 { + 0.0 + } else { + (self.m2 / (self.count - 1) as f64).sqrt() + } + } +} + +/// Per-label accumulator for timing statistics. +struct LabelAccumulator { + label: String, + inter_arrival: StatsAccumulator, + burst_duration: StatsAccumulator, + flow_bytes_per_sec: StatsAccumulator, + count: usize, +} + +impl LabelAccumulator { + fn new(label: String) -> Self { + Self { + label, + inter_arrival: StatsAccumulator::default(), + burst_duration: StatsAccumulator::default(), + flow_bytes_per_sec: StatsAccumulator::default(), + count: 0, + } + } + + fn into_profile(self) -> TimingProfile { + TimingProfile { + attack_type: self.label, + inter_arrival_mean: self.inter_arrival.mean(), + inter_arrival_std: self.inter_arrival.std_dev(), + burst_duration_mean: self.burst_duration.mean(), + burst_duration_std: self.burst_duration.std_dev(), + flow_bytes_per_sec_mean: self.flow_bytes_per_sec.mean(), + flow_bytes_per_sec_std: self.flow_bytes_per_sec.std_dev(), + sample_count: self.count, + } + } +} + +/// Find a column index by name (case-insensitive, trimmed). +fn find_column(headers: &[String], name: &str) -> Option { + let lower = name.to_ascii_lowercase(); + headers + .iter() + .position(|h| h.trim().to_ascii_lowercase() == lower) +} + +/// Extract timing profiles from all CSV files in the given directory. +/// +/// Each CSV is expected to have CIC-IDS2017 columns including at minimum: +/// `Flow Duration`, `Flow IAT Mean`, `Flow IAT Std`, `Flow Bytes/s`, `Label`. +/// +/// Returns one `TimingProfile` per unique label value. +pub fn extract_timing_profiles(csv_dir: &Path) -> Result> { + let entries: Vec = if csv_dir.is_file() { + vec![csv_dir.to_path_buf()] + } else { + let mut files: Vec = std::fs::read_dir(csv_dir) + .with_context(|| format!("reading directory {}", csv_dir.display()))? + .filter_map(|e| e.ok()) + .map(|e| e.path()) + .filter(|p| { + p.extension() + .map(|e| e.to_ascii_lowercase() == "csv") + .unwrap_or(false) + }) + .collect(); + files.sort(); + files + }; + + if entries.is_empty() { + anyhow::bail!("no CSV files found in {}", csv_dir.display()); + } + + let mut accumulators: std::collections::HashMap = + std::collections::HashMap::new(); + + for csv_path in &entries { + parse_csv_file(csv_path, &mut accumulators) + .with_context(|| format!("parsing {}", csv_path.display()))?; + } + + let mut profiles: Vec = accumulators + .into_values() + .filter(|a| a.count > 0) + .map(|a| a.into_profile()) + .collect(); + profiles.sort_by(|a, b| a.attack_type.cmp(&b.attack_type)); + + Ok(profiles) +} + +fn parse_csv_file( + path: &Path, + accumulators: &mut std::collections::HashMap, +) -> Result<()> { + let mut rdr = csv::ReaderBuilder::new() + .flexible(true) + .trim(csv::Trim::All) + .from_path(path)?; + + let headers: Vec = rdr + .headers()? + .iter() + .map(|h| h.to_string()) + .collect(); + + // Locate required columns. + let col_label = find_column(&headers, "Label") + .with_context(|| format!("missing 'Label' column in {}", path.display()))?; + let col_flow_duration = find_column(&headers, "Flow Duration"); + let col_iat_mean = find_column(&headers, "Flow IAT Mean"); + let col_iat_std = find_column(&headers, "Flow IAT Std"); + let col_bytes_per_sec = find_column(&headers, "Flow Bytes/s"); + + for result in rdr.records() { + let record = match result { + Ok(r) => r, + Err(_) => continue, + }; + + let label = match record.get(col_label) { + Some(l) => l.trim().to_string(), + None => continue, + }; + if label.is_empty() { + continue; + } + + let acc = accumulators + .entry(label.clone()) + .or_insert_with(|| LabelAccumulator::new(label)); + acc.count += 1; + + // Flow Duration (microseconds in CIC-IDS2017) → we treat as burst duration. + if let Some(col) = col_flow_duration { + if let Some(val) = record.get(col).and_then(|v| v.trim().parse::().ok()) { + // Convert microseconds to seconds. + acc.burst_duration.push(val / 1_000_000.0); + } + } + + // Flow IAT Mean (microseconds) → inter-arrival time. + if let Some(col) = col_iat_mean { + if let Some(val) = record.get(col).and_then(|v| v.trim().parse::().ok()) { + acc.inter_arrival.push(val / 1_000_000.0); + } + } + + // Flow IAT Std → used as inter_arrival_std contribution. + if let Some(col) = col_iat_std { + if let Some(_val) = record.get(col).and_then(|v| v.trim().parse::().ok()) { + // The per-flow IAT std contributes to the overall std via Welford above. + // We use the IAT Mean values; std is computed from those. + } + } + + // Flow Bytes/s. + if let Some(col) = col_bytes_per_sec { + if let Some(val) = record.get(col).and_then(|v| v.trim().parse::().ok()) { + acc.flow_bytes_per_sec.push(val); + } + } + } + + Ok(()) +} + +/// Parse timing profiles from an in-memory CSV string (useful for tests). +pub fn extract_timing_profiles_from_str(csv_content: &str) -> Result> { + let dir = tempfile::tempdir()?; + let path = dir.path().join("test.csv"); + std::fs::write(&path, csv_content)?; + extract_timing_profiles(&path) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_inline_csv() { + let csv = "\ +Flow Duration,Total Fwd Packets,Flow Bytes/s,Flow IAT Mean,Flow IAT Std,Label +1000000,10,5000.0,100000,50000,BENIGN +2000000,20,10000.0,200000,100000,BENIGN +500000,100,50000.0,5000,2000,DDoS +300000,200,80000.0,3000,1000,DDoS +100000,50,30000.0,10000,5000,DDoS +"; + let profiles = extract_timing_profiles_from_str(csv).unwrap(); + assert_eq!(profiles.len(), 2, "should have BENIGN and DDoS profiles"); + + let benign = profiles.iter().find(|p| p.attack_type == "BENIGN").unwrap(); + assert_eq!(benign.sample_count, 2); + // IAT mean: mean of [0.1, 0.2] = 0.15 seconds + assert!( + (benign.inter_arrival_mean - 0.15).abs() < 1e-6, + "benign iat mean: {}", + benign.inter_arrival_mean + ); + // Burst duration mean: mean of [1.0, 2.0] = 1.5 seconds + assert!( + (benign.burst_duration_mean - 1.5).abs() < 1e-6, + "benign burst mean: {}", + benign.burst_duration_mean + ); + + let ddos = profiles.iter().find(|p| p.attack_type == "DDoS").unwrap(); + assert_eq!(ddos.sample_count, 3); + // Flow bytes/s mean: mean of [50000, 80000, 30000] = 53333.33 + let expected_bps = (50000.0 + 80000.0 + 30000.0) / 3.0; + assert!( + (ddos.flow_bytes_per_sec_mean - expected_bps).abs() < 1.0, + "ddos bps mean: {}", + ddos.flow_bytes_per_sec_mean + ); + } + + #[test] + fn test_stats_accumulator() { + let mut acc = StatsAccumulator::default(); + acc.push(2.0); + acc.push(4.0); + acc.push(6.0); + assert!((acc.mean() - 4.0).abs() < 1e-10); + // std_dev of [2,4,6] = 2.0 + assert!((acc.std_dev() - 2.0).abs() < 1e-10); + } + + #[test] + fn test_empty_csv_dir() { + let dir = tempfile::tempdir().unwrap(); + let result = extract_timing_profiles(dir.path()); + assert!(result.is_err()); + } + + #[test] + fn test_find_column_case_insensitive() { + let headers: Vec = vec![ + " Flow Duration ".to_string(), + "label".to_string(), + ]; + assert_eq!(find_column(&headers, "Flow Duration"), Some(0)); + assert_eq!(find_column(&headers, "Label"), Some(1)); + assert_eq!(find_column(&headers, "LABEL"), Some(1)); + assert_eq!(find_column(&headers, "missing"), None); + } +} diff --git a/src/dataset/download.rs b/src/dataset/download.rs new file mode 100644 index 0000000..c36b8cb --- /dev/null +++ b/src/dataset/download.rs @@ -0,0 +1,119 @@ +//! Download and cache upstream datasets for training. +//! +//! Cached under `~/.cache/sunbeam//`. Files are only downloaded +//! once; subsequent runs reuse the cached copy. + +use anyhow::{Context, Result}; +use std::path::PathBuf; + +/// Base cache directory for all sunbeam datasets. +fn cache_base() -> PathBuf { + let base = std::env::var("XDG_CACHE_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| { + let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); + PathBuf::from(home).join(".cache") + }); + base.join("sunbeam") +} + +// --- CIC-IDS2017 --- + +/// Only the Friday DDoS file — contains DDoS Hulk, Slowloris, slowhttptest, GoldenEye. +const CICIDS_FILE: &str = "Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv"; + +/// Hugging Face mirror (public, no auth required). +const CICIDS_BASE_URL: &str = + "https://huggingface.co/datasets/c01dsnap/CIC-IDS2017/resolve/main"; + +fn cicids_cache_dir() -> PathBuf { + cache_base().join("cicids") +} + +/// Return the path to the cached CIC-IDS2017 DDoS CSV, or `None` if not downloaded. +pub fn cicids_cached_path() -> Option { + let path = cicids_cache_dir().join(CICIDS_FILE); + if path.exists() { + Some(path) + } else { + None + } +} + +/// Download the CIC-IDS2017 Friday DDoS CSV to cache. Returns the cached path. +pub fn download_cicids() -> Result { + let dir = cicids_cache_dir(); + let path = dir.join(CICIDS_FILE); + + if path.exists() { + eprintln!(" cached: {}", path.display()); + return Ok(path); + } + + let url = format!("{CICIDS_BASE_URL}/{CICIDS_FILE}"); + eprintln!(" downloading: {url}"); + eprintln!(" (this is ~170 MB, may take a minute)"); + + std::fs::create_dir_all(&dir)?; + + // Stream to file to avoid holding 170MB in memory. + let resp = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build()? + .get(&url) + .send() + .with_context(|| format!("fetching {url}"))? + .error_for_status() + .with_context(|| format!("HTTP error for {url}"))?; + + let mut file = std::fs::File::create(&path) + .with_context(|| format!("creating {}", path.display()))?; + let bytes = resp.bytes().with_context(|| "reading response body")?; + std::io::Write::write_all(&mut file, &bytes)?; + + eprintln!(" saved: {}", path.display()); + Ok(path) +} + +// --- CSIC 2010 --- + +/// Download CSIC 2010 dataset files to cache (delegates to scanner::csic). +pub fn download_csic() -> Result<()> { + if crate::scanner::csic::csic_is_cached() { + eprintln!(" cached: {}", crate::scanner::csic::csic_cache_path().display()); + return Ok(()); + } + // fetch_csic_dataset downloads, caches, and parses — we only need the download side-effect. + crate::scanner::csic::fetch_csic_dataset()?; + Ok(()) +} + +/// Download all upstream datasets. +pub fn download_all() -> Result<()> { + eprintln!("downloading upstream datasets...\n"); + + eprintln!("[1/2] CSIC 2010 (scanner training data)"); + download_csic()?; + eprintln!(); + + eprintln!("[2/2] CIC-IDS2017 DDoS timing profiles"); + let path = download_cicids()?; + eprintln!(" ok: {}\n", path.display()); + + eprintln!("all datasets cached."); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cache_paths() { + let base = cache_base(); + assert!(base.to_str().unwrap().contains("sunbeam")); + + let cicids = cicids_cache_dir(); + assert!(cicids.to_str().unwrap().contains("cicids")); + } +} diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs new file mode 100644 index 0000000..6018c9c --- /dev/null +++ b/src/dataset/mod.rs @@ -0,0 +1,6 @@ +pub mod sample; +pub mod modsec; +pub mod cicids; +pub mod synthetic; +pub mod download; +pub mod prepare; diff --git a/src/dataset/modsec.rs b/src/dataset/modsec.rs new file mode 100644 index 0000000..a35eb23 --- /dev/null +++ b/src/dataset/modsec.rs @@ -0,0 +1,340 @@ +//! Parser for OWASP ModSecurity audit log files (Serial / concurrent format). +//! +//! ModSecurity audit logs consist of multi-section entries delimited by boundary +//! markers like `--xxxxxxxx-A--`, `--xxxxxxxx-B--`, etc. Each section contains +//! different data about the transaction: +//! +//! - **A**: Timestamp, transaction ID, source/dest IP+port +//! - **B**: Request line + headers +//! - **C**: Request body +//! - **F**: Response status + headers +//! - **H**: Audit log trailer (rule matches, messages, actions) +//! +//! Any entry with a rule match in section H is labeled "attack"; entries with +//! no rule matches are labeled "normal". + +use crate::ddos::audit_log::AuditFields; +use anyhow::{Context, Result}; +use std::collections::HashMap; +use std::path::Path; + +/// Parse a ModSecurity audit log file and return `(AuditFields, label)` pairs. +/// +/// The label is `"attack"` if section H contains rule match messages, otherwise `"normal"`. +pub fn parse_modsec_audit_log(path: &Path) -> Result> { + let content = std::fs::read_to_string(path) + .with_context(|| format!("reading ModSec audit log: {}", path.display()))?; + parse_modsec_content(&content) +} + +/// Parse ModSecurity audit log content from a string. +fn parse_modsec_content(content: &str) -> Result> { + let mut results = Vec::new(); + + // Collect sections per boundary ID. + // Boundary markers look like: --xxxxxxxx-A-- + // where xxxxxxxx is a hex/alnum transaction ID and A is the section letter. + let mut sections: HashMap>> = HashMap::new(); + let mut current_id: Option = None; + let mut current_section: Option = None; + let mut current_lines: Vec = Vec::new(); + // Track order of first appearance of each boundary ID. + let mut id_order: Vec = Vec::new(); + + for line in content.lines() { + if let Some((id, section)) = parse_boundary(line) { + // Flush previous section. + if let (Some(ref cid), Some(sec)) = (¤t_id, current_section) { + let entry = sections.entry(cid.clone()).or_default(); + entry.entry(sec).or_default().extend(current_lines.drain(..)); + } + if !sections.contains_key(&id) { + id_order.push(id.clone()); + } + current_id = Some(id); + current_section = Some(section); + current_lines.clear(); + } else if current_id.is_some() { + current_lines.push(line.to_string()); + } + } + // Flush last section. + if let (Some(ref cid), Some(sec)) = (¤t_id, current_section) { + let entry = sections.entry(cid.clone()).or_default(); + entry.entry(sec).or_default().extend(current_lines.drain(..)); + } + + // Convert each transaction into AuditFields. + for id in &id_order { + if let Some(secs) = sections.get(id) { + if let Some(fields) = transaction_to_audit_fields(secs) { + results.push(fields); + } + } + } + + Ok(results) +} + +/// Try to parse a boundary marker line. +/// Returns `(boundary_id, section_letter)` on success. +fn parse_boundary(line: &str) -> Option<(String, char)> { + let trimmed = line.trim(); + if !trimmed.starts_with("--") || !trimmed.ends_with("--") { + return None; + } + // Strip leading and trailing -- + let inner = &trimmed[2..trimmed.len() - 2]; + // Should be: boundary_id-SECTION_LETTER + let dash_pos = inner.rfind('-')?; + if dash_pos == 0 || dash_pos == inner.len() - 1 { + return None; + } + let id = &inner[..dash_pos]; + let section_str = &inner[dash_pos + 1..]; + if section_str.len() != 1 { + return None; + } + let section = section_str.chars().next()?; + if !section.is_ascii_alphabetic() { + return None; + } + Some((id.to_string(), section)) +} + +/// Convert parsed sections into `(AuditFields, label)`. +fn transaction_to_audit_fields( + sections: &HashMap>, +) -> Option<(AuditFields, String)> { + // Section A: timestamp + connection info + let client_ip = sections + .get(&'A') + .and_then(|lines| { + // Section A first line typically: + // [dd/Mon/yyyy:HH:MM:SS +offset] transaction_id source_ip source_port dest_ip dest_port + lines.first().and_then(|line| { + // Find the content after the timestamp bracket. + let after_bracket = line.find(']').map(|i| &line[i + 1..])?; + let parts: Vec<&str> = after_bracket.split_whitespace().collect(); + // parts: [transaction_id, source_ip, source_port, dest_ip, dest_port] + if parts.len() >= 3 { + Some(parts[1].to_string()) + } else { + None + } + }) + }) + .unwrap_or_else(|| "0.0.0.0".to_string()); + + // Section B: request line + headers + let section_b = sections.get(&'B')?; + if section_b.is_empty() { + return None; + } + + // First non-empty line is the request line. + let request_line = section_b.iter().find(|l| !l.trim().is_empty())?; + let req_parts: Vec<&str> = request_line.splitn(3, ' ').collect(); + if req_parts.len() < 2 { + return None; + } + let method = req_parts[0].to_string(); + let raw_url = req_parts[1]; + + let (path, query) = if let Some(q) = raw_url.find('?') { + (raw_url[..q].to_string(), raw_url[q + 1..].to_string()) + } else { + (raw_url.to_string(), String::new()) + }; + + // Parse headers from remaining lines. + let mut headers: Vec<(String, String)> = Vec::new(); + for line in section_b.iter().skip(1) { + let trimmed = line.trim(); + if trimmed.is_empty() { + break; + } + if let Some(colon) = trimmed.find(':') { + let key = trimmed[..colon].trim().to_ascii_lowercase(); + let value = trimmed[colon + 1..].trim().to_string(); + headers.push((key, value)); + } + } + + let get_header = |name: &str| -> Option { + headers + .iter() + .find(|(k, _)| k == name) + .map(|(_, v)| v.clone()) + }; + + let host = get_header("host").unwrap_or_else(|| "unknown".to_string()); + let user_agent = get_header("user-agent").unwrap_or_else(|| "-".to_string()); + let has_cookies = get_header("cookie").is_some(); + let referer = get_header("referer").filter(|r| r != "-" && !r.is_empty()); + let accept_language = get_header("accept-language").filter(|a| a != "-" && !a.is_empty()); + let content_length: u64 = get_header("content-length") + .and_then(|v| v.parse().ok()) + .unwrap_or(0); + + // Section F: response status + let status = sections + .get(&'F') + .and_then(|lines| { + // First non-empty line: "HTTP/1.1 403 Forbidden" + lines.iter().find(|l| !l.trim().is_empty()).and_then(|line| { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 2 { + parts[1].parse::().ok() + } else { + None + } + }) + }) + .unwrap_or(0); + + // Section H: rule matches → determines label. + let has_rule_match = sections + .get(&'H') + .map(|lines| { + lines.iter().any(|l| { + let lower = l.to_ascii_lowercase(); + lower.contains("matched") || lower.contains("warning") || lower.contains("id:") + }) + }) + .unwrap_or(false); + + let label = if has_rule_match { "attack" } else { "normal" }.to_string(); + + let fields = AuditFields { + method, + host, + path, + query, + client_ip, + status, + duration_ms: 0, + content_length, + user_agent, + has_cookies: Some(has_cookies), + referer, + accept_language, + backend: "-".to_string(), + label: Some(label.clone()), + }; + + Some((fields, label)) +} + +/// Cache directory for ModSec data (mirrors the CSIC caching pattern). +#[allow(dead_code)] +fn cache_dir() -> std::path::PathBuf { + let base = std::env::var("XDG_CACHE_HOME") + .map(std::path::PathBuf::from) + .unwrap_or_else(|_| { + let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); + std::path::PathBuf::from(home).join(".cache") + }); + base.join("sunbeam").join("modsec") +} + +#[cfg(test)] +mod tests { + use super::*; + + const SAMPLE_AUDIT_LOG: &str = r#"--a1b2c3d4-A-- +[01/Jan/2026:12:00:00 +0000] XYZ123 192.168.1.100 54321 10.0.0.1 80 +--a1b2c3d4-B-- +GET /admin/config.php?debug=1 HTTP/1.1 +Host: example.com +User-Agent: curl/7.68.0 +Accept: */* +--a1b2c3d4-C-- +--a1b2c3d4-F-- +HTTP/1.1 403 Forbidden +Content-Type: text/html +--a1b2c3d4-H-- +Message: Warning. Matched "Operator `Rx' with parameter" [id "941100"] +Action: Intercepted (phase 2) +--a1b2c3d4-Z-- +--e5f6a7b8-A-- +[01/Jan/2026:12:00:01 +0000] ABC456 10.0.0.50 12345 10.0.0.1 80 +--e5f6a7b8-B-- +GET /index.html HTTP/1.1 +Host: example.com +User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120 +Accept: text/html +Accept-Language: en-US,en;q=0.9 +Cookie: session=abc123 +Referer: https://example.com/ +--e5f6a7b8-F-- +HTTP/1.1 200 OK +Content-Type: text/html +--e5f6a7b8-H-- +--e5f6a7b8-Z-- +"#; + + #[test] + fn test_parse_boundary() { + assert_eq!( + parse_boundary("--a1b2c3d4-A--"), + Some(("a1b2c3d4".to_string(), 'A')) + ); + assert_eq!( + parse_boundary("--a1b2c3d4-H--"), + Some(("a1b2c3d4".to_string(), 'H')) + ); + assert_eq!(parse_boundary("not a boundary"), None); + assert_eq!(parse_boundary("--invalid--"), None); + } + + #[test] + fn test_parse_modsec_audit_log_snippet() { + let results = parse_modsec_content(SAMPLE_AUDIT_LOG).unwrap(); + assert_eq!(results.len(), 2, "should parse two transactions"); + + // First entry: attack (has rule match in section H). + let (attack_fields, attack_label) = &results[0]; + assert_eq!(attack_label, "attack"); + assert_eq!(attack_fields.method, "GET"); + assert_eq!(attack_fields.path, "/admin/config.php"); + assert_eq!(attack_fields.query, "debug=1"); + assert_eq!(attack_fields.client_ip, "192.168.1.100"); + assert_eq!(attack_fields.user_agent, "curl/7.68.0"); + assert_eq!(attack_fields.status, 403); + assert!(!attack_fields.has_cookies.unwrap_or(true)); + + // Second entry: normal (no rule match). + let (normal_fields, normal_label) = &results[1]; + assert_eq!(normal_label, "normal"); + assert_eq!(normal_fields.method, "GET"); + assert_eq!(normal_fields.path, "/index.html"); + assert_eq!(normal_fields.client_ip, "10.0.0.50"); + assert_eq!(normal_fields.status, 200); + assert!(normal_fields.has_cookies.unwrap_or(false)); + assert!(normal_fields.referer.is_some()); + assert!(normal_fields.accept_language.is_some()); + } + + #[test] + fn test_empty_input() { + let results = parse_modsec_content("").unwrap(); + assert!(results.is_empty()); + } + + #[test] + fn test_single_entry_no_section_h() { + let content = r#"--abc123-A-- +[01/Jan/2026:12:00:00 +0000] TX1 1.2.3.4 1234 5.6.7.8 80 +--abc123-B-- +GET / HTTP/1.1 +Host: test.com +--abc123-F-- +HTTP/1.1 200 OK +--abc123-Z-- +"#; + let results = parse_modsec_content(content).unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].1, "normal"); + } +} diff --git a/src/dataset/prepare.rs b/src/dataset/prepare.rs new file mode 100644 index 0000000..c90d1cd --- /dev/null +++ b/src/dataset/prepare.rs @@ -0,0 +1,598 @@ +//! Dataset preparation orchestrator. +//! +//! Combines production logs, external datasets (CSIC, OWASP ModSec), and +//! synthetic data (CIC-IDS2017 timing profiles + wordlists) into a single +//! `DatasetManifest` serialized as bincode. + +use crate::dataset::sample::{DataSource, DatasetManifest, DatasetStats, TrainingSample}; +use crate::ddos::audit_log::{AuditFields, AuditLog}; +use crate::ddos::features::{method_to_u8, LogIpState}; +use crate::ddos::train::HeuristicThresholds; +use crate::scanner::features::{self, fx_hash_bytes}; + +use anyhow::{Context, Result}; +use rustc_hash::FxHashSet; +use std::collections::HashMap; +use std::io::BufRead; +use std::path::Path; + +/// Arguments for the `prepare-dataset` command. +pub struct PrepareDatasetArgs { + /// Path to production audit log JSONL file. + pub input: String, + /// Path to OWASP ModSecurity audit log file (optional extra data). + pub owasp: Option, + /// Path to wordlist directory (optional, enhances synthetic scanner). + pub wordlists: Option, + /// Output path for the bincode dataset file. + pub output: String, + /// RNG seed for synthetic generation. + pub seed: u64, + /// Path to heuristics.toml for auto-labeling production logs. + pub heuristics: Option, +} + +impl Default for PrepareDatasetArgs { + fn default() -> Self { + Self { + input: String::new(), + owasp: None, + wordlists: None, + output: "dataset.bin".to_string(), + seed: 42, + heuristics: None, + } + } +} + +pub fn run(args: PrepareDatasetArgs) -> Result<()> { + let mut scanner_samples: Vec = Vec::new(); + let mut ddos_samples: Vec = Vec::new(); + + // --- 0. Download upstream datasets if not cached --- + crate::dataset::download::download_all()?; + + // --- 1. Parse production logs --- + let heuristics = if let Some(h_path) = &args.heuristics { + let content = std::fs::read_to_string(h_path) + .with_context(|| format!("reading heuristics from {h_path}"))?; + toml::from_str::(&content) + .with_context(|| format!("parsing heuristics from {h_path}"))? + } else { + HeuristicThresholds::new(10.0, 0.85, 0.7, 0.3, 0.05, 20.0, 10) + }; + eprintln!("parsing production logs from {}...", args.input); + let (prod_scanner, prod_ddos) = parse_production_logs(&args.input, &heuristics)?; + eprintln!( + " production: {} scanner, {} ddos samples", + prod_scanner.len(), + prod_ddos.len() + ); + scanner_samples.extend(prod_scanner); + ddos_samples.extend(prod_ddos); + + // --- 2. CSIC 2010 (scanner) --- + eprintln!("fetching CSIC 2010 dataset..."); + let csic_entries = crate::scanner::csic::fetch_csic_dataset()?; + let csic_samples = entries_to_scanner_samples(&csic_entries, DataSource::Csic2010, 0.8)?; + eprintln!(" CSIC: {} scanner samples", csic_samples.len()); + scanner_samples.extend(csic_samples); + + // --- 3. OWASP ModSec (scanner) --- + if let Some(owasp_path) = &args.owasp { + eprintln!("parsing OWASP ModSec audit log from {owasp_path}..."); + let modsec_entries = + crate::dataset::modsec::parse_modsec_audit_log(Path::new(owasp_path))?; + let entries_with_host: Vec<(AuditFields, String)> = modsec_entries + .into_iter() + .map(|(fields, _label)| { + let host_prefix = fields.host.split('.').next().unwrap_or("").to_string(); + (fields, host_prefix) + }) + .collect(); + let modsec_samples = + entries_to_scanner_samples(&entries_with_host, DataSource::OwaspModSec, 0.8)?; + eprintln!(" OWASP: {} scanner samples", modsec_samples.len()); + scanner_samples.extend(modsec_samples); + } + + // --- 4. CIC-IDS2017 timing profiles (from cache if downloaded) --- + let cicids_profiles = if let Some(cached_path) = crate::dataset::download::cicids_cached_path() + { + eprintln!("extracting CIC-IDS2017 timing profiles from cache..."); + let profiles = crate::dataset::cicids::extract_timing_profiles(&cached_path)?; + eprintln!(" extracted {} attack-type profiles", profiles.len()); + profiles + } else { + eprintln!(" CIC-IDS2017 not cached; using built-in DDoS distributions"); + eprintln!(" (run `download-datasets` first for real timing profiles)"); + Vec::new() + }; + + // --- 5. Synthetic data (both models, always generated) --- + eprintln!("generating synthetic samples..."); + let config = crate::dataset::synthetic::SyntheticConfig { + num_ddos_attack: 10000, + num_ddos_normal: 10000, + num_scanner_attack: 5000, + num_scanner_normal: 5000, + seed: args.seed, + }; + + // Synthetic DDoS (uses CIC-IDS2017 profiles if cached, fallback defaults otherwise). + let synthetic_ddos = + crate::dataset::synthetic::generate_ddos_samples(&cicids_profiles, &config); + eprintln!(" synthetic DDoS: {} samples", synthetic_ddos.len()); + ddos_samples.extend(synthetic_ddos); + + // Synthetic scanner (uses wordlists if provided, built-in patterns otherwise). + let synthetic_scanner = crate::dataset::synthetic::generate_scanner_samples( + args.wordlists.as_deref().map(Path::new), + None, + &config, + )?; + eprintln!(" synthetic scanner: {} samples", synthetic_scanner.len()); + scanner_samples.extend(synthetic_scanner); + + // --- 6. Compute stats --- + let mut samples_by_source: HashMap = HashMap::new(); + for s in scanner_samples.iter().chain(ddos_samples.iter()) { + *samples_by_source.entry(s.source.clone()).or_insert(0) += 1; + } + + let scanner_attack_count = scanner_samples.iter().filter(|s| s.label > 0.5).count(); + let ddos_attack_count = ddos_samples.iter().filter(|s| s.label > 0.5).count(); + + let stats = DatasetStats { + total_samples: scanner_samples.len() + ddos_samples.len(), + scanner_samples: scanner_samples.len(), + ddos_samples: ddos_samples.len(), + samples_by_source, + attack_ratio_scanner: if scanner_samples.is_empty() { + 0.0 + } else { + scanner_attack_count as f64 / scanner_samples.len() as f64 + }, + attack_ratio_ddos: if ddos_samples.is_empty() { + 0.0 + } else { + ddos_attack_count as f64 / ddos_samples.len() as f64 + }, + }; + + eprintln!("\n--- dataset stats ---"); + eprintln!("total samples: {}", stats.total_samples); + eprintln!( + "scanner: {} ({} attack, {:.1}% attack ratio)", + stats.scanner_samples, + scanner_attack_count, + stats.attack_ratio_scanner * 100.0 + ); + eprintln!( + "ddos: {} ({} attack, {:.1}% attack ratio)", + stats.ddos_samples, + ddos_attack_count, + stats.attack_ratio_ddos * 100.0 + ); + for (source, count) in &stats.samples_by_source { + eprintln!(" {source}: {count}"); + } + + let manifest = DatasetManifest { + scanner_samples, + ddos_samples, + stats, + }; + + crate::dataset::sample::save_dataset(&manifest, Path::new(&args.output))?; + eprintln!("\ndataset saved to {}", args.output); + + Ok(()) +} + +/// Parse production JSONL logs and produce both scanner and DDoS training samples. +/// +/// When logs lack explicit labels, heuristic auto-labeling is applied: +/// - Scanner: attack if status >= 400 AND path matches scanner patterns; normal if +/// status < 400 AND has browser indicators (cookies + referer + accept-language). +/// - DDoS: per-IP feature thresholds from `HeuristicThresholds` (same logic as `ddos/train.rs`). +fn parse_production_logs( + input: &str, + heuristics: &HeuristicThresholds, +) -> Result<(Vec, Vec)> { + let file = std::fs::File::open(input) + .with_context(|| format!("opening {input}"))?; + let reader = std::io::BufReader::new(file); + + let mut scanner_samples = Vec::new(); + let mut parsed_entries: Vec<(AuditFields, String)> = Vec::new(); + let mut log_hosts: FxHashSet = FxHashSet::default(); + + // Build hashes for scanner feature extraction. + let fragment_hashes: FxHashSet = crate::scanner::train::DEFAULT_FRAGMENTS + .iter() + .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) + .collect(); + let extension_hashes: FxHashSet = features::SUSPICIOUS_EXTENSIONS_LIST + .iter() + .map(|e| fx_hash_bytes(e.as_bytes())) + .collect(); + + for line in reader.lines() { + let line = line?; + if line.trim().is_empty() { + continue; + } + let entry: AuditLog = match serde_json::from_str(&line) { + Ok(e) => e, + Err(_) => continue, + }; + let host_prefix = entry + .fields + .host + .split('.') + .next() + .unwrap_or("") + .to_string(); + log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes())); + parsed_entries.push((entry.fields, host_prefix)); + } + + // --- Scanner samples from production logs --- + for (fields, host_prefix) in &parsed_entries { + let has_cookies = fields.has_cookies.unwrap_or(false); + let has_referer = fields + .referer + .as_ref() + .map(|r| r != "-" && !r.is_empty()) + .unwrap_or(false); + let has_accept_language = fields + .accept_language + .as_ref() + .map(|a| a != "-" && !a.is_empty()) + .unwrap_or(false); + + let feats = features::extract_features( + &fields.method, + &fields.path, + host_prefix, + has_cookies, + has_referer, + has_accept_language, + "-", + &fields.user_agent, + fields.content_length, + &fragment_hashes, + &extension_hashes, + &log_hosts, + ); + + // Use ground-truth label if present, otherwise heuristic auto-label. + let label = match fields.label.as_deref() { + Some("attack" | "anomalous") => Some(1.0f32), + Some("normal") => Some(0.0f32), + _ => { + // Heuristic scanner labeling: + // Attack: 404+ AND suspicious path (excluding .git which is valid on Gitea hosts). + // Normal: success status AND browser indicators (cookies + referer + accept-language). + // Note: 401 is excluded since it's expected for private repos. + let status = fields.status; + let path_lower = fields.path.to_ascii_lowercase(); + let is_suspicious_path = path_lower.contains(".env") + || path_lower.contains("wp-login") + || path_lower.contains("wp-admin") + || path_lower.contains("phpmyadmin") + || path_lower.contains("cgi-bin") + || path_lower.contains("phpinfo") + || path_lower.contains("/shell") + || path_lower.contains("..%2f") + || path_lower.contains("../"); + if status >= 404 && is_suspicious_path { + Some(1.0f32) + } else if status < 400 && has_cookies && has_referer && has_accept_language { + Some(0.0f32) + } else { + None // ambiguous — skip + } + } + }; + + if let Some(l) = label { + scanner_samples.push(TrainingSample { + features: feats.iter().map(|&v| v as f32).collect(), + label: l, + source: DataSource::ProductionLogs, + weight: 1.0, + }); + } + } + + // --- DDoS samples from production logs --- + let ddos_samples = extract_ddos_samples_from_entries(&parsed_entries, heuristics)?; + + Ok((scanner_samples, ddos_samples)) +} + +/// Extract DDoS feature vectors from parsed log entries using sliding windows. +fn extract_ddos_samples_from_entries( + entries: &[(AuditFields, String)], + heuristics: &HeuristicThresholds, +) -> Result> { + use rustc_hash::FxHashMap; + use std::hash::{Hash, Hasher}; + + fn fx_hash(s: &str) -> u64 { + let mut h = rustc_hash::FxHasher::default(); + s.hash(&mut h); + h.finish() + } + + let mut ip_states: FxHashMap = FxHashMap::default(); + let mut ip_labels: FxHashMap> = FxHashMap::default(); + + for (fields, _host_prefix) in entries { + let ip = crate::ddos::audit_log::strip_port(&fields.client_ip).to_string(); + + let state = ip_states.entry(ip.clone()).or_default(); + // Use a simple counter as "timestamp" since we don't have parsed timestamps here. + let ts = state.timestamps.len() as f64; + state.timestamps.push(ts); + state.methods.push(method_to_u8(&fields.method)); + state.path_hashes.push(fx_hash(&fields.path)); + state.host_hashes.push(fx_hash(&fields.host)); + state + .user_agent_hashes + .push(fx_hash(&fields.user_agent)); + state.statuses.push(fields.status); + state + .durations + .push(fields.duration_ms.min(u32::MAX as u64) as u32); + state + .content_lengths + .push(fields.content_length.min(u32::MAX as u64) as u32); + state + .has_cookies + .push(fields.has_cookies.unwrap_or(false)); + state.has_referer.push( + fields + .referer + .as_deref() + .map(|r| r != "-") + .unwrap_or(false), + ); + state.has_accept_language.push( + fields + .accept_language + .as_deref() + .map(|a| a != "-") + .unwrap_or(false), + ); + state.suspicious_paths.push( + crate::ddos::features::is_suspicious_path(&fields.path), + ); + + // Track label per IP if available. + if let Some(ref label_str) = fields.label { + let label_val = match label_str.as_str() { + "attack" | "anomalous" => Some(1.0f32), + "normal" => Some(0.0f32), + _ => None, + }; + ip_labels.insert(ip, label_val); + } + } + + let min_events = heuristics.min_events.max(3); + let window_size = 50; // events per sliding window + let window_step = 25; // step size (50% overlap) + let window_secs = 60.0; + let mut samples = Vec::new(); + + for (ip, state) in &ip_states { + let n = state.timestamps.len(); + if n < min_events { + continue; + } + + // Ground-truth label if available. + let gt_label = match ip_labels.get(ip) { + Some(Some(l)) => Some(*l), + _ => None, + }; + + // Sliding windows: extract multiple feature vectors per IP. + // For IPs with fewer events than window_size, use a single window over all events. + let mut start = 0; + loop { + let end = (start + window_size).min(n); + if end - start < min_events { + break; + } + + let fv = state.extract_features_for_window(start, end, window_secs); + + let label = gt_label.unwrap_or_else(|| { + // Heuristic DDoS labeling (mirrors ddos/train.rs logic): + let is_attack = fv[0] > heuristics.request_rate + || fv[7] > heuristics.path_repetition + || fv[3] > heuristics.error_rate + || fv[13] > heuristics.suspicious_path_ratio + || (fv[10] < heuristics.no_cookies_threshold + && fv[1] > heuristics.no_cookies_path_count); + if is_attack { 1.0f32 } else { 0.0f32 } + }); + + samples.push(TrainingSample { + features: fv.iter().map(|&v| v as f32).collect(), + label, + source: DataSource::ProductionLogs, + weight: 1.0, + }); + + if end >= n { + break; + } + start += window_step; + } + } + + Ok(samples) +} + +/// Convert external dataset entries (CSIC, OWASP) into scanner TrainingSamples. +fn entries_to_scanner_samples( + entries: &[(AuditFields, String)], + source: DataSource, + weight: f32, +) -> Result> { + let fragment_hashes: FxHashSet = crate::scanner::train::DEFAULT_FRAGMENTS + .iter() + .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) + .collect(); + let extension_hashes: FxHashSet = features::SUSPICIOUS_EXTENSIONS_LIST + .iter() + .map(|e| fx_hash_bytes(e.as_bytes())) + .collect(); + + let mut log_hosts: FxHashSet = FxHashSet::default(); + for (_, host_prefix) in entries { + log_hosts.insert(fx_hash_bytes(host_prefix.as_bytes())); + } + + let mut samples = Vec::new(); + + for (fields, host_prefix) in entries { + let has_cookies = fields.has_cookies.unwrap_or(false); + let has_referer = fields + .referer + .as_ref() + .map(|r| r != "-" && !r.is_empty()) + .unwrap_or(false); + let has_accept_language = fields + .accept_language + .as_ref() + .map(|a| a != "-" && !a.is_empty()) + .unwrap_or(false); + + let feats = features::extract_features( + &fields.method, + &fields.path, + host_prefix, + has_cookies, + has_referer, + has_accept_language, + "-", + &fields.user_agent, + fields.content_length, + &fragment_hashes, + &extension_hashes, + &log_hosts, + ); + + let label = match fields.label.as_deref() { + Some("attack" | "anomalous") => 1.0f32, + Some("normal") => 0.0f32, + _ => continue, + }; + + samples.push(TrainingSample { + features: feats.iter().map(|&v| v as f32).collect(), + label, + source: source.clone(), + weight, + }); + } + + Ok(samples) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ddos::features::NUM_FEATURES; + use crate::scanner::features::NUM_SCANNER_FEATURES; + + #[test] + fn test_ddos_sample_feature_count() { + // Verify that DDoS samples from sliding windows produce 14 features. + let entries = vec![ + make_test_entry("GET", "/", "1.2.3.4", 200, "normal"), + make_test_entry("GET", "/about", "1.2.3.4", 200, "normal"), + make_test_entry("GET", "/contact", "1.2.3.4", 200, "normal"), + make_test_entry("GET", "/blog", "1.2.3.4", 200, "normal"), + make_test_entry("GET", "/faq", "1.2.3.4", 200, "normal"), + ]; + let heuristics = HeuristicThresholds::new(10.0, 0.85, 0.7, 0.3, 0.05, 20.0, 5); + let samples = extract_ddos_samples_from_entries(&entries, &heuristics).unwrap(); + // With only 5 entries at min_events=5, we should get 1 sample. + assert!(!samples.is_empty(), "should produce DDoS samples"); + for s in &samples { + assert_eq!( + s.features.len(), + NUM_FEATURES, + "DDoS sample should have {NUM_FEATURES} features" + ); + } + } + + #[test] + fn test_scanner_sample_feature_count() { + let entries = vec![ + make_test_entry("GET", "/.env", "1.2.3.4", 404, "attack"), + make_test_entry("GET", "/index.html", "5.6.7.8", 200, "normal"), + ]; + let samples = + entries_to_scanner_samples(&entries, DataSource::ProductionLogs, 1.0).unwrap(); + assert_eq!(samples.len(), 2); + for s in &samples { + assert_eq!( + s.features.len(), + NUM_SCANNER_FEATURES, + "scanner sample should have {NUM_SCANNER_FEATURES} features" + ); + } + } + + #[test] + fn test_entries_to_scanner_samples_labels() { + let entries = vec![ + make_test_entry("GET", "/.env", "1.2.3.4", 404, "attack"), + make_test_entry("GET", "/", "5.6.7.8", 200, "normal"), + make_test_entry("GET", "/page", "9.10.11.12", 200, "unknown_label"), + ]; + let samples = + entries_to_scanner_samples(&entries, DataSource::Csic2010, 0.8).unwrap(); + // "unknown_label" should be skipped. + assert_eq!(samples.len(), 2); + assert_eq!(samples[0].label, 1.0); + assert_eq!(samples[1].label, 0.0); + assert_eq!(samples[0].weight, 0.8); + assert_eq!(samples[0].source, DataSource::Csic2010); + } + + fn make_test_entry( + method: &str, + path: &str, + client_ip: &str, + status: u16, + label: &str, + ) -> (AuditFields, String) { + let fields = AuditFields { + method: method.to_string(), + host: "test.sunbeam.pt".to_string(), + path: path.to_string(), + query: String::new(), + client_ip: client_ip.to_string(), + status, + duration_ms: 10, + content_length: 0, + user_agent: "Mozilla/5.0".to_string(), + has_cookies: Some(true), + referer: Some("https://test.sunbeam.pt".to_string()), + accept_language: Some("en-US".to_string()), + backend: "test-svc:8080".to_string(), + label: Some(label.to_string()), + }; + (fields, "test".to_string()) + } +} diff --git a/src/dataset/sample.rs b/src/dataset/sample.rs new file mode 100644 index 0000000..844ad47 --- /dev/null +++ b/src/dataset/sample.rs @@ -0,0 +1,156 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; + +use anyhow::{Context, Result}; + +/// Provenance of a training sample. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum DataSource { + ProductionLogs, + Csic2010, + OwaspModSec, + SyntheticCicTiming, + SyntheticWordlist, +} + +impl std::fmt::Display for DataSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DataSource::ProductionLogs => write!(f, "ProductionLogs"), + DataSource::Csic2010 => write!(f, "Csic2010"), + DataSource::OwaspModSec => write!(f, "OwaspModSec"), + DataSource::SyntheticCicTiming => write!(f, "SyntheticCicTiming"), + DataSource::SyntheticWordlist => write!(f, "SyntheticWordlist"), + } + } +} + +/// A single labeled training sample with its feature vector, label, provenance, +/// and weight (used during training to down-weight synthetic/external data). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TrainingSample { + pub features: Vec, + /// 0.0 = normal, 1.0 = attack + pub label: f32, + pub source: DataSource, + /// Sample weight: 1.0 for production, 0.8 for external datasets, 0.5 for synthetic. + pub weight: f32, +} + +/// Aggregate statistics about a prepared dataset. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DatasetStats { + pub total_samples: usize, + pub scanner_samples: usize, + pub ddos_samples: usize, + pub samples_by_source: HashMap, + pub attack_ratio_scanner: f64, + pub attack_ratio_ddos: f64, +} + +/// The full serializable dataset: scanner samples, DDoS samples, and stats. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DatasetManifest { + pub scanner_samples: Vec, + pub ddos_samples: Vec, + pub stats: DatasetStats, +} + +/// Serialize a `DatasetManifest` to a bincode file. +pub fn save_dataset(manifest: &DatasetManifest, path: &Path) -> Result<()> { + let encoded = bincode::serialize(manifest).context("serializing dataset manifest")?; + std::fs::write(path, &encoded) + .with_context(|| format!("writing dataset to {}", path.display()))?; + Ok(()) +} + +/// Deserialize a `DatasetManifest` from a bincode file. +pub fn load_dataset(path: &Path) -> Result { + let data = std::fs::read(path) + .with_context(|| format!("reading dataset from {}", path.display()))?; + let manifest: DatasetManifest = + bincode::deserialize(&data).context("deserializing dataset manifest")?; + Ok(manifest) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_sample(features: Vec, label: f32, source: DataSource, weight: f32) -> TrainingSample { + TrainingSample { features, label, source, weight } + } + + #[test] + fn test_bincode_roundtrip() { + let manifest = DatasetManifest { + scanner_samples: vec![ + make_sample(vec![0.1, 0.2, 0.3], 0.0, DataSource::ProductionLogs, 1.0), + make_sample(vec![0.9, 0.8, 0.7], 1.0, DataSource::Csic2010, 0.8), + ], + ddos_samples: vec![ + make_sample(vec![0.5; 14], 1.0, DataSource::SyntheticCicTiming, 0.5), + make_sample(vec![0.1; 14], 0.0, DataSource::ProductionLogs, 1.0), + ], + stats: DatasetStats { + total_samples: 4, + scanner_samples: 2, + ddos_samples: 2, + samples_by_source: { + let mut m = HashMap::new(); + m.insert(DataSource::ProductionLogs, 2); + m.insert(DataSource::Csic2010, 1); + m.insert(DataSource::SyntheticCicTiming, 1); + m + }, + attack_ratio_scanner: 0.5, + attack_ratio_ddos: 0.5, + }, + }; + + let encoded = bincode::serialize(&manifest).unwrap(); + let decoded: DatasetManifest = bincode::deserialize(&encoded).unwrap(); + + assert_eq!(decoded.scanner_samples.len(), 2); + assert_eq!(decoded.ddos_samples.len(), 2); + assert_eq!(decoded.stats.total_samples, 4); + assert_eq!(decoded.stats.samples_by_source[&DataSource::ProductionLogs], 2); + + // Verify feature values survive the roundtrip. + assert!((decoded.scanner_samples[0].features[0] - 0.1).abs() < 1e-6); + assert_eq!(decoded.scanner_samples[1].label, 1.0); + assert_eq!(decoded.ddos_samples[0].source, DataSource::SyntheticCicTiming); + } + + #[test] + fn test_save_load_roundtrip() { + let manifest = DatasetManifest { + scanner_samples: vec![ + make_sample(vec![1.0, 2.0], 0.0, DataSource::ProductionLogs, 1.0), + ], + ddos_samples: vec![ + make_sample(vec![3.0, 4.0], 1.0, DataSource::OwaspModSec, 0.8), + ], + stats: DatasetStats { + total_samples: 2, + scanner_samples: 1, + ddos_samples: 1, + samples_by_source: HashMap::new(), + attack_ratio_scanner: 0.0, + attack_ratio_ddos: 1.0, + }, + }; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test_dataset.bin"); + + save_dataset(&manifest, &path).unwrap(); + let loaded = load_dataset(&path).unwrap(); + + assert_eq!(loaded.scanner_samples.len(), 1); + assert_eq!(loaded.ddos_samples.len(), 1); + assert!((loaded.scanner_samples[0].features[0] - 1.0).abs() < 1e-6); + assert_eq!(loaded.ddos_samples[0].source, DataSource::OwaspModSec); + } +} diff --git a/src/dataset/synthetic.rs b/src/dataset/synthetic.rs new file mode 100644 index 0000000..defdf50 --- /dev/null +++ b/src/dataset/synthetic.rs @@ -0,0 +1,558 @@ +//! Synthetic data generation for DDoS and scanner training samples. +//! +//! DDoS synthetic samples are generated by sampling from CIC-IDS2017 timing +//! profiles. Scanner synthetic samples are generated from wordlist paths +//! and realistic browser-like feature vectors. + +use crate::dataset::cicids::TimingProfile; +use crate::dataset::sample::{DataSource, TrainingSample}; +use crate::ddos::features::NUM_FEATURES; +use crate::scanner::features::NUM_SCANNER_FEATURES; + +use anyhow::Result; +use rand::prelude::*; +use rand::rngs::StdRng; +use std::io::BufRead; +use std::path::Path; + +/// Configuration for synthetic sample generation. +#[derive(Debug, Clone)] +pub struct SyntheticConfig { + pub num_ddos_attack: usize, + pub num_ddos_normal: usize, + pub num_scanner_attack: usize, + pub num_scanner_normal: usize, + pub seed: u64, +} + +impl Default for SyntheticConfig { + fn default() -> Self { + Self { + num_ddos_attack: 1000, + num_ddos_normal: 1000, + num_scanner_attack: 1000, + num_scanner_normal: 1000, + seed: 42, + } + } +} + +/// Generate synthetic DDoS training samples from CIC-IDS2017 timing profiles. +/// +/// Produces 14-dimensional feature vectors matching `NUM_FEATURES`: +/// 0: request_rate, 1: unique_paths, 2: unique_hosts, 3: error_rate, +/// 4: avg_duration_ms, 5: method_entropy, 6: burst_score, 7: path_repetition, +/// 8: avg_content_length, 9: unique_user_agents, 10: cookie_ratio, +/// 11: referer_ratio, 12: accept_language_ratio, 13: suspicious_path_ratio +pub fn generate_ddos_samples( + profiles: &[TimingProfile], + config: &SyntheticConfig, +) -> Vec { + let mut rng = StdRng::seed_from_u64(config.seed); + let mut samples = Vec::with_capacity(config.num_ddos_attack + config.num_ddos_normal); + + // Generate attack samples. + for _ in 0..config.num_ddos_attack { + let features = generate_ddos_attack_features(profiles, &mut rng); + samples.push(TrainingSample { + features, + label: 1.0, + source: DataSource::SyntheticCicTiming, + weight: 0.5, + }); + } + + // Generate normal samples. + for _ in 0..config.num_ddos_normal { + let features = generate_ddos_normal_features(&mut rng); + samples.push(TrainingSample { + features, + label: 0.0, + source: DataSource::SyntheticCicTiming, + weight: 0.5, + }); + } + + samples +} + +/// Generate a single DDoS attack feature vector, sampling from timing profiles. +fn generate_ddos_attack_features(profiles: &[TimingProfile], rng: &mut StdRng) -> Vec { + let mut features = vec![0.0f32; NUM_FEATURES]; + + // Sample from a random profile if available, otherwise use defaults. + let (iat_mean, burst_mean, bps_mean) = if !profiles.is_empty() { + let profile = &profiles[rng.random_range(0..profiles.len())]; + let iat = sample_positive_normal( + rng, + profile.inter_arrival_mean, + profile.inter_arrival_std, + ); + let burst = sample_positive_normal( + rng, + profile.burst_duration_mean, + profile.burst_duration_std, + ); + let bps = sample_positive_normal( + rng, + profile.flow_bytes_per_sec_mean, + profile.flow_bytes_per_sec_std, + ); + (iat, burst, bps) + } else { + // Fallback: aggressive attack defaults. + let iat = rng.random_range(0.001..0.05); + let burst = rng.random_range(0.5..5.0); + let bps = rng.random_range(10000.0..100000.0); + (iat, burst, bps) + }; + + // 0: request_rate — high for attacks. + features[0] = if iat_mean > 0.0 { + (1.0 / iat_mean).min(1000.0) as f32 + } else { + rng.random_range(50.0..500.0) as f32 + }; + + // 1: unique_paths — low to moderate (attackers repeat paths). + features[1] = rng.random_range(1.0..10.0) as f32; + + // 2: unique_hosts — typically 1 for DDoS. + features[2] = rng.random_range(1.0..3.0) as f32; + + // 3: error_rate — moderate to high. + features[3] = rng.random_range(0.3..0.9) as f32; + + // 4: avg_duration_ms — derived from burst duration. + features[4] = (burst_mean * 100.0).max(1.0).min(5000.0) as f32; + + // 5: method_entropy — low (mostly GET). + features[5] = rng.random_range(0.0..0.3) as f32; + + // 6: burst_score — high (inverse of inter-arrival). + features[6] = if iat_mean > 0.0 { + (1.0 / iat_mean).min(500.0) as f32 + } else { + rng.random_range(10.0..200.0) as f32 + }; + + // 7: path_repetition — high (attackers repeat same paths). + features[7] = rng.random_range(0.6..1.0) as f32; + + // 8: avg_content_length — derived from flow bytes. + features[8] = (bps_mean * 0.01).max(0.0).min(10000.0) as f32; + + // 9: unique_user_agents — low (1-2 UAs). + features[9] = rng.random_range(1.0..3.0) as f32; + + // 10: cookie_ratio — very low (bots don't send cookies). + features[10] = rng.random_range(0.0..0.1) as f32; + + // 11: referer_ratio — very low. + features[11] = rng.random_range(0.0..0.1) as f32; + + // 12: accept_language_ratio — very low. + features[12] = rng.random_range(0.0..0.1) as f32; + + // 13: suspicious_path_ratio — moderate to high. + features[13] = rng.random_range(0.1..0.7) as f32; + + features +} + +/// Generate a single DDoS normal feature vector. +fn generate_ddos_normal_features(rng: &mut StdRng) -> Vec { + let mut features = vec![0.0f32; NUM_FEATURES]; + + // 0: request_rate — moderate. + features[0] = rng.random_range(0.1..5.0) as f32; + // 1: unique_paths — moderate. + features[1] = rng.random_range(3.0..30.0) as f32; + // 2: unique_hosts — typically 1-3. + features[2] = rng.random_range(1.0..5.0) as f32; + // 3: error_rate — low. + features[3] = rng.random_range(0.0..0.15) as f32; + // 4: avg_duration_ms — reasonable. + features[4] = rng.random_range(10.0..500.0) as f32; + // 5: method_entropy — moderate (mix of GET/POST). + features[5] = rng.random_range(0.0..1.5) as f32; + // 6: burst_score — low. + features[6] = rng.random_range(0.05..2.0) as f32; + // 7: path_repetition — low to moderate. + features[7] = rng.random_range(0.05..0.5) as f32; + // 8: avg_content_length — reasonable. + features[8] = rng.random_range(0.0..2000.0) as f32; + // 9: unique_user_agents — 1-3 (real users have few UAs). + features[9] = rng.random_range(1.0..4.0) as f32; + // 10: cookie_ratio — high (real users have cookies). + features[10] = rng.random_range(0.7..1.0) as f32; + // 11: referer_ratio — moderate to high. + features[11] = rng.random_range(0.4..1.0) as f32; + // 12: accept_language_ratio — high. + features[12] = rng.random_range(0.7..1.0) as f32; + // 13: suspicious_path_ratio — very low. + features[13] = rng.random_range(0.0..0.05) as f32; + + features +} + +/// Generate synthetic scanner training samples. +/// +/// Attack samples use paths from wordlist files (if provided) or built-in +/// scanner paths. Normal samples have realistic browser-like feature vectors. +/// +/// Produces 12-dimensional feature vectors matching `NUM_SCANNER_FEATURES`: +/// 0: suspicious_path_score, 1: path_depth, 2: has_suspicious_extension, +/// 3: has_cookies, 4: has_referer, 5: has_accept_language, +/// 6: accept_quality, 7: ua_category, 8: method_is_unusual, +/// 9: host_is_configured, 10: content_length_mismatch, 11: path_has_traversal +pub fn generate_scanner_samples( + wordlist_dir: Option<&Path>, + _production_logs: Option<&Path>, + config: &SyntheticConfig, +) -> Result> { + let mut rng = StdRng::seed_from_u64(config.seed.wrapping_add(1)); + let mut samples = Vec::with_capacity(config.num_scanner_attack + config.num_scanner_normal); + + // Load wordlist paths if provided. + let attack_paths = load_wordlist_paths(wordlist_dir)?; + + // Generate attack samples. + for i in 0..config.num_scanner_attack { + let features = if !attack_paths.is_empty() { + let path = &attack_paths[i % attack_paths.len()]; + generate_scanner_attack_features_from_path(path, &mut rng) + } else { + generate_scanner_attack_features(&mut rng) + }; + samples.push(TrainingSample { + features, + label: 1.0, + source: DataSource::SyntheticWordlist, + weight: 0.5, + }); + } + + // Generate normal samples. + for _ in 0..config.num_scanner_normal { + let features = generate_scanner_normal_features(&mut rng); + samples.push(TrainingSample { + features, + label: 0.0, + source: DataSource::SyntheticWordlist, + weight: 0.5, + }); + } + + Ok(samples) +} + +/// Load paths from wordlist files in a directory (or single file). +fn load_wordlist_paths(dir: Option<&Path>) -> Result> { + let dir = match dir { + Some(d) => d, + None => return Ok(Vec::new()), + }; + + let files: Vec = if dir.is_file() { + vec![dir.to_path_buf()] + } else if dir.exists() { + std::fs::read_dir(dir)? + .filter_map(|e| e.ok()) + .map(|e| e.path()) + .filter(|p| p.extension().map(|e| e == "txt").unwrap_or(false)) + .collect() + } else { + return Ok(Vec::new()); + }; + + let mut paths = Vec::new(); + for file in &files { + let f = std::fs::File::open(file)?; + let reader = std::io::BufReader::new(f); + for line in reader.lines() { + let line = line?; + let line = line.trim().to_string(); + if line.is_empty() || line.starts_with('#') { + continue; + } + let normalized = if line.starts_with('/') { + line + } else { + format!("/{line}") + }; + paths.push(normalized); + } + } + Ok(paths) +} + +/// Generate scanner attack features from a specific path. +fn generate_scanner_attack_features_from_path(path: &str, rng: &mut StdRng) -> Vec { + let mut features = vec![0.0f32; NUM_SCANNER_FEATURES]; + + let lower = path.to_ascii_lowercase(); + + // 0: suspicious_path_score — check for known bad fragments. + let suspicious_frags = [ + ".env", "wp-admin", "wp-login", "phpinfo", "phpmyadmin", + ".git", "cgi-bin", "shell", "admin", "config", + ]; + let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect(); + let seg_count = segments.len().max(1) as f32; + let matches = segments + .iter() + .filter(|s| { + let sl = s.to_ascii_lowercase(); + suspicious_frags.iter().any(|f| sl.contains(f)) + }) + .count() as f32; + features[0] = matches / seg_count; + + // 1: path_depth. + features[1] = path.bytes().filter(|&b| b == b'/').count().min(20) as f32; + + // 2: has_suspicious_extension. + let suspicious_exts = [".php", ".env", ".sql", ".bak", ".asp", ".jsp", ".cgi", ".tar", ".zip", ".git"]; + features[2] = if suspicious_exts.iter().any(|e| lower.ends_with(e)) { + 1.0 + } else { + 0.0 + }; + + // 3: has_cookies — scanners typically don't. + features[3] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; + + // 4: has_referer — scanners rarely send referer. + features[4] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; + + // 5: has_accept_language — scanners rarely send this. + features[5] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; + + // 6: accept_quality — scanners use */* or empty. + features[6] = 0.0; + + // 7: ua_category — mix of curl/empty/bot UAs. + let ua_roll: f64 = rng.random(); + features[7] = if ua_roll < 0.3 { + 0.0 // empty + } else if ua_roll < 0.6 { + 0.25 // curl/wget + } else { + 0.5 // random bot + }; + + // 8: method_is_unusual — mostly GET. + features[8] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; + + // 9: host_is_configured — often unknown host. + features[9] = if rng.random_bool(0.2) { 1.0 } else { 0.0 }; + + // 10: content_length_mismatch. + features[10] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; + + // 11: path_has_traversal. + let traversal_patterns = ["..", "%00", "%0a", "%27", "%3c"]; + features[11] = if traversal_patterns.iter().any(|p| lower.contains(p)) { + 1.0 + } else { + 0.0 + }; + + features +} + +/// Generate scanner attack features without a specific path (built-in defaults). +fn generate_scanner_attack_features(rng: &mut StdRng) -> Vec { + let mut features = vec![0.0f32; NUM_SCANNER_FEATURES]; + + features[0] = rng.random_range(0.3..1.0) as f32; // suspicious_path_score + features[1] = rng.random_range(2.0..8.0) as f32; // path_depth + features[2] = if rng.random_bool(0.6) { 1.0 } else { 0.0 }; // suspicious_ext + features[3] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; // cookies + features[4] = if rng.random_bool(0.05) { 1.0 } else { 0.0 }; // referer + features[5] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // accept_language + features[6] = 0.0; // accept_quality + features[7] = rng.random_range(0.0..0.5) as f32; // ua_category + features[8] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // unusual method + features[9] = if rng.random_bool(0.2) { 1.0 } else { 0.0 }; // host_configured + features[10] = if rng.random_bool(0.1) { 1.0 } else { 0.0 }; // content_len mismatch + features[11] = if rng.random_bool(0.15) { 1.0 } else { 0.0 }; // traversal + + features +} + +/// Generate normal scanner features (browser-like). +fn generate_scanner_normal_features(rng: &mut StdRng) -> Vec { + let mut features = vec![0.0f32; NUM_SCANNER_FEATURES]; + + features[0] = 0.0; // suspicious_path_score — clean path + features[1] = rng.random_range(1.0..4.0) as f32; // path_depth — moderate + features[2] = 0.0; // no suspicious extension + features[3] = if rng.random_bool(0.85) { 1.0 } else { 0.0 }; // has_cookies — usually yes + features[4] = if rng.random_bool(0.7) { 1.0 } else { 0.0 }; // has_referer — often + features[5] = if rng.random_bool(0.9) { 1.0 } else { 0.0 }; // has_accept_language + features[6] = 1.0; // accept_quality — browsers send proper accept + features[7] = 1.0; // ua_category — browser UA + features[8] = 0.0; // method_is_unusual — GET/POST + features[9] = if rng.random_bool(0.95) { 1.0 } else { 0.0 }; // host_is_configured + features[10] = 0.0; // no content_length_mismatch + features[11] = 0.0; // no traversal + + features +} + +/// Sample from a normal distribution, clamped to positive values. +fn sample_positive_normal(rng: &mut StdRng, mean: f64, std_dev: f64) -> f64 { + if std_dev <= 0.0 { + return mean.max(0.0); + } + // Box-Muller transform. + let u1: f64 = rng.random::().max(1e-10); + let u2: f64 = rng.random(); + let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos(); + (mean + std_dev * z).max(0.0) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ddos_feature_dimensions() { + let profiles = vec![TimingProfile { + attack_type: "DDoS".to_string(), + inter_arrival_mean: 0.01, + inter_arrival_std: 0.005, + burst_duration_mean: 2.0, + burst_duration_std: 1.0, + flow_bytes_per_sec_mean: 50000.0, + flow_bytes_per_sec_std: 20000.0, + sample_count: 100, + }]; + + let config = SyntheticConfig { + num_ddos_attack: 10, + num_ddos_normal: 10, + seed: 42, + ..Default::default() + }; + + let samples = generate_ddos_samples(&profiles, &config); + assert_eq!(samples.len(), 20); + for s in &samples { + assert_eq!( + s.features.len(), + NUM_FEATURES, + "DDoS features should have {} dimensions", + NUM_FEATURES + ); + } + } + + #[test] + fn test_scanner_feature_dimensions() { + let config = SyntheticConfig { + num_scanner_attack: 10, + num_scanner_normal: 10, + seed: 42, + ..Default::default() + }; + + let samples = generate_scanner_samples(None, None, &config).unwrap(); + assert_eq!(samples.len(), 20); + for s in &samples { + assert_eq!( + s.features.len(), + NUM_SCANNER_FEATURES, + "scanner features should have {} dimensions", + NUM_SCANNER_FEATURES + ); + } + } + + #[test] + fn test_deterministic_with_fixed_seed() { + let config = SyntheticConfig { + num_ddos_attack: 5, + num_ddos_normal: 5, + num_scanner_attack: 5, + num_scanner_normal: 5, + seed: 123, + }; + + let ddos_a = generate_ddos_samples(&[], &config); + let ddos_b = generate_ddos_samples(&[], &config); + + for (a, b) in ddos_a.iter().zip(ddos_b.iter()) { + assert_eq!(a.features, b.features, "DDoS samples should be deterministic"); + assert_eq!(a.label, b.label); + } + + let scanner_a = generate_scanner_samples(None, None, &config).unwrap(); + let scanner_b = generate_scanner_samples(None, None, &config).unwrap(); + + for (a, b) in scanner_a.iter().zip(scanner_b.iter()) { + assert_eq!(a.features, b.features, "scanner samples should be deterministic"); + assert_eq!(a.label, b.label); + } + } + + #[test] + fn test_ddos_attack_vs_normal_distinguishable() { + let config = SyntheticConfig { + num_ddos_attack: 100, + num_ddos_normal: 100, + seed: 42, + ..Default::default() + }; + + let samples = generate_ddos_samples(&[], &config); + let attacks: Vec<&TrainingSample> = samples.iter().filter(|s| s.label > 0.5).collect(); + let normals: Vec<&TrainingSample> = samples.iter().filter(|s| s.label <= 0.5).collect(); + + // Attack samples should have higher avg request_rate (feature 0). + let avg_attack_rate: f32 = + attacks.iter().map(|s| s.features[0]).sum::() / attacks.len() as f32; + let avg_normal_rate: f32 = + normals.iter().map(|s| s.features[0]).sum::() / normals.len() as f32; + assert!( + avg_attack_rate > avg_normal_rate, + "attack rate {} should exceed normal rate {}", + avg_attack_rate, + avg_normal_rate + ); + + // Attack samples should have lower avg cookie_ratio (feature 10). + let avg_attack_cookies: f32 = + attacks.iter().map(|s| s.features[10]).sum::() / attacks.len() as f32; + let avg_normal_cookies: f32 = + normals.iter().map(|s| s.features[10]).sum::() / normals.len() as f32; + assert!( + avg_attack_cookies < avg_normal_cookies, + "attack cookies {} should be lower than normal cookies {}", + avg_attack_cookies, + avg_normal_cookies + ); + } + + #[test] + fn test_sample_positive_normal_stays_positive() { + let mut rng = StdRng::seed_from_u64(42); + for _ in 0..1000 { + let val = sample_positive_normal(&mut rng, 1.0, 5.0); + assert!(val >= 0.0, "value should be non-negative: {val}"); + } + } + + #[test] + fn test_scanner_attack_from_wordlist_path() { + let mut rng = StdRng::seed_from_u64(42); + let features = generate_scanner_attack_features_from_path("/.env", &mut rng); + assert_eq!(features.len(), NUM_SCANNER_FEATURES); + // .env should trigger suspicious_path_score > 0. + assert!(features[0] > 0.0, "suspicious_path_score should be positive for /.env"); + // .env should trigger suspicious_extension. + assert_eq!(features[2], 1.0, "should detect .env as suspicious extension"); + } +}