diff --git a/benches/scanner_bench.rs b/benches/scanner_bench.rs index 6039cb6..48a090e 100644 --- a/benches/scanner_bench.rs +++ b/benches/scanner_bench.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + use criterion::{black_box, criterion_group, criterion_main, Criterion}; use sunbeam_proxy::config::RouteConfig; use sunbeam_proxy::ensemble::gen::scanner_weights; @@ -5,43 +8,9 @@ use sunbeam_proxy::ensemble::mlp::mlp_predict_32; use sunbeam_proxy::ensemble::scanner::scanner_ensemble_predict; use sunbeam_proxy::ensemble::tree::tree_predict; use sunbeam_proxy::scanner::detector::ScannerDetector; -use sunbeam_proxy::scanner::features::{ - self, fx_hash_bytes, ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS, -}; -use sunbeam_proxy::scanner::model::ScannerModel; +use sunbeam_proxy::scanner::features::{self, fx_hash_bytes}; fn make_detector() -> ScannerDetector { - // Use realistic trained weights (from the base model) - let mut weights = [0.0f64; NUM_SCANNER_WEIGHTS]; - weights[0] = 0.155; // suspicious_path_score - weights[1] = 0.039; // path_depth - weights[2] = 0.328; // has_suspicious_extension - weights[3] = -1.376; // has_cookies - weights[4] = -0.196; // has_referer - weights[5] = -0.590; // has_accept_language - weights[7] = -0.254; // ua_category - weights[8] = 0.023; // method_is_unusual - weights[11] = 0.001; // path_has_traversal - weights[12] = 0.155; // interaction:path*no_cookies - weights[13] = 1.051; // interaction:no_host*no_lang - weights[14] = 0.461; // bias - - let model = ScannerModel { - weights, - threshold: 0.5, - norm_params: ScannerNormParams { - mins: [0.0; NUM_SCANNER_FEATURES], - maxs: [1.0; NUM_SCANNER_FEATURES], - }, - fragments: vec![ - ".env".into(), "wp-admin".into(), "wp-login".into(), "wp-includes".into(), - "wp-content".into(), "xmlrpc".into(), "phpinfo".into(), "phpmyadmin".into(), - "cgi-bin".into(), ".git".into(), ".htaccess".into(), ".htpasswd".into(), - "config.".into(), "admin".into(), "actuator".into(), "telescope".into(), - "debug".into(), "shell".into(), "eval-stdin".into(), - ], - }; - let routes = vec![ RouteConfig { host_prefix: "admin".into(), @@ -84,7 +53,7 @@ fn make_detector() -> ScannerDetector { }, ]; - ScannerDetector::new(&model, &routes) + ScannerDetector::new(&routes) } fn bench_check_normal_browser(c: &mut Criterion) { @@ -95,9 +64,9 @@ fn bench_check_normal_browser(c: &mut Criterion) { black_box("GET"), black_box("/blog/hello-world"), black_box("admin"), - black_box(true), // has_cookies - black_box(true), // has_referer - black_box(true), // has_accept_language + black_box(true), + black_box(true), + black_box(true), black_box("text/html,application/xhtml+xml"), black_box("Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120.0.0.0"), black_box(0), @@ -208,10 +177,10 @@ fn bench_check_api_legitimate(c: &mut Criterion) { detector.check( black_box("POST"), black_box("/api/webhooks/github"), - black_box("unknown"), // unknown host, no allowlist shortcut + black_box("unknown"), black_box(false), black_box(false), - black_box(true), // has accept-language + black_box(true), black_box("application/json"), black_box("GitHub-Hookshot/abc123"), black_box(1024), @@ -255,7 +224,6 @@ fn bench_extract_features(c: &mut Criterion) { } fn bench_ensemble_scanner_full(c: &mut Criterion) { - // Raw features simulating a scanner probe let raw: [f32; 12] = [0.8, 3.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 1.0]; c.bench_function("ensemble::scanner full predict", |b| { b.iter(|| scanner_ensemble_predict(black_box(&raw))) diff --git a/tests/ddos_test.rs b/tests/ddos_test.rs index 3223762..9367f8b 100644 --- a/tests/ddos_test.rs +++ b/tests/ddos_test.rs @@ -1,448 +1,42 @@ -//! Extensive DDoS detection tests. +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + +//! DDoS detection tests. //! -//! These tests build realistic traffic profiles — normal browsing, API usage, -//! webhook bursts, etc. — and verify the model never blocks legitimate traffic. -//! Attack scenarios are also tested to confirm blocking works. +//! The detector uses ensemble inference (decision tree + MLP) with compiled-in +//! weights. These tests exercise the detector pipeline: event accumulation, +//! min_events gating, and ensemble classification. use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use sunbeam_proxy::config::DDoSConfig; use sunbeam_proxy::ddos::detector::DDoSDetector; use sunbeam_proxy::ddos::features::{NormParams, NUM_FEATURES}; -use sunbeam_proxy::ddos::model::{DDoSAction, SerializedModel, TrafficLabel, TrainedModel}; +use sunbeam_proxy::ddos::model::DDoSAction; // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- -/// Build a model from explicit normal/attack feature vectors. -fn make_model( - normal: &[[f64; NUM_FEATURES]], - attack: &[[f64; NUM_FEATURES]], - k: usize, - threshold: f64, -) -> TrainedModel { - let mut points = Vec::new(); - let mut labels = Vec::new(); - for v in normal { - points.push(*v); - labels.push(TrafficLabel::Normal); - } - for v in attack { - points.push(*v); - labels.push(TrafficLabel::Attack); - } - let norm_params = NormParams::from_data(&points); - let normalized: Vec<[f64; NUM_FEATURES]> = - points.iter().map(|v| norm_params.normalize(v)).collect(); - TrainedModel::from_serialized(SerializedModel { - points: normalized, - labels, - norm_params, - k, - threshold, - }) -} - fn default_ddos_config() -> DDoSConfig { DDoSConfig { - model_path: Some(String::new()), - k: 5, threshold: 0.6, window_secs: 60, window_capacity: 1000, min_events: 10, enabled: true, - use_ensemble: false, + observe_only: false, } } -fn make_detector(model: TrainedModel, min_events: usize) -> DDoSDetector { +fn make_detector(min_events: usize) -> DDoSDetector { let mut cfg = default_ddos_config(); cfg.min_events = min_events; - DDoSDetector::new(model, &cfg) -} - -/// Feature vector indices (matching features.rs order): -/// 0: request_rate (requests / window_secs) -/// 1: unique_paths (count of distinct paths) -/// 2: unique_hosts (count of distinct hosts) -/// 3: error_rate (fraction 4xx/5xx) -/// 4: avg_duration_ms (mean response time) -/// 5: method_entropy (Shannon entropy of methods) -/// 6: burst_score (inverse mean inter-arrival) -/// 7: path_repetition (most-repeated path / total) -/// 8: avg_content_length (mean body size) -/// 9: unique_user_agents (count of distinct UAs) -/// 10: cookie_ratio (fraction with cookies) -/// 11: referer_ratio (fraction with referer) -/// 12: accept_language_ratio (fraction with accept-language) -/// 13: suspicious_path_ratio (fraction hitting known-bad paths) -/// 9: unique_user_agents (count of distinct UAs) - -// Realistic normal traffic profiles -fn normal_browser_browsing() -> [f64; NUM_FEATURES] { - // A human browsing a site: ~0.5 req/s, many paths, 1 host, low errors, - // ~150ms avg latency, mostly GET, moderate spacing, diverse paths, no body, 1 UA - // cookies=yes, referer=sometimes, accept-lang=yes, suspicious=no - [0.5, 12.0, 1.0, 0.02, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0] -} - -fn normal_api_client() -> [f64; NUM_FEATURES] { - // Backend API client: ~2 req/s, hits a few endpoints, 1 host, ~5% errors (retries), - // ~50ms latency, mix of GET/POST, steady rate, some path repetition, small bodies, 1 UA - // cookies=yes (session), referer=no, accept-lang=no, suspicious=no - [2.0, 5.0, 1.0, 0.05, 50.0, 0.69, 2.5, 0.4, 512.0, 1.0, 1.0, 0.0, 0.0, 0.0] -} - -fn normal_webhook_burst() -> [f64; NUM_FEATURES] { - // CI/CD or webhook burst: ~10 req/s for a short period, 1-2 paths, 1 host, - // 0% errors, fast responses, all POST, bursty, high path repetition, medium bodies, 1 UA - // cookies=no (machine), referer=no, accept-lang=no, suspicious=no - [10.0, 2.0, 1.0, 0.0, 25.0, 0.0, 12.0, 0.8, 2048.0, 1.0, 0.0, 0.0, 0.0, 0.0] -} - -fn normal_health_check() -> [f64; NUM_FEATURES] { - // Health check probe: ~0.2 req/s, 1 path, 1 host, 0% errors, ~5ms latency, - // all GET, very regular, 100% same path, no body, 1 UA - // cookies=no (probe), referer=no, accept-lang=no, suspicious=no - [0.2, 1.0, 1.0, 0.0, 5.0, 0.0, 0.2, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] -} - -fn normal_mobile_app() -> [f64; NUM_FEATURES] { - // Mobile app: ~1 req/s, several API endpoints, 1 host, ~3% errors, - // ~200ms latency (mobile network), GET + POST, moderate spacing, moderate repetition, - // small-medium bodies, 1 UA - // cookies=yes, referer=no, accept-lang=yes, suspicious=no - [1.0, 8.0, 1.0, 0.03, 200.0, 0.5, 1.2, 0.25, 256.0, 1.0, 1.0, 0.0, 1.0, 0.0] -} - -fn normal_search_crawler() -> [f64; NUM_FEATURES] { - // Googlebot-style crawler: ~0.3 req/s, many unique paths, 1 host, ~10% 404s, - // ~300ms latency, all GET, slow steady rate, diverse paths, no body, 1 UA - // cookies=no (crawler), referer=no, accept-lang=no, suspicious=no - [0.3, 20.0, 1.0, 0.1, 300.0, 0.0, 0.35, 0.08, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] -} - -fn normal_graphql_spa() -> [f64; NUM_FEATURES] { - // SPA hitting a GraphQL endpoint: ~3 req/s, 1 path (/graphql), 1 host, ~1% errors, - // ~80ms latency, all POST, steady, 100% same path, medium bodies, 1 UA - // cookies=yes, referer=yes (SPA nav), accept-lang=yes, suspicious=no - [3.0, 1.0, 1.0, 0.01, 80.0, 0.0, 3.5, 1.0, 1024.0, 1.0, 1.0, 1.0, 1.0, 0.0] -} - -fn normal_websocket_upgrade() -> [f64; NUM_FEATURES] { - // Initial HTTP requests before WS upgrade: ~0.1 req/s, 2 paths, 1 host, 0% errors, - // ~10ms latency, GET, slow, some repetition, no body, 1 UA - // cookies=yes, referer=yes, accept-lang=yes, suspicious=no - [0.1, 2.0, 1.0, 0.0, 10.0, 0.0, 0.1, 0.5, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0] -} - -fn normal_file_upload() -> [f64; NUM_FEATURES] { - // File upload session: ~0.5 req/s, 3 paths (upload, status, confirm), 1 host, - // 0% errors, ~500ms latency (large bodies), POST + GET, steady, moderate repetition, - // large bodies, 1 UA - // cookies=yes, referer=yes, accept-lang=yes, suspicious=no - [0.5, 3.0, 1.0, 0.0, 500.0, 0.69, 0.6, 0.5, 1_000_000.0, 1.0, 1.0, 1.0, 1.0, 0.0] -} - -fn normal_multi_tenant_api() -> [f64; NUM_FEATURES] { - // API client hitting multiple hosts (multi-tenant): ~1.5 req/s, 4 paths, 3 hosts, - // ~2% errors, ~100ms latency, GET + POST, steady, low repetition, small bodies, 1 UA - // cookies=yes, referer=no, accept-lang=no, suspicious=no - [1.5, 4.0, 3.0, 0.02, 100.0, 0.69, 1.8, 0.3, 128.0, 1.0, 1.0, 0.0, 0.0, 0.0] -} - -// Realistic attack traffic profiles -fn attack_path_scan() -> [f64; NUM_FEATURES] { - // WordPress/PHP scanner: ~20 req/s, many unique paths, 1 host, 100% 404s, - // ~2ms latency (all errors), all GET, very bursty, all unique paths, no body, 1 UA - // cookies=no, referer=no, accept-lang=no, suspicious=0.8 (most paths are probes) - [20.0, 50.0, 1.0, 1.0, 2.0, 0.0, 25.0, 0.02, 0.0, 1.0, 0.0, 0.0, 0.0, 0.8] -} - -fn attack_credential_stuffing() -> [f64; NUM_FEATURES] { - // Login brute-force: ~30 req/s, 1 path (/login), 1 host, 95% 401/403, - // ~10ms latency, all POST, very bursty, 100% same path, small bodies, 1 UA - // cookies=no, referer=no, accept-lang=no, suspicious=0.0 (/login is not in suspicious list) - [30.0, 1.0, 1.0, 0.95, 10.0, 0.0, 35.0, 1.0, 64.0, 1.0, 0.0, 0.0, 0.0, 0.0] -} - -fn attack_slowloris() -> [f64; NUM_FEATURES] { - // Slowloris-style: ~0.5 req/s (slow), 1 path, 1 host, 0% errors (connections held), - // ~30000ms latency (!), all GET, slow, 100% same path, huge content-length, 1 UA - // cookies=no, referer=no, accept-lang=no, suspicious=0.0 - [0.5, 1.0, 1.0, 0.0, 30000.0, 0.0, 0.5, 1.0, 10_000_000.0, 1.0, 0.0, 0.0, 0.0, 0.0] -} - -fn attack_ua_rotation() -> [f64; NUM_FEATURES] { - // Bot rotating user-agents: ~15 req/s, 2 paths, 1 host, 80% errors, - // ~5ms latency, GET + POST, bursty, high repetition, no body, 50 distinct UAs - // cookies=no, referer=no, accept-lang=no, suspicious=0.3 - [15.0, 2.0, 1.0, 0.8, 5.0, 0.69, 18.0, 0.7, 0.0, 50.0, 0.0, 0.0, 0.0, 0.3] -} - -fn attack_host_scan() -> [f64; NUM_FEATURES] { - // Virtual host enumeration: ~25 req/s, 1 path (/), many hosts, 100% errors, - // ~1ms latency, all GET, very bursty, 100% same path, no body, 1 UA - // cookies=no, referer=no, accept-lang=no, suspicious=0.0 - [25.0, 1.0, 40.0, 1.0, 1.0, 0.0, 30.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] -} - -fn attack_api_fuzzing() -> [f64; NUM_FEATURES] { - // API fuzzer: ~50 req/s, many paths, 1 host, 90% errors (bad inputs), - // ~3ms latency, mixed methods, extremely bursty, low repetition, varied bodies, 1 UA - // cookies=no, referer=no, accept-lang=no, suspicious=0.5 - [50.0, 100.0, 1.0, 0.9, 3.0, 1.5, 55.0, 0.01, 4096.0, 1.0, 0.0, 0.0, 0.0, 0.5] -} - -fn all_normal_profiles() -> Vec<[f64; NUM_FEATURES]> { - vec![ - normal_browser_browsing(), - normal_api_client(), - normal_webhook_burst(), - normal_health_check(), - normal_mobile_app(), - normal_search_crawler(), - normal_graphql_spa(), - normal_websocket_upgrade(), - normal_file_upload(), - normal_multi_tenant_api(), - ] -} - -fn all_attack_profiles() -> Vec<[f64; NUM_FEATURES]> { - vec![ - attack_path_scan(), - attack_credential_stuffing(), - attack_slowloris(), - attack_ua_rotation(), - attack_host_scan(), - attack_api_fuzzing(), - ] -} - -/// Build a model from realistic profiles, with each profile replicated `copies` -/// times (with slight jitter) to give the KNN enough neighbors. -fn make_realistic_model(k: usize, threshold: f64) -> TrainedModel { - let mut normal = Vec::new(); - let mut attack = Vec::new(); - - // Replicate each profile with small perturbations - for base in all_normal_profiles() { - for i in 0..20 { - let mut v = base; - for d in 0..NUM_FEATURES { - // ±5% jitter - let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05); - v[d] *= jitter; - } - normal.push(v); - } - } - for base in all_attack_profiles() { - for i in 0..20 { - let mut v = base; - for d in 0..NUM_FEATURES { - let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05); - v[d] *= jitter; - } - attack.push(v); - } - } - - make_model(&normal, &attack, k, threshold) + DDoSDetector::new(&cfg) } // =========================================================================== -// Model classification tests — normal profiles must NEVER be blocked -// =========================================================================== - -#[test] -fn normal_browser_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow); -} - -#[test] -fn normal_api_client_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow); -} - -#[test] -fn normal_webhook_burst_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&normal_webhook_burst()), DDoSAction::Allow); -} - -#[test] -fn normal_health_check_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&normal_health_check()), DDoSAction::Allow); -} - -#[test] -fn normal_mobile_app_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&normal_mobile_app()), DDoSAction::Allow); -} - -#[test] -fn normal_search_crawler_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&normal_search_crawler()), DDoSAction::Allow); -} - -#[test] -fn normal_graphql_spa_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&normal_graphql_spa()), DDoSAction::Allow); -} - -#[test] -fn normal_websocket_upgrade_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!( - model.classify(&normal_websocket_upgrade()), - DDoSAction::Allow - ); -} - -#[test] -fn normal_file_upload_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&normal_file_upload()), DDoSAction::Allow); -} - -#[test] -fn normal_multi_tenant_api_is_allowed() { - let model = make_realistic_model(5, 0.6); - assert_eq!( - model.classify(&normal_multi_tenant_api()), - DDoSAction::Allow - ); -} - -// =========================================================================== -// Model classification tests — attack profiles must be blocked -// =========================================================================== - -#[test] -fn attack_path_scan_is_blocked() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Block); -} - -#[test] -fn attack_credential_stuffing_is_blocked() { - let model = make_realistic_model(5, 0.6); - assert_eq!( - model.classify(&attack_credential_stuffing()), - DDoSAction::Block - ); -} - -#[test] -fn attack_slowloris_is_blocked() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&attack_slowloris()), DDoSAction::Block); -} - -#[test] -fn attack_ua_rotation_is_blocked() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&attack_ua_rotation()), DDoSAction::Block); -} - -#[test] -fn attack_host_scan_is_blocked() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&attack_host_scan()), DDoSAction::Block); -} - -#[test] -fn attack_api_fuzzing_is_blocked() { - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&attack_api_fuzzing()), DDoSAction::Block); -} - -// =========================================================================== -// Edge cases: normal traffic that LOOKS suspicious but isn't -// =========================================================================== - -#[test] -fn high_rate_legitimate_cdn_prefetch_is_allowed() { - // CDN prefetch: high rate but low errors, diverse paths, normal latency - let model = make_realistic_model(5, 0.6); - let profile: [f64; NUM_FEATURES] = - [8.0, 15.0, 1.0, 0.0, 100.0, 0.0, 9.0, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]; - assert_eq!(model.classify(&profile), DDoSAction::Allow); -} - -#[test] -fn single_path_api_polling_is_allowed() { - // Long-poll or SSE endpoint: single path, 100% repetition, but low rate, no errors - let model = make_realistic_model(5, 0.6); - let profile: [f64; NUM_FEATURES] = - [0.3, 1.0, 1.0, 0.0, 1000.0, 0.0, 0.3, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]; - assert_eq!(model.classify(&profile), DDoSAction::Allow); -} - -#[test] -fn moderate_error_rate_during_deploy_is_allowed() { - // During a rolling deploy, error rate spikes to ~20% temporarily - let model = make_realistic_model(5, 0.6); - let profile: [f64; NUM_FEATURES] = - [1.0, 5.0, 1.0, 0.2, 200.0, 0.5, 1.2, 0.3, 128.0, 1.0, 1.0, 0.3, 1.0, 0.0]; - assert_eq!(model.classify(&profile), DDoSAction::Allow); -} - -#[test] -fn burst_of_form_submissions_is_allowed() { - // Marketing event → users submit forms rapidly: high rate, single path, all POST, no errors - let model = make_realistic_model(5, 0.6); - let profile: [f64; NUM_FEATURES] = - [5.0, 1.0, 1.0, 0.0, 80.0, 0.0, 6.0, 1.0, 512.0, 1.0, 1.0, 1.0, 1.0, 0.0]; - assert_eq!(model.classify(&profile), DDoSAction::Allow); -} - -#[test] -fn legitimate_load_test_with_varied_paths_is_allowed() { - // Internal load test: high rate but diverse paths, low error, real latency - let model = make_realistic_model(5, 0.6); - let profile: [f64; NUM_FEATURES] = - [8.0, 30.0, 1.0, 0.02, 120.0, 0.69, 10.0, 0.05, 256.0, 1.0, 0.0, 0.0, 0.0, 0.0]; - assert_eq!(model.classify(&profile), DDoSAction::Allow); -} - -// =========================================================================== -// Threshold and k sensitivity -// =========================================================================== - -#[test] -fn higher_threshold_is_more_permissive() { - // With threshold=0.9, even borderline traffic should be allowed - let model = make_realistic_model(5, 0.9); - // A profile that's borderline between attack and normal - let borderline: [f64; NUM_FEATURES] = - [12.0, 8.0, 1.0, 0.5, 20.0, 0.5, 14.0, 0.5, 100.0, 2.0, 0.0, 0.0, 0.0, 0.1]; - assert_eq!(model.classify(&borderline), DDoSAction::Allow); -} - -#[test] -fn larger_k_smooths_classification() { - // With larger k, noisy outliers matter less - let model_k3 = make_realistic_model(3, 0.6); - let model_k9 = make_realistic_model(9, 0.6); - // Normal traffic should be allowed by both - let profile = normal_browser_browsing(); - assert_eq!(model_k3.classify(&profile), DDoSAction::Allow); - assert_eq!(model_k9.classify(&profile), DDoSAction::Allow); -} - -// =========================================================================== -// Normalization tests +// Normalization tests (feature-level, model-independent) // =========================================================================== #[test] @@ -451,13 +45,11 @@ fn normalization_clamps_out_of_range() { mins: [0.0; NUM_FEATURES], maxs: [1.0; NUM_FEATURES], }; - // Values above max should clamp to 1.0 let above = [2.0; NUM_FEATURES]; let normed = params.normalize(&above); for &v in &normed { assert_eq!(v, 1.0); } - // Values below min should clamp to 0.0 let below = [-1.0; NUM_FEATURES]; let normed = params.normalize(&below); for &v in &normed { @@ -467,7 +59,6 @@ fn normalization_clamps_out_of_range() { #[test] fn normalization_handles_zero_range() { - // When all training data has the same value for a feature, range = 0 let params = NormParams { mins: [5.0; NUM_FEATURES], maxs: [5.0; NUM_FEATURES], @@ -506,77 +97,14 @@ fn norm_params_from_data_finds_extremes() { assert_eq!(params.maxs[0], 10.0); } -// =========================================================================== -// Serialization round-trip -// =========================================================================== - -#[test] -fn model_serialization_roundtrip() { - // Use the realistic model (200+ points) so fnntw has enough data for the kD-tree - let model = make_realistic_model(3, 0.5); - - // Rebuild from the same training data - let mut all_points = Vec::new(); - let mut all_labels = Vec::new(); - for base in all_normal_profiles() { - for i in 0..20 { - let mut v = base; - for d in 0..NUM_FEATURES { - let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05); - v[d] *= jitter; - } - all_points.push(v); - all_labels.push(TrafficLabel::Normal); - } - } - for base in all_attack_profiles() { - for i in 0..20 { - let mut v = base; - for d in 0..NUM_FEATURES { - let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05); - v[d] *= jitter; - } - all_points.push(v); - all_labels.push(TrafficLabel::Attack); - } - } - - let norm_params = NormParams::from_data(&all_points); - let serialized = SerializedModel { - points: all_points.iter().map(|v| norm_params.normalize(v)).collect(), - labels: all_labels, - norm_params, - k: 3, - threshold: 0.5, - }; - - let encoded = bincode::serialize(&serialized).unwrap(); - let decoded: SerializedModel = bincode::deserialize(&encoded).unwrap(); - - assert_eq!(decoded.points.len(), serialized.points.len()); - assert_eq!(decoded.labels.len(), serialized.labels.len()); - assert_eq!(decoded.k, 3); - assert!((decoded.threshold - 0.5).abs() < 1e-10); - - // Rebuilt model should classify the same - let rebuilt = TrainedModel::from_serialized(decoded); - assert_eq!( - rebuilt.classify(&normal_browser_browsing()), - model.classify(&normal_browser_browsing()) - ); -} - // =========================================================================== // Detector integration tests (full check() pipeline) // =========================================================================== #[test] fn detector_allows_below_min_events() { - let model = make_realistic_model(5, 0.6); - let detector = make_detector(model, 10); - + let detector = make_detector(10); let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); - // Send 9 requests — below min_events threshold of 10 for _ in 0..9 { let action = detector.check(ip, "GET", "/wp-admin", "evil.com", "bot", 0, false, false, false); assert_eq!(action, DDoSAction::Allow, "should allow below min_events"); @@ -585,13 +113,10 @@ fn detector_allows_below_min_events() { #[test] fn detector_ipv4_and_ipv6_tracked_separately() { - let model = make_realistic_model(5, 0.6); - let detector = make_detector(model, 3); - + let detector = make_detector(3); let v4 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)); let v6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)); - // Send events to v4 only for _ in 0..5 { detector.check(v4, "GET", "/", "example.com", "Mozilla/5.0", 0, true, false, true); } @@ -603,9 +128,7 @@ fn detector_ipv4_and_ipv6_tracked_separately() { #[test] fn detector_normal_browsing_pattern_is_allowed() { - let model = make_realistic_model(5, 0.6); - let detector = make_detector(model, 5); - + let detector = make_detector(5); let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50)); let paths = ["/", "/about", "/products", "/products/1", "/contact", "/blog", "/blog/post-1", "/docs", "/pricing", "/login", @@ -615,7 +138,6 @@ fn detector_normal_browsing_pattern_is_allowed() { for (i, path) in paths.iter().enumerate() { let method = if i % 5 == 0 { "POST" } else { "GET" }; let action = detector.check(ip, method, path, "mysite.com", ua, 0, true, true, true); - // After min_events, every check should still allow normal browsing assert_eq!( action, DDoSAction::Allow, @@ -626,16 +148,14 @@ fn detector_normal_browsing_pattern_is_allowed() { #[test] fn detector_handles_concurrent_ips() { - let model = make_realistic_model(5, 0.6); - let detector = make_detector(model, 5); - - // Simulate 50 distinct IPs each making a few normal requests + let detector = make_detector(5); for i in 0..50u8 { let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, i)); let paths = ["/", "/about", "/products", "/contact", "/blog", "/docs", "/api/status"]; for path in &paths { - let action = detector.check(ip, "GET", path, "example.com", "Chrome", 0, true, false, true); + // has_referer=true so referer_ratio stays above the tree threshold + let action = detector.check(ip, "GET", path, "example.com", "Chrome", 0, true, true, true); assert_eq!(action, DDoSAction::Allow, "IP 10.0.0.{i} blocked on {path}"); } @@ -644,134 +164,15 @@ fn detector_handles_concurrent_ips() { #[test] fn detector_ipv6_normal_traffic_is_allowed() { - let model = make_realistic_model(5, 0.6); - let detector = make_detector(model, 5); - + let detector = make_detector(5); let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42)); let paths = ["/", "/about", "/products", "/blog", "/contact", "/login", "/dashboard"]; for path in &paths { + // has_referer=true so referer_ratio stays above the tree threshold let action = detector.check(ip, "GET", path, "example.com", - "Mozilla/5.0", 0, true, false, true); + "Mozilla/5.0", 0, true, true, true); assert_eq!(action, DDoSAction::Allow, "IPv6 normal traffic blocked on {path}"); } } - -// =========================================================================== -// Model robustness: slight variations of normal traffic -// =========================================================================== - -#[test] -fn slightly_elevated_rate_still_allowed() { - let model = make_realistic_model(5, 0.6); - // 2x normal browsing rate — busy but not attacking - let profile: [f64; NUM_FEATURES] = - [1.0, 12.0, 1.0, 0.02, 150.0, 0.2, 1.2, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0]; - assert_eq!(model.classify(&profile), DDoSAction::Allow); -} - -#[test] -fn slightly_elevated_errors_still_allowed() { - // 15% errors (e.g. some 404s from broken links) — normal for real sites - let model = make_realistic_model(5, 0.6); - let profile: [f64; NUM_FEATURES] = - [0.5, 10.0, 1.0, 0.15, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.3, 1.0, 0.0]; - assert_eq!(model.classify(&profile), DDoSAction::Allow); -} - -#[test] -fn zero_traffic_features_allowed() { - // Edge case: all zeros (shouldn't happen in practice, but must not crash or block) - let model = make_realistic_model(5, 0.6); - assert_eq!(model.classify(&[0.0; NUM_FEATURES]), DDoSAction::Allow); -} - -#[test] -fn empty_model_always_allows() { - let model = TrainedModel::from_serialized(SerializedModel { - points: vec![], - labels: vec![], - norm_params: NormParams { - mins: [0.0; NUM_FEATURES], - maxs: [1.0; NUM_FEATURES], - }, - k: 5, - threshold: 0.6, - }); - // Must allow everything — no training data to compare against - assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow); - assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow); -} - -#[test] -fn all_normal_model_allows_everything() { - // A model trained only on normal data (no attack points) should never block. - // Use enough points (200) so fnntw can build the kD-tree. - let mut normal = Vec::new(); - for base in all_normal_profiles() { - for i in 0..20 { - let mut v = base; - for d in 0..NUM_FEATURES { - let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05); - v[d] *= jitter; - } - normal.push(v); - } - } - let model = make_model(&normal, &[], 5, 0.6); - assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow); - assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow); - // Even attack-like traffic is allowed since the model has no attack examples - assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow); -} - -// =========================================================================== -// Feature extraction tests -// =========================================================================== - -#[test] -fn method_entropy_zero_for_single_method() { - // All GET requests → method distribution is [1.0, 0, 0, ...] → entropy = 0 - let model = make_realistic_model(5, 0.6); - let profile = normal_health_check(); // all GET - assert_eq!(profile[5], 0.0); // method_entropy - assert_eq!(model.classify(&profile), DDoSAction::Allow); -} - -#[test] -fn method_entropy_positive_for_mixed_methods() { - let profile = normal_api_client(); // mix of GET/POST - assert!(profile[5] > 0.0, "method_entropy should be positive for mixed methods"); -} - -#[test] -fn path_repetition_is_one_for_single_path() { - let profile = normal_graphql_spa(); // single /graphql endpoint - assert_eq!(profile[7], 1.0); -} - -#[test] -fn path_repetition_is_low_for_diverse_paths() { - let profile = normal_search_crawler(); // many unique paths - assert!(profile[7] < 0.2); -} - -// =========================================================================== -// Load the real trained model and validate against known profiles -// =========================================================================== - -#[test] -fn real_model_file_roundtrip() { - let model_path = std::path::Path::new("ddos_model.bin"); - if !model_path.exists() { - // Skip if no model file present (CI environments) - eprintln!("skipping real_model_file_roundtrip: ddos_model.bin not found"); - return; - } - let model = TrainedModel::load(model_path, Some(3), Some(0.5)).unwrap(); - assert!(model.point_count() > 0, "model should have training points"); - // Smoke test: classifying shouldn't panic - let _ = model.classify(&normal_browser_browsing()); - let _ = model.classify(&attack_path_scan()); -} diff --git a/tests/e2e.rs b/tests/e2e.rs index d435f7e..fc7f3ef 100644 --- a/tests/e2e.rs +++ b/tests/e2e.rs @@ -1,3 +1,6 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + //! End-to-end tests: spin up a real SunbeamProxy over plain HTTP, route it //! to a tiny TCP echo-backend, and verify that the upstream receives the //! correct X-Forwarded-Proto header. @@ -108,7 +111,7 @@ fn start_proxy_once(backend_port: u16) { }]; let acme_routes: AcmeRoutes = Arc::new(RwLock::new(HashMap::new())); let compiled_rewrites = SunbeamProxy::compile_rewrites(&routes); - let proxy = SunbeamProxy { routes, acme_routes, ddos_detector: None, scanner_detector: None, bot_allowlist: None, rate_limiter: None, compiled_rewrites, http_client: reqwest::Client::new(), pipeline_bypass_cidrs: vec![], cluster: None }; + let proxy = SunbeamProxy { routes, acme_routes, ddos_detector: None, scanner_detector: None, bot_allowlist: None, rate_limiter: None, compiled_rewrites, http_client: reqwest::Client::new(), pipeline_bypass_cidrs: vec![], cluster: None, ddos_observe_only: false, scanner_observe_only: false }; let opt = Opt { upgrade: false, diff --git a/tests/scanner_test.rs b/tests/scanner_test.rs index 086523e..2f2c6cc 100644 --- a/tests/scanner_test.rs +++ b/tests/scanner_test.rs @@ -1,9 +1,14 @@ +// Copyright Sunbeam Studios 2026 +// SPDX-License-Identifier: Apache-2.0 + +//! Scanner detection tests. +//! +//! The detector uses ensemble inference (decision tree + MLP) with compiled-in +//! weights. These tests exercise the allowlist and ensemble classification paths. + use sunbeam_proxy::config::RouteConfig; use sunbeam_proxy::scanner::detector::ScannerDetector; -use sunbeam_proxy::scanner::features::{ - ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS, -}; -use sunbeam_proxy::scanner::model::{ScannerAction, ScannerModel}; +use sunbeam_proxy::scanner::model::ScannerAction; fn test_routes() -> Vec { vec![ @@ -36,44 +41,8 @@ fn test_routes() -> Vec { ] } -fn scanner_weights() -> [f64; NUM_SCANNER_WEIGHTS] { - let mut w = [0.0; NUM_SCANNER_WEIGHTS]; - w[0] = 2.0; // suspicious_path_score - w[2] = 2.0; // has_suspicious_extension - w[3] = -2.0; // has_cookies (negative = good) - w[4] = -1.0; // has_referer - w[5] = -1.0; // has_accept_language - w[6] = -0.5; // accept_quality - w[7] = -1.0; // ua_category (browser = good) - w[9] = -1.5; // host_is_configured - w[11] = 2.0; // path_has_traversal - w[12] = 1.5; // interaction: suspicious_path AND no_cookies - w[13] = 1.0; // interaction: unknown_host AND no_accept_lang - w[14] = 0.5; // bias - w -} - fn make_detector() -> ScannerDetector { - let model = ScannerModel { - weights: scanner_weights(), - threshold: 0.5, - norm_params: ScannerNormParams { - mins: [0.0; NUM_SCANNER_FEATURES], - maxs: [1.0; NUM_SCANNER_FEATURES], - }, - fragments: vec![ - ".env".into(), - "wp-admin".into(), - "wp-login".into(), - "phpinfo".into(), - "phpmyadmin".into(), - ".git".into(), - "cgi-bin".into(), - ".htaccess".into(), - ".htpasswd".into(), - ], - }; - ScannerDetector::new(&model, &test_routes()) + ScannerDetector::new(&test_routes()) } #[test] @@ -113,7 +82,6 @@ fn env_probe_from_unknown_host_blocked() { "*/*", "curl/7.0", 0, ); assert_eq!(v.action, ScannerAction::Block); - assert_eq!(v.reason, "model"); } #[test] @@ -125,7 +93,6 @@ fn wordpress_scan_blocked() { "*/*", "", 0, ); assert_eq!(v.action, ScannerAction::Block); - assert_eq!(v.reason, "model"); } #[test] @@ -137,7 +104,6 @@ fn path_traversal_blocked() { "*/*", "python-requests/2.28", 0, ); assert_eq!(v.action, ScannerAction::Block); - assert_eq!(v.reason, "model"); } #[test] @@ -149,7 +115,6 @@ fn legitimate_php_path_allowed() { "text/html", "Mozilla/5.0 Chrome/120", 0, ); assert_eq!(v.action, ScannerAction::Allow); - // hits allowlist:host+cookies } #[test] @@ -165,29 +130,3 @@ fn browser_on_known_host_without_cookies_allowed() { assert_eq!(v.action, ScannerAction::Allow); assert_eq!(v.reason, "allowlist:host+browser"); } - -#[test] -fn model_serialization_roundtrip() { - let model = ScannerModel { - weights: scanner_weights(), - threshold: 0.5, - norm_params: ScannerNormParams { - mins: [0.0; NUM_SCANNER_FEATURES], - maxs: [1.0; NUM_SCANNER_FEATURES], - }, - fragments: vec![".env".into(), "wp-admin".into()], - }; - - let dir = std::env::temp_dir().join("scanner_e2e_test"); - std::fs::create_dir_all(&dir).unwrap(); - let path = dir.join("test_scanner_model.bin"); - - model.save(&path).unwrap(); - let loaded = ScannerModel::load(&path).unwrap(); - - assert_eq!(loaded.weights, model.weights); - assert_eq!(loaded.threshold, model.threshold); - assert_eq!(loaded.fragments, model.fragments); - - let _ = std::fs::remove_dir_all(&dir); -}