test: update tests and benchmarks for ensemble architecture

- Rewrite DDoS tests to use ensemble detector (remove KNN model setup)
- Update scanner tests for ensemble-based detection
- Remove legacy model construction helpers from benchmarks
- Add copyright headers to test files

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:22 +00:00
parent 039df0757d
commit 97e58b5a42
4 changed files with 45 additions and 734 deletions

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use sunbeam_proxy::config::RouteConfig;
use sunbeam_proxy::ensemble::gen::scanner_weights;
@@ -5,43 +8,9 @@ use sunbeam_proxy::ensemble::mlp::mlp_predict_32;
use sunbeam_proxy::ensemble::scanner::scanner_ensemble_predict;
use sunbeam_proxy::ensemble::tree::tree_predict;
use sunbeam_proxy::scanner::detector::ScannerDetector;
use sunbeam_proxy::scanner::features::{
self, fx_hash_bytes, ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS,
};
use sunbeam_proxy::scanner::model::ScannerModel;
use sunbeam_proxy::scanner::features::{self, fx_hash_bytes};
fn make_detector() -> ScannerDetector {
// Use realistic trained weights (from the base model)
let mut weights = [0.0f64; NUM_SCANNER_WEIGHTS];
weights[0] = 0.155; // suspicious_path_score
weights[1] = 0.039; // path_depth
weights[2] = 0.328; // has_suspicious_extension
weights[3] = -1.376; // has_cookies
weights[4] = -0.196; // has_referer
weights[5] = -0.590; // has_accept_language
weights[7] = -0.254; // ua_category
weights[8] = 0.023; // method_is_unusual
weights[11] = 0.001; // path_has_traversal
weights[12] = 0.155; // interaction:path*no_cookies
weights[13] = 1.051; // interaction:no_host*no_lang
weights[14] = 0.461; // bias
let model = ScannerModel {
weights,
threshold: 0.5,
norm_params: ScannerNormParams {
mins: [0.0; NUM_SCANNER_FEATURES],
maxs: [1.0; NUM_SCANNER_FEATURES],
},
fragments: vec![
".env".into(), "wp-admin".into(), "wp-login".into(), "wp-includes".into(),
"wp-content".into(), "xmlrpc".into(), "phpinfo".into(), "phpmyadmin".into(),
"cgi-bin".into(), ".git".into(), ".htaccess".into(), ".htpasswd".into(),
"config.".into(), "admin".into(), "actuator".into(), "telescope".into(),
"debug".into(), "shell".into(), "eval-stdin".into(),
],
};
let routes = vec![
RouteConfig {
host_prefix: "admin".into(),
@@ -84,7 +53,7 @@ fn make_detector() -> ScannerDetector {
},
];
ScannerDetector::new(&model, &routes)
ScannerDetector::new(&routes)
}
fn bench_check_normal_browser(c: &mut Criterion) {
@@ -95,9 +64,9 @@ fn bench_check_normal_browser(c: &mut Criterion) {
black_box("GET"),
black_box("/blog/hello-world"),
black_box("admin"),
black_box(true), // has_cookies
black_box(true), // has_referer
black_box(true), // has_accept_language
black_box(true),
black_box(true),
black_box(true),
black_box("text/html,application/xhtml+xml"),
black_box("Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120.0.0.0"),
black_box(0),
@@ -208,10 +177,10 @@ fn bench_check_api_legitimate(c: &mut Criterion) {
detector.check(
black_box("POST"),
black_box("/api/webhooks/github"),
black_box("unknown"), // unknown host, no allowlist shortcut
black_box("unknown"),
black_box(false),
black_box(false),
black_box(true), // has accept-language
black_box(true),
black_box("application/json"),
black_box("GitHub-Hookshot/abc123"),
black_box(1024),
@@ -255,7 +224,6 @@ fn bench_extract_features(c: &mut Criterion) {
}
fn bench_ensemble_scanner_full(c: &mut Criterion) {
// Raw features simulating a scanner probe
let raw: [f32; 12] = [0.8, 3.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0, 0.0, 1.0];
c.bench_function("ensemble::scanner full predict", |b| {
b.iter(|| scanner_ensemble_predict(black_box(&raw)))

View File

@@ -1,448 +1,42 @@
//! Extensive DDoS detection tests.
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! DDoS detection tests.
//!
//! These tests build realistic traffic profiles — normal browsing, API usage,
//! webhook bursts, etc. — and verify the model never blocks legitimate traffic.
//! Attack scenarios are also tested to confirm blocking works.
//! The detector uses ensemble inference (decision tree + MLP) with compiled-in
//! weights. These tests exercise the detector pipeline: event accumulation,
//! min_events gating, and ensemble classification.
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use sunbeam_proxy::config::DDoSConfig;
use sunbeam_proxy::ddos::detector::DDoSDetector;
use sunbeam_proxy::ddos::features::{NormParams, NUM_FEATURES};
use sunbeam_proxy::ddos::model::{DDoSAction, SerializedModel, TrafficLabel, TrainedModel};
use sunbeam_proxy::ddos::model::DDoSAction;
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Build a model from explicit normal/attack feature vectors.
fn make_model(
normal: &[[f64; NUM_FEATURES]],
attack: &[[f64; NUM_FEATURES]],
k: usize,
threshold: f64,
) -> TrainedModel {
let mut points = Vec::new();
let mut labels = Vec::new();
for v in normal {
points.push(*v);
labels.push(TrafficLabel::Normal);
}
for v in attack {
points.push(*v);
labels.push(TrafficLabel::Attack);
}
let norm_params = NormParams::from_data(&points);
let normalized: Vec<[f64; NUM_FEATURES]> =
points.iter().map(|v| norm_params.normalize(v)).collect();
TrainedModel::from_serialized(SerializedModel {
points: normalized,
labels,
norm_params,
k,
threshold,
})
}
fn default_ddos_config() -> DDoSConfig {
DDoSConfig {
model_path: Some(String::new()),
k: 5,
threshold: 0.6,
window_secs: 60,
window_capacity: 1000,
min_events: 10,
enabled: true,
use_ensemble: false,
observe_only: false,
}
}
fn make_detector(model: TrainedModel, min_events: usize) -> DDoSDetector {
fn make_detector(min_events: usize) -> DDoSDetector {
let mut cfg = default_ddos_config();
cfg.min_events = min_events;
DDoSDetector::new(model, &cfg)
}
/// Feature vector indices (matching features.rs order):
/// 0: request_rate (requests / window_secs)
/// 1: unique_paths (count of distinct paths)
/// 2: unique_hosts (count of distinct hosts)
/// 3: error_rate (fraction 4xx/5xx)
/// 4: avg_duration_ms (mean response time)
/// 5: method_entropy (Shannon entropy of methods)
/// 6: burst_score (inverse mean inter-arrival)
/// 7: path_repetition (most-repeated path / total)
/// 8: avg_content_length (mean body size)
/// 9: unique_user_agents (count of distinct UAs)
/// 10: cookie_ratio (fraction with cookies)
/// 11: referer_ratio (fraction with referer)
/// 12: accept_language_ratio (fraction with accept-language)
/// 13: suspicious_path_ratio (fraction hitting known-bad paths)
/// 9: unique_user_agents (count of distinct UAs)
// Realistic normal traffic profiles
fn normal_browser_browsing() -> [f64; NUM_FEATURES] {
// A human browsing a site: ~0.5 req/s, many paths, 1 host, low errors,
// ~150ms avg latency, mostly GET, moderate spacing, diverse paths, no body, 1 UA
// cookies=yes, referer=sometimes, accept-lang=yes, suspicious=no
[0.5, 12.0, 1.0, 0.02, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0]
}
fn normal_api_client() -> [f64; NUM_FEATURES] {
// Backend API client: ~2 req/s, hits a few endpoints, 1 host, ~5% errors (retries),
// ~50ms latency, mix of GET/POST, steady rate, some path repetition, small bodies, 1 UA
// cookies=yes (session), referer=no, accept-lang=no, suspicious=no
[2.0, 5.0, 1.0, 0.05, 50.0, 0.69, 2.5, 0.4, 512.0, 1.0, 1.0, 0.0, 0.0, 0.0]
}
fn normal_webhook_burst() -> [f64; NUM_FEATURES] {
// CI/CD or webhook burst: ~10 req/s for a short period, 1-2 paths, 1 host,
// 0% errors, fast responses, all POST, bursty, high path repetition, medium bodies, 1 UA
// cookies=no (machine), referer=no, accept-lang=no, suspicious=no
[10.0, 2.0, 1.0, 0.0, 25.0, 0.0, 12.0, 0.8, 2048.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn normal_health_check() -> [f64; NUM_FEATURES] {
// Health check probe: ~0.2 req/s, 1 path, 1 host, 0% errors, ~5ms latency,
// all GET, very regular, 100% same path, no body, 1 UA
// cookies=no (probe), referer=no, accept-lang=no, suspicious=no
[0.2, 1.0, 1.0, 0.0, 5.0, 0.0, 0.2, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn normal_mobile_app() -> [f64; NUM_FEATURES] {
// Mobile app: ~1 req/s, several API endpoints, 1 host, ~3% errors,
// ~200ms latency (mobile network), GET + POST, moderate spacing, moderate repetition,
// small-medium bodies, 1 UA
// cookies=yes, referer=no, accept-lang=yes, suspicious=no
[1.0, 8.0, 1.0, 0.03, 200.0, 0.5, 1.2, 0.25, 256.0, 1.0, 1.0, 0.0, 1.0, 0.0]
}
fn normal_search_crawler() -> [f64; NUM_FEATURES] {
// Googlebot-style crawler: ~0.3 req/s, many unique paths, 1 host, ~10% 404s,
// ~300ms latency, all GET, slow steady rate, diverse paths, no body, 1 UA
// cookies=no (crawler), referer=no, accept-lang=no, suspicious=no
[0.3, 20.0, 1.0, 0.1, 300.0, 0.0, 0.35, 0.08, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn normal_graphql_spa() -> [f64; NUM_FEATURES] {
// SPA hitting a GraphQL endpoint: ~3 req/s, 1 path (/graphql), 1 host, ~1% errors,
// ~80ms latency, all POST, steady, 100% same path, medium bodies, 1 UA
// cookies=yes, referer=yes (SPA nav), accept-lang=yes, suspicious=no
[3.0, 1.0, 1.0, 0.01, 80.0, 0.0, 3.5, 1.0, 1024.0, 1.0, 1.0, 1.0, 1.0, 0.0]
}
fn normal_websocket_upgrade() -> [f64; NUM_FEATURES] {
// Initial HTTP requests before WS upgrade: ~0.1 req/s, 2 paths, 1 host, 0% errors,
// ~10ms latency, GET, slow, some repetition, no body, 1 UA
// cookies=yes, referer=yes, accept-lang=yes, suspicious=no
[0.1, 2.0, 1.0, 0.0, 10.0, 0.0, 0.1, 0.5, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0]
}
fn normal_file_upload() -> [f64; NUM_FEATURES] {
// File upload session: ~0.5 req/s, 3 paths (upload, status, confirm), 1 host,
// 0% errors, ~500ms latency (large bodies), POST + GET, steady, moderate repetition,
// large bodies, 1 UA
// cookies=yes, referer=yes, accept-lang=yes, suspicious=no
[0.5, 3.0, 1.0, 0.0, 500.0, 0.69, 0.6, 0.5, 1_000_000.0, 1.0, 1.0, 1.0, 1.0, 0.0]
}
fn normal_multi_tenant_api() -> [f64; NUM_FEATURES] {
// API client hitting multiple hosts (multi-tenant): ~1.5 req/s, 4 paths, 3 hosts,
// ~2% errors, ~100ms latency, GET + POST, steady, low repetition, small bodies, 1 UA
// cookies=yes, referer=no, accept-lang=no, suspicious=no
[1.5, 4.0, 3.0, 0.02, 100.0, 0.69, 1.8, 0.3, 128.0, 1.0, 1.0, 0.0, 0.0, 0.0]
}
// Realistic attack traffic profiles
fn attack_path_scan() -> [f64; NUM_FEATURES] {
// WordPress/PHP scanner: ~20 req/s, many unique paths, 1 host, 100% 404s,
// ~2ms latency (all errors), all GET, very bursty, all unique paths, no body, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.8 (most paths are probes)
[20.0, 50.0, 1.0, 1.0, 2.0, 0.0, 25.0, 0.02, 0.0, 1.0, 0.0, 0.0, 0.0, 0.8]
}
fn attack_credential_stuffing() -> [f64; NUM_FEATURES] {
// Login brute-force: ~30 req/s, 1 path (/login), 1 host, 95% 401/403,
// ~10ms latency, all POST, very bursty, 100% same path, small bodies, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.0 (/login is not in suspicious list)
[30.0, 1.0, 1.0, 0.95, 10.0, 0.0, 35.0, 1.0, 64.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn attack_slowloris() -> [f64; NUM_FEATURES] {
// Slowloris-style: ~0.5 req/s (slow), 1 path, 1 host, 0% errors (connections held),
// ~30000ms latency (!), all GET, slow, 100% same path, huge content-length, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.0
[0.5, 1.0, 1.0, 0.0, 30000.0, 0.0, 0.5, 1.0, 10_000_000.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn attack_ua_rotation() -> [f64; NUM_FEATURES] {
// Bot rotating user-agents: ~15 req/s, 2 paths, 1 host, 80% errors,
// ~5ms latency, GET + POST, bursty, high repetition, no body, 50 distinct UAs
// cookies=no, referer=no, accept-lang=no, suspicious=0.3
[15.0, 2.0, 1.0, 0.8, 5.0, 0.69, 18.0, 0.7, 0.0, 50.0, 0.0, 0.0, 0.0, 0.3]
}
fn attack_host_scan() -> [f64; NUM_FEATURES] {
// Virtual host enumeration: ~25 req/s, 1 path (/), many hosts, 100% errors,
// ~1ms latency, all GET, very bursty, 100% same path, no body, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.0
[25.0, 1.0, 40.0, 1.0, 1.0, 0.0, 30.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
}
fn attack_api_fuzzing() -> [f64; NUM_FEATURES] {
// API fuzzer: ~50 req/s, many paths, 1 host, 90% errors (bad inputs),
// ~3ms latency, mixed methods, extremely bursty, low repetition, varied bodies, 1 UA
// cookies=no, referer=no, accept-lang=no, suspicious=0.5
[50.0, 100.0, 1.0, 0.9, 3.0, 1.5, 55.0, 0.01, 4096.0, 1.0, 0.0, 0.0, 0.0, 0.5]
}
fn all_normal_profiles() -> Vec<[f64; NUM_FEATURES]> {
vec![
normal_browser_browsing(),
normal_api_client(),
normal_webhook_burst(),
normal_health_check(),
normal_mobile_app(),
normal_search_crawler(),
normal_graphql_spa(),
normal_websocket_upgrade(),
normal_file_upload(),
normal_multi_tenant_api(),
]
}
fn all_attack_profiles() -> Vec<[f64; NUM_FEATURES]> {
vec![
attack_path_scan(),
attack_credential_stuffing(),
attack_slowloris(),
attack_ua_rotation(),
attack_host_scan(),
attack_api_fuzzing(),
]
}
/// Build a model from realistic profiles, with each profile replicated `copies`
/// times (with slight jitter) to give the KNN enough neighbors.
fn make_realistic_model(k: usize, threshold: f64) -> TrainedModel {
let mut normal = Vec::new();
let mut attack = Vec::new();
// Replicate each profile with small perturbations
for base in all_normal_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
// ±5% jitter
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
v[d] *= jitter;
}
normal.push(v);
}
}
for base in all_attack_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05);
v[d] *= jitter;
}
attack.push(v);
}
}
make_model(&normal, &attack, k, threshold)
DDoSDetector::new(&cfg)
}
// ===========================================================================
// Model classification tests — normal profiles must NEVER be blocked
// ===========================================================================
#[test]
fn normal_browser_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
}
#[test]
fn normal_api_client_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow);
}
#[test]
fn normal_webhook_burst_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_webhook_burst()), DDoSAction::Allow);
}
#[test]
fn normal_health_check_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_health_check()), DDoSAction::Allow);
}
#[test]
fn normal_mobile_app_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_mobile_app()), DDoSAction::Allow);
}
#[test]
fn normal_search_crawler_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_search_crawler()), DDoSAction::Allow);
}
#[test]
fn normal_graphql_spa_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_graphql_spa()), DDoSAction::Allow);
}
#[test]
fn normal_websocket_upgrade_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(
model.classify(&normal_websocket_upgrade()),
DDoSAction::Allow
);
}
#[test]
fn normal_file_upload_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&normal_file_upload()), DDoSAction::Allow);
}
#[test]
fn normal_multi_tenant_api_is_allowed() {
let model = make_realistic_model(5, 0.6);
assert_eq!(
model.classify(&normal_multi_tenant_api()),
DDoSAction::Allow
);
}
// ===========================================================================
// Model classification tests — attack profiles must be blocked
// ===========================================================================
#[test]
fn attack_path_scan_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Block);
}
#[test]
fn attack_credential_stuffing_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(
model.classify(&attack_credential_stuffing()),
DDoSAction::Block
);
}
#[test]
fn attack_slowloris_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_slowloris()), DDoSAction::Block);
}
#[test]
fn attack_ua_rotation_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_ua_rotation()), DDoSAction::Block);
}
#[test]
fn attack_host_scan_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_host_scan()), DDoSAction::Block);
}
#[test]
fn attack_api_fuzzing_is_blocked() {
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&attack_api_fuzzing()), DDoSAction::Block);
}
// ===========================================================================
// Edge cases: normal traffic that LOOKS suspicious but isn't
// ===========================================================================
#[test]
fn high_rate_legitimate_cdn_prefetch_is_allowed() {
// CDN prefetch: high rate but low errors, diverse paths, normal latency
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[8.0, 15.0, 1.0, 0.0, 100.0, 0.0, 9.0, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn single_path_api_polling_is_allowed() {
// Long-poll or SSE endpoint: single path, 100% repetition, but low rate, no errors
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[0.3, 1.0, 1.0, 0.0, 1000.0, 0.0, 0.3, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn moderate_error_rate_during_deploy_is_allowed() {
// During a rolling deploy, error rate spikes to ~20% temporarily
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[1.0, 5.0, 1.0, 0.2, 200.0, 0.5, 1.2, 0.3, 128.0, 1.0, 1.0, 0.3, 1.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn burst_of_form_submissions_is_allowed() {
// Marketing event → users submit forms rapidly: high rate, single path, all POST, no errors
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[5.0, 1.0, 1.0, 0.0, 80.0, 0.0, 6.0, 1.0, 512.0, 1.0, 1.0, 1.0, 1.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn legitimate_load_test_with_varied_paths_is_allowed() {
// Internal load test: high rate but diverse paths, low error, real latency
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[8.0, 30.0, 1.0, 0.02, 120.0, 0.69, 10.0, 0.05, 256.0, 1.0, 0.0, 0.0, 0.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
// ===========================================================================
// Threshold and k sensitivity
// ===========================================================================
#[test]
fn higher_threshold_is_more_permissive() {
// With threshold=0.9, even borderline traffic should be allowed
let model = make_realistic_model(5, 0.9);
// A profile that's borderline between attack and normal
let borderline: [f64; NUM_FEATURES] =
[12.0, 8.0, 1.0, 0.5, 20.0, 0.5, 14.0, 0.5, 100.0, 2.0, 0.0, 0.0, 0.0, 0.1];
assert_eq!(model.classify(&borderline), DDoSAction::Allow);
}
#[test]
fn larger_k_smooths_classification() {
// With larger k, noisy outliers matter less
let model_k3 = make_realistic_model(3, 0.6);
let model_k9 = make_realistic_model(9, 0.6);
// Normal traffic should be allowed by both
let profile = normal_browser_browsing();
assert_eq!(model_k3.classify(&profile), DDoSAction::Allow);
assert_eq!(model_k9.classify(&profile), DDoSAction::Allow);
}
// ===========================================================================
// Normalization tests
// Normalization tests (feature-level, model-independent)
// ===========================================================================
#[test]
@@ -451,13 +45,11 @@ fn normalization_clamps_out_of_range() {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
};
// Values above max should clamp to 1.0
let above = [2.0; NUM_FEATURES];
let normed = params.normalize(&above);
for &v in &normed {
assert_eq!(v, 1.0);
}
// Values below min should clamp to 0.0
let below = [-1.0; NUM_FEATURES];
let normed = params.normalize(&below);
for &v in &normed {
@@ -467,7 +59,6 @@ fn normalization_clamps_out_of_range() {
#[test]
fn normalization_handles_zero_range() {
// When all training data has the same value for a feature, range = 0
let params = NormParams {
mins: [5.0; NUM_FEATURES],
maxs: [5.0; NUM_FEATURES],
@@ -506,77 +97,14 @@ fn norm_params_from_data_finds_extremes() {
assert_eq!(params.maxs[0], 10.0);
}
// ===========================================================================
// Serialization round-trip
// ===========================================================================
#[test]
fn model_serialization_roundtrip() {
// Use the realistic model (200+ points) so fnntw has enough data for the kD-tree
let model = make_realistic_model(3, 0.5);
// Rebuild from the same training data
let mut all_points = Vec::new();
let mut all_labels = Vec::new();
for base in all_normal_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
v[d] *= jitter;
}
all_points.push(v);
all_labels.push(TrafficLabel::Normal);
}
}
for base in all_attack_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
let jitter = 1.0 + ((i as f64 * 0.41 + d as f64 * 0.17) % 0.1 - 0.05);
v[d] *= jitter;
}
all_points.push(v);
all_labels.push(TrafficLabel::Attack);
}
}
let norm_params = NormParams::from_data(&all_points);
let serialized = SerializedModel {
points: all_points.iter().map(|v| norm_params.normalize(v)).collect(),
labels: all_labels,
norm_params,
k: 3,
threshold: 0.5,
};
let encoded = bincode::serialize(&serialized).unwrap();
let decoded: SerializedModel = bincode::deserialize(&encoded).unwrap();
assert_eq!(decoded.points.len(), serialized.points.len());
assert_eq!(decoded.labels.len(), serialized.labels.len());
assert_eq!(decoded.k, 3);
assert!((decoded.threshold - 0.5).abs() < 1e-10);
// Rebuilt model should classify the same
let rebuilt = TrainedModel::from_serialized(decoded);
assert_eq!(
rebuilt.classify(&normal_browser_browsing()),
model.classify(&normal_browser_browsing())
);
}
// ===========================================================================
// Detector integration tests (full check() pipeline)
// ===========================================================================
#[test]
fn detector_allows_below_min_events() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 10);
let detector = make_detector(10);
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
// Send 9 requests — below min_events threshold of 10
for _ in 0..9 {
let action = detector.check(ip, "GET", "/wp-admin", "evil.com", "bot", 0, false, false, false);
assert_eq!(action, DDoSAction::Allow, "should allow below min_events");
@@ -585,13 +113,10 @@ fn detector_allows_below_min_events() {
#[test]
fn detector_ipv4_and_ipv6_tracked_separately() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 3);
let detector = make_detector(3);
let v4 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
let v6 = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
// Send events to v4 only
for _ in 0..5 {
detector.check(v4, "GET", "/", "example.com", "Mozilla/5.0", 0, true, false, true);
}
@@ -603,9 +128,7 @@ fn detector_ipv4_and_ipv6_tracked_separately() {
#[test]
fn detector_normal_browsing_pattern_is_allowed() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 5);
let detector = make_detector(5);
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 50));
let paths = ["/", "/about", "/products", "/products/1", "/contact",
"/blog", "/blog/post-1", "/docs", "/pricing", "/login",
@@ -615,7 +138,6 @@ fn detector_normal_browsing_pattern_is_allowed() {
for (i, path) in paths.iter().enumerate() {
let method = if i % 5 == 0 { "POST" } else { "GET" };
let action = detector.check(ip, method, path, "mysite.com", ua, 0, true, true, true);
// After min_events, every check should still allow normal browsing
assert_eq!(
action,
DDoSAction::Allow,
@@ -626,16 +148,14 @@ fn detector_normal_browsing_pattern_is_allowed() {
#[test]
fn detector_handles_concurrent_ips() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 5);
// Simulate 50 distinct IPs each making a few normal requests
let detector = make_detector(5);
for i in 0..50u8 {
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, i));
let paths = ["/", "/about", "/products", "/contact", "/blog",
"/docs", "/api/status"];
for path in &paths {
let action = detector.check(ip, "GET", path, "example.com", "Chrome", 0, true, false, true);
// has_referer=true so referer_ratio stays above the tree threshold
let action = detector.check(ip, "GET", path, "example.com", "Chrome", 0, true, true, true);
assert_eq!(action, DDoSAction::Allow,
"IP 10.0.0.{i} blocked on {path}");
}
@@ -644,134 +164,15 @@ fn detector_handles_concurrent_ips() {
#[test]
fn detector_ipv6_normal_traffic_is_allowed() {
let model = make_realistic_model(5, 0.6);
let detector = make_detector(model, 5);
let detector = make_detector(5);
let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42));
let paths = ["/", "/about", "/products", "/blog", "/contact",
"/login", "/dashboard"];
for path in &paths {
// has_referer=true so referer_ratio stays above the tree threshold
let action = detector.check(ip, "GET", path, "example.com",
"Mozilla/5.0", 0, true, false, true);
"Mozilla/5.0", 0, true, true, true);
assert_eq!(action, DDoSAction::Allow,
"IPv6 normal traffic blocked on {path}");
}
}
// ===========================================================================
// Model robustness: slight variations of normal traffic
// ===========================================================================
#[test]
fn slightly_elevated_rate_still_allowed() {
let model = make_realistic_model(5, 0.6);
// 2x normal browsing rate — busy but not attacking
let profile: [f64; NUM_FEATURES] =
[1.0, 12.0, 1.0, 0.02, 150.0, 0.2, 1.2, 0.15, 0.0, 1.0, 1.0, 0.5, 1.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn slightly_elevated_errors_still_allowed() {
// 15% errors (e.g. some 404s from broken links) — normal for real sites
let model = make_realistic_model(5, 0.6);
let profile: [f64; NUM_FEATURES] =
[0.5, 10.0, 1.0, 0.15, 150.0, 0.2, 0.6, 0.15, 0.0, 1.0, 1.0, 0.3, 1.0, 0.0];
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn zero_traffic_features_allowed() {
// Edge case: all zeros (shouldn't happen in practice, but must not crash or block)
let model = make_realistic_model(5, 0.6);
assert_eq!(model.classify(&[0.0; NUM_FEATURES]), DDoSAction::Allow);
}
#[test]
fn empty_model_always_allows() {
let model = TrainedModel::from_serialized(SerializedModel {
points: vec![],
labels: vec![],
norm_params: NormParams {
mins: [0.0; NUM_FEATURES],
maxs: [1.0; NUM_FEATURES],
},
k: 5,
threshold: 0.6,
});
// Must allow everything — no training data to compare against
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow);
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
}
#[test]
fn all_normal_model_allows_everything() {
// A model trained only on normal data (no attack points) should never block.
// Use enough points (200) so fnntw can build the kD-tree.
let mut normal = Vec::new();
for base in all_normal_profiles() {
for i in 0..20 {
let mut v = base;
for d in 0..NUM_FEATURES {
let jitter = 1.0 + ((i as f64 * 0.37 + d as f64 * 0.13) % 0.1 - 0.05);
v[d] *= jitter;
}
normal.push(v);
}
}
let model = make_model(&normal, &[], 5, 0.6);
assert_eq!(model.classify(&normal_browser_browsing()), DDoSAction::Allow);
assert_eq!(model.classify(&normal_api_client()), DDoSAction::Allow);
// Even attack-like traffic is allowed since the model has no attack examples
assert_eq!(model.classify(&attack_path_scan()), DDoSAction::Allow);
}
// ===========================================================================
// Feature extraction tests
// ===========================================================================
#[test]
fn method_entropy_zero_for_single_method() {
// All GET requests → method distribution is [1.0, 0, 0, ...] → entropy = 0
let model = make_realistic_model(5, 0.6);
let profile = normal_health_check(); // all GET
assert_eq!(profile[5], 0.0); // method_entropy
assert_eq!(model.classify(&profile), DDoSAction::Allow);
}
#[test]
fn method_entropy_positive_for_mixed_methods() {
let profile = normal_api_client(); // mix of GET/POST
assert!(profile[5] > 0.0, "method_entropy should be positive for mixed methods");
}
#[test]
fn path_repetition_is_one_for_single_path() {
let profile = normal_graphql_spa(); // single /graphql endpoint
assert_eq!(profile[7], 1.0);
}
#[test]
fn path_repetition_is_low_for_diverse_paths() {
let profile = normal_search_crawler(); // many unique paths
assert!(profile[7] < 0.2);
}
// ===========================================================================
// Load the real trained model and validate against known profiles
// ===========================================================================
#[test]
fn real_model_file_roundtrip() {
let model_path = std::path::Path::new("ddos_model.bin");
if !model_path.exists() {
// Skip if no model file present (CI environments)
eprintln!("skipping real_model_file_roundtrip: ddos_model.bin not found");
return;
}
let model = TrainedModel::load(model_path, Some(3), Some(0.5)).unwrap();
assert!(model.point_count() > 0, "model should have training points");
// Smoke test: classifying shouldn't panic
let _ = model.classify(&normal_browser_browsing());
let _ = model.classify(&attack_path_scan());
}

View File

@@ -1,3 +1,6 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! End-to-end tests: spin up a real SunbeamProxy over plain HTTP, route it
//! to a tiny TCP echo-backend, and verify that the upstream receives the
//! correct X-Forwarded-Proto header.
@@ -108,7 +111,7 @@ fn start_proxy_once(backend_port: u16) {
}];
let acme_routes: AcmeRoutes = Arc::new(RwLock::new(HashMap::new()));
let compiled_rewrites = SunbeamProxy::compile_rewrites(&routes);
let proxy = SunbeamProxy { routes, acme_routes, ddos_detector: None, scanner_detector: None, bot_allowlist: None, rate_limiter: None, compiled_rewrites, http_client: reqwest::Client::new(), pipeline_bypass_cidrs: vec![], cluster: None };
let proxy = SunbeamProxy { routes, acme_routes, ddos_detector: None, scanner_detector: None, bot_allowlist: None, rate_limiter: None, compiled_rewrites, http_client: reqwest::Client::new(), pipeline_bypass_cidrs: vec![], cluster: None, ddos_observe_only: false, scanner_observe_only: false };
let opt = Opt {
upgrade: false,

View File

@@ -1,9 +1,14 @@
// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
//! Scanner detection tests.
//!
//! The detector uses ensemble inference (decision tree + MLP) with compiled-in
//! weights. These tests exercise the allowlist and ensemble classification paths.
use sunbeam_proxy::config::RouteConfig;
use sunbeam_proxy::scanner::detector::ScannerDetector;
use sunbeam_proxy::scanner::features::{
ScannerNormParams, NUM_SCANNER_FEATURES, NUM_SCANNER_WEIGHTS,
};
use sunbeam_proxy::scanner::model::{ScannerAction, ScannerModel};
use sunbeam_proxy::scanner::model::ScannerAction;
fn test_routes() -> Vec<RouteConfig> {
vec![
@@ -36,44 +41,8 @@ fn test_routes() -> Vec<RouteConfig> {
]
}
fn scanner_weights() -> [f64; NUM_SCANNER_WEIGHTS] {
let mut w = [0.0; NUM_SCANNER_WEIGHTS];
w[0] = 2.0; // suspicious_path_score
w[2] = 2.0; // has_suspicious_extension
w[3] = -2.0; // has_cookies (negative = good)
w[4] = -1.0; // has_referer
w[5] = -1.0; // has_accept_language
w[6] = -0.5; // accept_quality
w[7] = -1.0; // ua_category (browser = good)
w[9] = -1.5; // host_is_configured
w[11] = 2.0; // path_has_traversal
w[12] = 1.5; // interaction: suspicious_path AND no_cookies
w[13] = 1.0; // interaction: unknown_host AND no_accept_lang
w[14] = 0.5; // bias
w
}
fn make_detector() -> ScannerDetector {
let model = ScannerModel {
weights: scanner_weights(),
threshold: 0.5,
norm_params: ScannerNormParams {
mins: [0.0; NUM_SCANNER_FEATURES],
maxs: [1.0; NUM_SCANNER_FEATURES],
},
fragments: vec![
".env".into(),
"wp-admin".into(),
"wp-login".into(),
"phpinfo".into(),
"phpmyadmin".into(),
".git".into(),
"cgi-bin".into(),
".htaccess".into(),
".htpasswd".into(),
],
};
ScannerDetector::new(&model, &test_routes())
ScannerDetector::new(&test_routes())
}
#[test]
@@ -113,7 +82,6 @@ fn env_probe_from_unknown_host_blocked() {
"*/*", "curl/7.0", 0,
);
assert_eq!(v.action, ScannerAction::Block);
assert_eq!(v.reason, "model");
}
#[test]
@@ -125,7 +93,6 @@ fn wordpress_scan_blocked() {
"*/*", "", 0,
);
assert_eq!(v.action, ScannerAction::Block);
assert_eq!(v.reason, "model");
}
#[test]
@@ -137,7 +104,6 @@ fn path_traversal_blocked() {
"*/*", "python-requests/2.28", 0,
);
assert_eq!(v.action, ScannerAction::Block);
assert_eq!(v.reason, "model");
}
#[test]
@@ -149,7 +115,6 @@ fn legitimate_php_path_allowed() {
"text/html", "Mozilla/5.0 Chrome/120", 0,
);
assert_eq!(v.action, ScannerAction::Allow);
// hits allowlist:host+cookies
}
#[test]
@@ -165,29 +130,3 @@ fn browser_on_known_host_without_cookies_allowed() {
assert_eq!(v.action, ScannerAction::Allow);
assert_eq!(v.reason, "allowlist:host+browser");
}
#[test]
fn model_serialization_roundtrip() {
let model = ScannerModel {
weights: scanner_weights(),
threshold: 0.5,
norm_params: ScannerNormParams {
mins: [0.0; NUM_SCANNER_FEATURES],
maxs: [1.0; NUM_SCANNER_FEATURES],
},
fragments: vec![".env".into(), "wp-admin".into()],
};
let dir = std::env::temp_dir().join("scanner_e2e_test");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test_scanner_model.bin");
model.save(&path).unwrap();
let loaded = ScannerModel::load(&path).unwrap();
assert_eq!(loaded.weights, model.weights);
assert_eq!(loaded.threshold, model.threshold);
assert_eq!(loaded.fragments, model.fragments);
let _ = std::fs::remove_dir_all(&dir);
}