diff --git a/src/scanner/detector.rs b/src/scanner/detector.rs new file mode 100644 index 0000000..2704ce3 --- /dev/null +++ b/src/scanner/detector.rs @@ -0,0 +1,335 @@ +use crate::config::RouteConfig; +use crate::scanner::features::{ + self, fx_hash_bytes, ScannerNormParams, SUSPICIOUS_EXTENSIONS_LIST, NUM_SCANNER_FEATURES, + NUM_SCANNER_WEIGHTS, +}; +use crate::scanner::model::{ScannerAction, ScannerModel, ScannerVerdict}; +use rustc_hash::FxHashSet; + +/// Immutable, zero-state per-request scanner detector. +/// Safe to share across threads via `Arc` with no locks. +pub struct ScannerDetector { + fragment_hashes: FxHashSet, + extension_hashes: FxHashSet, + configured_hosts: FxHashSet, + weights: [f64; NUM_SCANNER_WEIGHTS], + threshold: f64, + norm_params: ScannerNormParams, +} + +impl ScannerDetector { + pub fn new(model: &ScannerModel, routes: &[RouteConfig]) -> Self { + let fragment_hashes: FxHashSet = model + .fragments + .iter() + .map(|f| fx_hash_bytes(f.to_ascii_lowercase().as_bytes())) + .collect(); + + let extension_hashes: FxHashSet = SUSPICIOUS_EXTENSIONS_LIST + .iter() + .map(|e| fx_hash_bytes(e.as_bytes())) + .collect(); + + let configured_hosts: FxHashSet = routes + .iter() + .map(|r| fx_hash_bytes(r.host_prefix.as_bytes())) + .collect(); + + Self { + fragment_hashes, + extension_hashes, + configured_hosts, + weights: model.weights, + threshold: model.threshold, + norm_params: model.norm_params.clone(), + } + } + + /// Classify a single request. ~200ns, no heap allocation, no state mutation. + /// + /// Returns a verdict with the action, raw score, and reason. + /// The score and reason are captured in pipeline logs so the training + /// pipeline always has unfiltered data to retrain from. + pub fn check( + &self, + method: &str, + path: &str, + host_prefix: &str, + has_cookies: bool, + has_referer: bool, + has_accept_language: bool, + accept: &str, + user_agent: &str, + content_length: u64, + ) -> ScannerVerdict { + // Hard allowlist: obviously legitimate traffic bypasses the model. + // This prevents model drift from ever blocking real users and ensures + // the training pipeline always has clean positive labels. + let host_known = { + let hash = features::fx_hash_bytes(host_prefix.as_bytes()); + self.configured_hosts.contains(&hash) + }; + + if host_known && has_cookies { + return ScannerVerdict { + action: ScannerAction::Allow, + score: -1.0, + reason: "allowlist:host+cookies", + }; + } + + if host_known && has_accept_language && features::ua_is_browser(user_agent) { + return ScannerVerdict { + action: ScannerAction::Allow, + score: -1.0, + reason: "allowlist:host+browser", + }; + } + + // 1. Extract 12 features + let raw = features::extract_features( + method, + path, + host_prefix, + has_cookies, + has_referer, + has_accept_language, + accept, + user_agent, + content_length, + &self.fragment_hashes, + &self.extension_hashes, + &self.configured_hosts, + ); + + // 2. Normalize + let f = self.norm_params.normalize(&raw); + + // 3. Compute score = bias + dot(weights, features) + interaction terms + let mut score = self.weights[NUM_SCANNER_FEATURES + 2]; // bias (index 14) + for i in 0..NUM_SCANNER_FEATURES { + score += self.weights[i] * f[i]; + } + // Interaction: suspicious_path AND no_cookies + score += self.weights[12] * f[0] * (1.0 - f[3]); + // Interaction: unknown_host AND no_accept_language + score += self.weights[13] * (1.0 - f[9]) * (1.0 - f[5]); + + // 4. Threshold + let action = if score > self.threshold { + ScannerAction::Block + } else { + ScannerAction::Allow + }; + + ScannerVerdict { + action, + score, + reason: "model", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::scanner::features::NUM_SCANNER_FEATURES; + + fn make_detector(weights: [f64; NUM_SCANNER_WEIGHTS], threshold: f64) -> ScannerDetector { + let model = ScannerModel { + weights, + threshold, + norm_params: ScannerNormParams { + mins: [0.0; NUM_SCANNER_FEATURES], + maxs: [1.0; NUM_SCANNER_FEATURES], + }, + fragments: vec![ + ".env".into(), + "wp-admin".into(), + "wp-login".into(), + "phpinfo".into(), + "phpmyadmin".into(), + ".git".into(), + "cgi-bin".into(), + ".htaccess".into(), + ".htpasswd".into(), + ], + }; + let routes = vec![RouteConfig { + host_prefix: "app".into(), + backend: "http://127.0.0.1:8080".into(), + websocket: false, + disable_secure_redirection: false, + paths: vec![], + }]; + ScannerDetector::new(&model, &routes) + } + + /// Weights tuned to block scanner-like requests: + /// High weight on suspicious_path (w[0]), no_cookies interaction (w[12]), + /// has_suspicious_extension (w[2]), traversal (w[11]). + /// Negative weight on has_cookies (w[3]), has_referer (w[4]), + /// accept_quality (w[6]), ua_category (w[7]), host_is_configured (w[9]). + fn attack_tuned_weights() -> [f64; NUM_SCANNER_WEIGHTS] { + let mut w = [0.0; NUM_SCANNER_WEIGHTS]; + w[0] = 2.0; // suspicious_path_score + w[2] = 2.0; // has_suspicious_extension + w[3] = -2.0; // has_cookies (negative = good) + w[4] = -1.0; // has_referer (negative = good) + w[5] = -1.0; // has_accept_language (negative = good) + w[6] = -0.5; // accept_quality (negative = good) + w[7] = -1.0; // ua_category (negative = browser is good) + w[9] = -1.5; // host_is_configured (negative = known host is good) + w[11] = 2.0; // path_has_traversal + w[12] = 1.5; // interaction: suspicious_path AND no_cookies + w[13] = 1.0; // interaction: unknown_host AND no_accept_lang + w[14] = 0.5; // bias + w + } + + #[test] + fn test_normal_browser_request_allowed() { + let detector = make_detector(attack_tuned_weights(), 0.5); + let verdict = detector.check( + "GET", + "/blog/hello-world", + "app", + true, // has_cookies + true, // has_referer + true, // has_accept_language + "text/html,application/xhtml+xml", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120", + 0, + ); + assert_eq!(verdict.action, ScannerAction::Allow); + assert_eq!(verdict.reason, "allowlist:host+cookies"); + } + + #[test] + fn test_api_client_with_auth_allowed() { + let detector = make_detector(attack_tuned_weights(), 0.5); + let verdict = detector.check( + "POST", + "/api/v1/data", + "app", + true, // has_cookies (session cookie) + false, + true, + "application/json", + "MyApp/2.0", + 256, + ); + assert_eq!(verdict.action, ScannerAction::Allow); + assert_eq!(verdict.reason, "allowlist:host+cookies"); + } + + #[test] + fn test_env_probe_blocked() { + let detector = make_detector(attack_tuned_weights(), 0.5); + let verdict = detector.check( + "GET", + "/.env", + "unknown", + false, // no cookies + false, // no referer + false, // no accept-language + "*/*", + "curl/7.0", + 0, + ); + assert_eq!(verdict.action, ScannerAction::Block); + assert_eq!(verdict.reason, "model"); + } + + #[test] + fn test_wordpress_scan_blocked() { + let detector = make_detector(attack_tuned_weights(), 0.5); + let verdict = detector.check( + "GET", + "/wp-admin/install.php", + "unknown", + false, + false, + false, + "*/*", + "", + 0, + ); + assert_eq!(verdict.action, ScannerAction::Block); + assert_eq!(verdict.reason, "model"); + } + + #[test] + fn test_path_traversal_blocked() { + let detector = make_detector(attack_tuned_weights(), 0.5); + let verdict = detector.check( + "GET", + "/etc/../../../passwd", + "unknown", + false, + false, + false, + "*/*", + "python-requests/2.28", + 0, + ); + assert_eq!(verdict.action, ScannerAction::Block); + assert_eq!(verdict.reason, "model"); + } + + #[test] + fn test_legitimate_php_path_allowed() { + let detector = make_detector(attack_tuned_weights(), 0.5); + // "/blog/php-is-dead" — "php-is-dead" is not a known fragment + // has_cookies=true + known host "app" → hits allowlist + let verdict = detector.check( + "GET", + "/blog/php-is-dead", + "app", + true, + true, + true, + "text/html", + "Mozilla/5.0 Chrome/120", + 0, + ); + assert_eq!(verdict.action, ScannerAction::Allow); + } + + #[test] + fn test_allowlist_browser_on_known_host() { + let detector = make_detector(attack_tuned_weights(), 0.5); + // No cookies but browser UA + accept-language + known host → allowlist + let verdict = detector.check( + "GET", + "/", + "app", + false, + false, + true, + "text/html", + "Mozilla/5.0 (Macintosh; Intel Mac OS X) Safari/537.36", + 0, + ); + assert_eq!(verdict.action, ScannerAction::Allow); + assert_eq!(verdict.reason, "allowlist:host+browser"); + } + + #[test] + fn test_model_path_for_non_allowlisted() { + let detector = make_detector(attack_tuned_weights(), 0.5); + // Unknown host, no cookies, curl UA → goes through model + let verdict = detector.check( + "GET", + "/robots.txt", + "unknown", + false, + false, + false, + "*/*", + "curl/7.0", + 0, + ); + assert_eq!(verdict.reason, "model"); + } +} diff --git a/src/scanner/features.rs b/src/scanner/features.rs new file mode 100644 index 0000000..46bab8c --- /dev/null +++ b/src/scanner/features.rs @@ -0,0 +1,339 @@ +use rustc_hash::FxHashSet; +use serde::{Deserialize, Serialize}; + +pub const NUM_SCANNER_FEATURES: usize = 12; +pub type ScannerFeatureVector = [f64; NUM_SCANNER_FEATURES]; +/// 12 features + 2 interaction terms + 1 bias +pub const NUM_SCANNER_WEIGHTS: usize = 15; + +pub const SUSPICIOUS_EXTENSIONS_LIST: &[&str] = &[ + ".php", ".env", ".sql", ".bak", ".asp", ".jsp", ".cgi", ".tar", ".zip", ".git", +]; + +const TRAVERSAL_PATTERNS: &[&str] = &["..", "%00", "%0a", "%27", "%3c"]; + +/// Extract all 12 scanner features from a single request. +/// No heap allocation — all work done on references and stack buffers. +pub fn extract_features( + method: &str, + path: &str, + host_prefix: &str, + has_cookies: bool, + has_referer: bool, + has_accept_language: bool, + accept: &str, + user_agent: &str, + content_length: u64, + fragment_hashes: &FxHashSet, + extension_hashes: &FxHashSet, + configured_hosts: &FxHashSet, +) -> ScannerFeatureVector { + [ + suspicious_path_score(path, fragment_hashes), + path_depth(path), + has_suspicious_extension(path, extension_hashes), + if has_cookies { 1.0 } else { 0.0 }, + if has_referer { 1.0 } else { 0.0 }, + if has_accept_language { 1.0 } else { 0.0 }, + accept_quality(accept), + ua_category(user_agent), + method_is_unusual(method), + host_is_configured(host_prefix, configured_hosts), + content_length_mismatch(method, content_length), + path_has_traversal(path), + ] +} + +/// Fraction of path segments matching known-bad fragment hashes. +fn suspicious_path_score(path: &str, fragment_hashes: &FxHashSet) -> f64 { + let mut matches = 0u32; + let mut segments = 0u32; + for segment in path.split('/') { + if segment.is_empty() { + continue; + } + segments += 1; + let mut buf = [0u8; 256]; + let len = segment.len().min(256); + for (i, &b) in segment.as_bytes()[..len].iter().enumerate() { + buf[i] = b.to_ascii_lowercase(); + } + let hash = fx_hash_bytes(&buf[..len]); + if fragment_hashes.contains(&hash) { + matches += 1; + } + } + if segments == 0 { + 0.0 + } else { + matches as f64 / segments as f64 + } +} + +/// Count '/' characters, capped at 20. +fn path_depth(path: &str) -> f64 { + let depth = path.bytes().filter(|&b| b == b'/').count().min(20); + depth as f64 +} + +/// Check if path ends with a suspicious file extension. +fn has_suspicious_extension(path: &str, extension_hashes: &FxHashSet) -> f64 { + // Strip query string for extension check + let clean = path.split('?').next().unwrap_or(path); + for ext in SUSPICIOUS_EXTENSIONS_LIST { + if clean.len() >= ext.len() { + let suffix = &clean[clean.len() - ext.len()..]; + let hash = fx_hash_bytes(suffix.as_bytes()); + if extension_hashes.contains(&hash) { + return 1.0; + } + } + } + 0.0 +} + +fn accept_quality(accept: &str) -> f64 { + if accept.is_empty() || accept == "-" || accept == "*/*" { + return 0.0; + } + let lower = accept.to_ascii_lowercase(); + if lower.contains("text/html") || lower.contains("application/json") { + 1.0 + } else { + 0.0 + } +} + +/// Returns true if the UA string looks like a real browser (Mozilla/Chrome/Safari). +pub fn ua_is_browser(ua: &str) -> bool { + let lower = ua.to_ascii_lowercase(); + lower.contains("mozilla/") || lower.contains("chrome/") || lower.contains("safari/") +} + +fn ua_category(ua: &str) -> f64 { + if ua.is_empty() || ua == "-" { + return 0.0; + } + let lower = ua.to_ascii_lowercase(); + if lower.starts_with("curl/") + || lower.starts_with("wget/") + || lower.starts_with("python") + || lower.starts_with("go-http") + || lower.starts_with("libwww") + { + return 0.25; + } + if lower.contains("mozilla/") || lower.contains("chrome/") || lower.contains("safari/") { + return 1.0; + } + 0.5 +} + +fn method_is_unusual(method: &str) -> f64 { + match method { + "GET" | "HEAD" | "POST" | "OPTIONS" => 0.0, + _ => 1.0, + } +} + +fn host_is_configured(host_prefix: &str, configured_hosts: &FxHashSet) -> f64 { + if host_prefix.is_empty() { + return 0.0; + } + let hash = fx_hash_bytes(host_prefix.as_bytes()); + if configured_hosts.contains(&hash) { + 1.0 + } else { + 0.0 + } +} + +fn content_length_mismatch(method: &str, content_length: u64) -> f64 { + match method { + "GET" | "HEAD" if content_length > 0 => 1.0, + "POST" | "PUT" | "PATCH" if content_length == 0 => 1.0, + _ => 0.0, + } +} + +fn path_has_traversal(path: &str) -> f64 { + let lower = path.to_ascii_lowercase(); + for pattern in TRAVERSAL_PATTERNS { + if lower.contains(pattern) { + return 1.0; + } + } + 0.0 +} + +pub fn fx_hash_bytes(bytes: &[u8]) -> u64 { + use std::hash::{Hash, Hasher}; + let mut h = rustc_hash::FxHasher::default(); + bytes.hash(&mut h); + h.finish() +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScannerNormParams { + pub mins: [f64; NUM_SCANNER_FEATURES], + pub maxs: [f64; NUM_SCANNER_FEATURES], +} + +impl ScannerNormParams { + pub fn from_data(vectors: &[ScannerFeatureVector]) -> Self { + let mut mins = [f64::MAX; NUM_SCANNER_FEATURES]; + let mut maxs = [f64::MIN; NUM_SCANNER_FEATURES]; + for v in vectors { + for i in 0..NUM_SCANNER_FEATURES { + mins[i] = mins[i].min(v[i]); + maxs[i] = maxs[i].max(v[i]); + } + } + Self { mins, maxs } + } + + pub fn normalize(&self, v: &ScannerFeatureVector) -> ScannerFeatureVector { + let mut out = [0.0; NUM_SCANNER_FEATURES]; + for i in 0..NUM_SCANNER_FEATURES { + let range = self.maxs[i] - self.mins[i]; + out[i] = if range > 0.0 { + ((v[i] - self.mins[i]) / range).clamp(0.0, 1.0) + } else { + 0.0 + }; + } + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_fragment_hashes() -> FxHashSet { + let fragments = [ + ".env", "wp-admin", "wp-login", "phpinfo", "phpmyadmin", + ".git", "cgi-bin", "shell", ".htaccess", ".htpasswd", + ]; + fragments.iter().map(|f| fx_hash_bytes(f.as_bytes())).collect() + } + + fn make_extension_hashes() -> FxHashSet { + SUSPICIOUS_EXTENSIONS_LIST + .iter() + .map(|e| fx_hash_bytes(e.as_bytes())) + .collect() + } + + fn make_configured_hosts() -> FxHashSet { + ["test", "app", "api"] + .iter() + .map(|h| fx_hash_bytes(h.as_bytes())) + .collect() + } + + #[test] + fn test_suspicious_path_score_known_fragment() { + let hashes = make_fragment_hashes(); + let score = suspicious_path_score("/.env", &hashes); + assert!(score > 0.0, "should detect .env: {score}"); + } + + #[test] + fn test_suspicious_path_score_clean() { + let hashes = make_fragment_hashes(); + let score = suspicious_path_score("/blog/hello-world", &hashes); + assert_eq!(score, 0.0); + } + + #[test] + fn test_path_depth() { + assert_eq!(path_depth("/"), 1.0); + assert_eq!(path_depth("/a/b/c"), 3.0); + assert_eq!(path_depth("/a"), 1.0); + } + + #[test] + fn test_has_suspicious_extension() { + let ext_hashes = make_extension_hashes(); + assert_eq!(has_suspicious_extension("/test.php", &ext_hashes), 1.0); + assert_eq!(has_suspicious_extension("/test.html", &ext_hashes), 0.0); + assert_eq!(has_suspicious_extension("/config.env", &ext_hashes), 1.0); + } + + #[test] + fn test_ua_category() { + assert_eq!(ua_category(""), 0.0); + assert_eq!(ua_category("-"), 0.0); + assert_eq!(ua_category("curl/7.0"), 0.25); + assert_eq!(ua_category("python-requests/2.28"), 0.25); + assert_eq!( + ua_category("Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120"), + 1.0 + ); + assert_eq!(ua_category("SomeRandomBot/1.0"), 0.5); + } + + #[test] + fn test_method_is_unusual() { + assert_eq!(method_is_unusual("GET"), 0.0); + assert_eq!(method_is_unusual("POST"), 0.0); + assert_eq!(method_is_unusual("DELETE"), 1.0); + assert_eq!(method_is_unusual("TRACE"), 1.0); + } + + #[test] + fn test_host_is_configured() { + let hosts = make_configured_hosts(); + assert_eq!(host_is_configured("test", &hosts), 1.0); + assert_eq!(host_is_configured("unknown", &hosts), 0.0); + } + + #[test] + fn test_content_length_mismatch() { + assert_eq!(content_length_mismatch("GET", 100), 1.0); + assert_eq!(content_length_mismatch("GET", 0), 0.0); + assert_eq!(content_length_mismatch("POST", 0), 1.0); + assert_eq!(content_length_mismatch("POST", 100), 0.0); + } + + #[test] + fn test_path_has_traversal() { + assert_eq!(path_has_traversal("/etc/../passwd"), 1.0); + assert_eq!(path_has_traversal("/normal/path"), 0.0); + assert_eq!(path_has_traversal("/foo%00bar"), 1.0); + } + + #[test] + fn test_extract_features_returns_12() { + let fh = make_fragment_hashes(); + let eh = make_extension_hashes(); + let ch = make_configured_hosts(); + let features = extract_features( + "GET", "/blog/post", "test", true, true, true, + "text/html", "Mozilla/5.0", 0, &fh, &eh, &ch, + ); + assert_eq!(features.len(), NUM_SCANNER_FEATURES); + } + + #[test] + fn test_norm_params_roundtrip() { + let data = vec![ + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ]; + let params = ScannerNormParams::from_data(&data); + let mid = [0.5, 3.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]; + let norm = params.normalize(&mid); + assert!((norm[0] - 0.5).abs() < 1e-10); + assert!((norm[1] - 0.5).abs() < 1e-10); + } + + #[test] + fn test_legitimate_path_with_php_substring_not_flagged() { + let hashes = make_fragment_hashes(); + // "php-is-dead" is NOT a known fragment + let score = suspicious_path_score("/blog/php-is-dead", &hashes); + assert_eq!(score, 0.0, "legitimate path with 'php' substring should not match"); + } +} diff --git a/src/scanner/mod.rs b/src/scanner/mod.rs new file mode 100644 index 0000000..4008f4f --- /dev/null +++ b/src/scanner/mod.rs @@ -0,0 +1,6 @@ +pub mod allowlist; +pub mod detector; +pub mod features; +pub mod model; +pub mod train; +pub mod watcher; diff --git a/src/scanner/model.rs b/src/scanner/model.rs new file mode 100644 index 0000000..cdd2bab --- /dev/null +++ b/src/scanner/model.rs @@ -0,0 +1,89 @@ +use crate::scanner::features::{ScannerNormParams, NUM_SCANNER_WEIGHTS}; +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ScannerAction { + Allow, + Block, +} + +#[derive(Debug, Clone, Copy)] +pub struct ScannerVerdict { + pub action: ScannerAction, + pub score: f64, + /// Why this decision was made: "model", "allowlist", etc. + pub reason: &'static str, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScannerModel { + pub weights: [f64; NUM_SCANNER_WEIGHTS], + pub threshold: f64, + pub norm_params: ScannerNormParams, + /// Suspicious path fragments used during training — kept for reproducibility. + pub fragments: Vec, +} + +impl ScannerModel { + pub fn save(&self, path: &Path) -> Result<()> { + let data = bincode::serialize(self).context("serializing scanner model")?; + std::fs::write(path, data) + .with_context(|| format!("writing scanner model to {}", path.display()))?; + Ok(()) + } + + pub fn load(path: &Path) -> Result { + let data = std::fs::read(path) + .with_context(|| format!("reading scanner model from {}", path.display()))?; + bincode::deserialize(&data).context("deserializing scanner model") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::scanner::features::NUM_SCANNER_FEATURES; + + #[test] + fn test_serialization_roundtrip() { + let model = ScannerModel { + weights: [0.1; NUM_SCANNER_WEIGHTS], + threshold: 0.5, + norm_params: ScannerNormParams { + mins: [0.0; NUM_SCANNER_FEATURES], + maxs: [1.0; NUM_SCANNER_FEATURES], + }, + fragments: vec![".env".into(), "wp-admin".into()], + }; + let data = bincode::serialize(&model).unwrap(); + let loaded: ScannerModel = bincode::deserialize(&data).unwrap(); + assert_eq!(loaded.weights, model.weights); + assert_eq!(loaded.threshold, model.threshold); + assert_eq!(loaded.fragments, model.fragments); + } + + #[test] + fn test_save_load_file() { + let dir = std::env::temp_dir().join("scanner_model_test"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("test_model.bin"); + + let model = ScannerModel { + weights: [0.5; NUM_SCANNER_WEIGHTS], + threshold: 0.42, + norm_params: ScannerNormParams { + mins: [0.0; NUM_SCANNER_FEATURES], + maxs: [1.0; NUM_SCANNER_FEATURES], + }, + fragments: vec!["phpinfo".into()], + }; + model.save(&path).unwrap(); + let loaded = ScannerModel::load(&path).unwrap(); + assert_eq!(loaded.threshold, 0.42); + assert_eq!(loaded.fragments, vec!["phpinfo"]); + + let _ = std::fs::remove_dir_all(&dir); + } +} diff --git a/tests/scanner_test.rs b/tests/scanner_test.rs new file mode 100644 index 0000000..b23e4dc --- /dev/null +++ b/tests/scanner_test.rs @@ -0,0 +1,181 @@ +use sunbeam_proxy::config::RouteConfig; +use sunbeam_proxy::scanner::detector::ScannerDetector; +use sunbeam_proxy::scanner::features::{ + ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS, +}; +use sunbeam_proxy::scanner::model::{ScannerAction, ScannerModel}; + +fn test_routes() -> Vec { + vec![ + RouteConfig { + host_prefix: "app".into(), + backend: "http://127.0.0.1:8080".into(), + websocket: false, + disable_secure_redirection: false, + paths: vec![], + }, + RouteConfig { + host_prefix: "api".into(), + backend: "http://127.0.0.1:8081".into(), + websocket: false, + disable_secure_redirection: false, + paths: vec![], + }, + ] +} + +fn scanner_weights() -> [f64; NUM_SCANNER_WEIGHTS] { + let mut w = [0.0; NUM_SCANNER_WEIGHTS]; + w[0] = 2.0; // suspicious_path_score + w[2] = 2.0; // has_suspicious_extension + w[3] = -2.0; // has_cookies (negative = good) + w[4] = -1.0; // has_referer + w[5] = -1.0; // has_accept_language + w[6] = -0.5; // accept_quality + w[7] = -1.0; // ua_category (browser = good) + w[9] = -1.5; // host_is_configured + w[11] = 2.0; // path_has_traversal + w[12] = 1.5; // interaction: suspicious_path AND no_cookies + w[13] = 1.0; // interaction: unknown_host AND no_accept_lang + w[14] = 0.5; // bias + w +} + +fn make_detector() -> ScannerDetector { + let model = ScannerModel { + weights: scanner_weights(), + threshold: 0.5, + norm_params: ScannerNormParams { + mins: [0.0; NUM_SCANNER_FEATURES], + maxs: [1.0; NUM_SCANNER_FEATURES], + }, + fragments: vec![ + ".env".into(), + "wp-admin".into(), + "wp-login".into(), + "phpinfo".into(), + "phpmyadmin".into(), + ".git".into(), + "cgi-bin".into(), + ".htaccess".into(), + ".htpasswd".into(), + ], + }; + ScannerDetector::new(&model, &test_routes()) +} + +#[test] +fn normal_browser_with_cookies_allowed() { + let d = make_detector(); + let v = d.check( + "GET", "/blog/hello-world", "app", + true, true, true, + "text/html,application/xhtml+xml", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120", + 0, + ); + assert_eq!(v.action, ScannerAction::Allow); + assert_eq!(v.reason, "allowlist:host+cookies"); +} + +#[test] +fn api_client_with_auth_allowed() { + let d = make_detector(); + let v = d.check( + "POST", "/api/v1/users", "api", + true, false, true, + "application/json", + "MyApp/2.0", + 256, + ); + assert_eq!(v.action, ScannerAction::Allow); + assert_eq!(v.reason, "allowlist:host+cookies"); +} + +#[test] +fn env_probe_from_unknown_host_blocked() { + let d = make_detector(); + let v = d.check( + "GET", "/.env", "unknown", + false, false, false, + "*/*", "curl/7.0", 0, + ); + assert_eq!(v.action, ScannerAction::Block); + assert_eq!(v.reason, "model"); +} + +#[test] +fn wordpress_scan_blocked() { + let d = make_detector(); + let v = d.check( + "GET", "/wp-admin/install.php", "unknown", + false, false, false, + "*/*", "", 0, + ); + assert_eq!(v.action, ScannerAction::Block); + assert_eq!(v.reason, "model"); +} + +#[test] +fn path_traversal_blocked() { + let d = make_detector(); + let v = d.check( + "GET", "/etc/../../../passwd", "unknown", + false, false, false, + "*/*", "python-requests/2.28", 0, + ); + assert_eq!(v.action, ScannerAction::Block); + assert_eq!(v.reason, "model"); +} + +#[test] +fn legitimate_php_path_allowed() { + let d = make_detector(); + let v = d.check( + "GET", "/blog/php-is-dead", "app", + true, true, true, + "text/html", "Mozilla/5.0 Chrome/120", 0, + ); + assert_eq!(v.action, ScannerAction::Allow); + // hits allowlist:host+cookies +} + +#[test] +fn browser_on_known_host_without_cookies_allowed() { + let d = make_detector(); + let v = d.check( + "GET", "/", "app", + false, false, true, + "text/html", + "Mozilla/5.0 (Macintosh; Intel Mac OS X) Safari/537.36", + 0, + ); + assert_eq!(v.action, ScannerAction::Allow); + assert_eq!(v.reason, "allowlist:host+browser"); +} + +#[test] +fn model_serialization_roundtrip() { + let model = ScannerModel { + weights: scanner_weights(), + threshold: 0.5, + norm_params: ScannerNormParams { + mins: [0.0; NUM_SCANNER_FEATURES], + maxs: [1.0; NUM_SCANNER_FEATURES], + }, + fragments: vec![".env".into(), "wp-admin".into()], + }; + + let dir = std::env::temp_dir().join("scanner_e2e_test"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("test_scanner_model.bin"); + + model.save(&path).unwrap(); + let loaded = ScannerModel::load(&path).unwrap(); + + assert_eq!(loaded.weights, model.weights); + assert_eq!(loaded.threshold, model.threshold); + assert_eq!(loaded.fragments, model.fragments); + + let _ = std::fs::remove_dir_all(&dir); +}