feat(proxy): integrate DDoS, scanner, and rate limiter into request pipeline

Wire up all three detection layers in request_filter with pipeline
logging at each stage for unfiltered training data. Add DDoS, scanner,
and rate_limit config sections. Bot allowlist check before scanner
model on the hot path. CLI subcommands for train/replay.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:20 +00:00
parent ae18b00fa4
commit 867b6b2489
7 changed files with 1160 additions and 36 deletions

View File

@@ -4,16 +4,167 @@ mod watcher;
use sunbeam_proxy::{acme, config};
use sunbeam_proxy::proxy::SunbeamProxy;
use sunbeam_proxy::ddos;
use sunbeam_proxy::rate_limit;
use sunbeam_proxy::scanner;
use std::{collections::HashMap, sync::Arc};
use anyhow::Result;
use clap::{Parser, Subcommand};
use kube::Client;
use pingora::server::{configuration::Opt, Server};
use pingora_proxy::http_proxy_service;
use std::sync::RwLock;
#[derive(Parser)]
#[command(name = "sunbeam-proxy")]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
/// Start the proxy server (default if no subcommand given)
Serve {
/// Pingora --upgrade flag for zero-downtime reload
#[arg(long)]
upgrade: bool,
},
/// Replay audit logs through the DDoS detector and rate limiter
Replay {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Path to trained model file
#[arg(short, long, default_value = "ddos_model.bin")]
model: String,
/// Optional config file (for rate limit settings)
#[arg(short, long)]
config: Option<String>,
/// KNN k parameter
#[arg(long, default_value = "5")]
k: usize,
/// Attack threshold
#[arg(long, default_value = "0.6")]
threshold: f64,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP before classification
#[arg(long, default_value = "10")]
min_events: usize,
/// Also run rate limiter during replay
#[arg(long)]
rate_limit: bool,
},
/// Train a DDoS detection model from audit logs
Train {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output model file path
#[arg(short, long)]
output: String,
/// File with known-attack IPs (one per line)
#[arg(long)]
attack_ips: Option<String>,
/// File with known-normal IPs (one per line)
#[arg(long)]
normal_ips: Option<String>,
/// TOML file with heuristic auto-labeling thresholds
#[arg(long)]
heuristics: Option<String>,
/// KNN k parameter
#[arg(long, default_value = "5")]
k: usize,
/// Attack threshold (fraction of k neighbors)
#[arg(long, default_value = "0.6")]
threshold: f64,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP to include in training
#[arg(long, default_value = "10")]
min_events: usize,
},
/// Train a per-request scanner detection model from audit logs
TrainScanner {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output model file path
#[arg(short, long, default_value = "scanner_model.bin")]
output: String,
/// Directory (or file) containing .txt wordlists of scanner paths
#[arg(long)]
wordlists: Option<String>,
/// Classification threshold
#[arg(long, default_value = "0.5")]
threshold: f64,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
Commands::Serve { upgrade } => run_serve(upgrade),
Commands::Replay {
input,
model,
config,
k,
threshold,
window_secs,
min_events,
rate_limit,
} => ddos::replay::run(ddos::replay::ReplayArgs {
input,
model_path: model,
config_path: config,
k,
threshold,
window_secs,
min_events,
rate_limit,
}),
Commands::Train {
input,
output,
attack_ips,
normal_ips,
heuristics,
k,
threshold,
window_secs,
min_events,
} => ddos::train::run(ddos::train::TrainArgs {
input,
output,
attack_ips,
normal_ips,
heuristics,
k,
threshold,
window_secs,
min_events,
}),
Commands::TrainScanner {
input,
output,
wordlists,
threshold,
} => scanner::train::run(scanner::train::TrainScannerArgs {
input,
output,
wordlists,
threshold,
}),
}
}
fn run_serve(upgrade: bool) -> Result<()> {
// Install the aws-lc-rs crypto provider for rustls before any TLS init.
// Required because rustls 0.23 no longer auto-selects a provider at compile time.
rustls::crypto::aws_lc_rs::default_provider()
@@ -27,10 +178,120 @@ fn main() -> Result<()> {
// 1. Init telemetry (JSON logs + optional OTEL traces).
telemetry::init(&cfg.telemetry.otlp_endpoint);
// 2. Detect --upgrade flag. When present, Pingora inherits listening socket
// FDs from the upgrade Unix socket instead of binding fresh ports, enabling
// zero-downtime cert/config reloads triggered by the K8s watcher below.
let upgrade = std::env::args().any(|a| a == "--upgrade");
// 2. Load DDoS detection model if configured.
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
if ddos_cfg.enabled {
match ddos::model::TrainedModel::load(
std::path::Path::new(&ddos_cfg.model_path),
Some(ddos_cfg.k),
Some(ddos_cfg.threshold),
) {
Ok(model) => {
let point_count = model.point_count();
let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg));
tracing::info!(
points = point_count,
k = ddos_cfg.k,
threshold = ddos_cfg.threshold,
"DDoS detector loaded"
);
Some(detector)
}
Err(e) => {
tracing::warn!(error = %e, "failed to load DDoS model; detection disabled");
None
}
}
} else {
None
}
} else {
None
};
// 2b. Init rate limiter if configured.
let rate_limiter = if let Some(rl_cfg) = &cfg.rate_limit {
if rl_cfg.enabled {
let limiter = Arc::new(rate_limit::limiter::RateLimiter::new(rl_cfg));
let evict_limiter = limiter.clone();
let interval = rl_cfg.eviction_interval_secs;
std::thread::spawn(move || loop {
std::thread::sleep(std::time::Duration::from_secs(interval));
evict_limiter.evict_stale();
});
tracing::info!(
auth_burst = rl_cfg.authenticated.burst,
auth_rate = rl_cfg.authenticated.rate,
unauth_burst = rl_cfg.unauthenticated.burst,
unauth_rate = rl_cfg.unauthenticated.rate,
"rate limiter enabled"
);
Some(limiter)
} else {
None
}
} else {
None
};
// 2c. Load scanner model if configured.
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
if scanner_cfg.enabled {
match scanner::model::ScannerModel::load(std::path::Path::new(&scanner_cfg.model_path)) {
Ok(mut model) => {
let fragment_count = model.fragments.len();
model.threshold = scanner_cfg.threshold;
let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes);
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
// Start bot allowlist if rules are configured.
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
let al = scanner::allowlist::BotAllowlist::spawn(
&scanner_cfg.allowlist,
scanner_cfg.bot_cache_ttl_secs,
);
tracing::info!(
rules = scanner_cfg.allowlist.len(),
"bot allowlist enabled"
);
Some(al)
} else {
None
};
// Start background file watcher for hot-reload.
if scanner_cfg.poll_interval_secs > 0 {
let watcher_handle = handle.clone();
let model_path = std::path::PathBuf::from(&scanner_cfg.model_path);
let threshold = scanner_cfg.threshold;
let routes = cfg.routes.clone();
let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs);
std::thread::spawn(move || {
scanner::watcher::watch_scanner_model(
watcher_handle, model_path, threshold, routes, interval,
);
});
}
tracing::info!(
fragments = fragment_count,
threshold = scanner_cfg.threshold,
poll_interval_secs = scanner_cfg.poll_interval_secs,
"scanner detector loaded"
);
(Some(handle), bot_allowlist)
}
Err(e) => {
tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled");
(None, None)
}
}
} else {
(None, None)
}
} else {
(None, None)
};
// 3. Fetch the TLS cert from K8s before Pingora binds the TLS port.
// The Client is created and dropped within this temp runtime — we do NOT
@@ -47,8 +308,6 @@ fn main() -> Result<()> {
if let Err(e) =
cert::fetch_and_write(&c, &cfg.tls.cert_path, &cfg.tls.key_path).await
{
// Non-fatal: Secret may not exist yet on first deploy (cert-manager
// is still issuing), or the Secret name may differ in dev.
tracing::warn!(error = %e, "cert fetch from K8s failed; using existing files");
}
}
@@ -83,6 +342,10 @@ fn main() -> Result<()> {
let proxy = SunbeamProxy {
routes: cfg.routes.clone(),
acme_routes: acme_routes.clone(),
ddos_detector,
scanner_detector,
bot_allowlist,
rate_limiter,
};
let mut svc = http_proxy_service(&server.configuration, proxy);
@@ -90,11 +353,6 @@ fn main() -> Result<()> {
svc.add_tcp(&cfg.listen.http);
// Port 443: only add the TLS listener if the cert files exist.
// On first deploy cert-manager hasn't issued the cert yet, so we start
// HTTP-only. Once the pingora-tls Secret is created (ACME challenge
// completes), the watcher in step 6 writes the cert files and triggers
// a graceful upgrade. The upgrade process finds the cert files and adds
// the TLS listener, inheriting the port-80 socket from the old process.
let cert_exists = std::path::Path::new(&cfg.tls.cert_path).exists();
if cert_exists {
svc.add_tls(&cfg.listen.https, &cfg.tls.cert_path, &cfg.tls.key_path)?;
@@ -109,7 +367,6 @@ fn main() -> Result<()> {
server.add_service(svc);
// 5b. SSH TCP passthrough (port 22 → Gitea SSH), if configured.
// Runs on its own OS thread + Tokio runtime — same pattern as the cert/ingress watcher.
if let Some(ssh_cfg) = &cfg.ssh {
let listen = ssh_cfg.listen.clone();
let backend = ssh_cfg.backend.clone();
@@ -123,10 +380,7 @@ fn main() -> Result<()> {
});
}
// 6. Background K8s watchers on their own OS thread + tokio runtime so they
// don't interfere with Pingora's internal runtime. A fresh Client is
// created here so its tower workers live on this runtime (not the
// now-dropped temp runtime from step 3).
// 6. Background K8s watchers on their own OS thread + tokio runtime.
if k8s_available {
let cert_path = cfg.tls.cert_path.clone();
let key_path = cfg.tls.key_path.clone();