Files
proxy/src/main.rs
Sienna Meridian Satterwhite a5810dd8a7 feat: configurable k8s resources, CSIC training pipeline, unified Dockerfile
- Make K8s namespace, TLS secret, and config ConfigMap names configurable
  via [kubernetes] config section (previously hardcoded to "ingress")
- Add CSIC 2010 dataset converter and auto-download for scanner training
- Unify Dockerfile for local and production builds (remove cross-compile path)
- Bake ML models directory into container image
- Update CSIC dataset URL to self-hosted mirror (src.sunbeam.pt)
- Fix rate_limit pipeline log missing fields
- Consolidate docs/README.md into root README.md

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
2026-03-10 23:38:20 +00:00

457 lines
16 KiB
Rust

mod cert;
mod telemetry;
mod watcher;
use sunbeam_proxy::{acme, config};
use sunbeam_proxy::proxy::SunbeamProxy;
use sunbeam_proxy::ddos;
use sunbeam_proxy::rate_limit;
use sunbeam_proxy::scanner;
use std::{collections::HashMap, sync::Arc};
use anyhow::Result;
use clap::{Parser, Subcommand};
use kube::Client;
use pingora::server::{configuration::Opt, Server};
use pingora_proxy::http_proxy_service;
use std::sync::RwLock;
#[derive(Parser)]
#[command(name = "sunbeam-proxy")]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
/// Start the proxy server (default if no subcommand given)
Serve {
/// Pingora --upgrade flag for zero-downtime reload
#[arg(long)]
upgrade: bool,
},
/// Replay audit logs through the DDoS detector and rate limiter
ReplayDdos {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Path to trained model file
#[arg(short, long, default_value = "ddos_model.bin")]
model: String,
/// Optional config file (for rate limit settings)
#[arg(short, long)]
config: Option<String>,
/// KNN k parameter
#[arg(long, default_value = "5")]
k: usize,
/// Attack threshold
#[arg(long, default_value = "0.6")]
threshold: f64,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP before classification
#[arg(long, default_value = "10")]
min_events: usize,
/// Also run rate limiter during replay
#[arg(long)]
rate_limit: bool,
},
/// Train a DDoS detection model from audit logs
TrainDdos {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output model file path
#[arg(short, long)]
output: String,
/// File with known-attack IPs (one per line)
#[arg(long)]
attack_ips: Option<String>,
/// File with known-normal IPs (one per line)
#[arg(long)]
normal_ips: Option<String>,
/// TOML file with heuristic auto-labeling thresholds
#[arg(long)]
heuristics: Option<String>,
/// KNN k parameter
#[arg(long, default_value = "5")]
k: usize,
/// Attack threshold (fraction of k neighbors)
#[arg(long, default_value = "0.6")]
threshold: f64,
/// Sliding window size in seconds
#[arg(long, default_value = "60")]
window_secs: u64,
/// Minimum events per IP to include in training
#[arg(long, default_value = "10")]
min_events: usize,
},
/// Train a per-request scanner detection model from audit logs
TrainScanner {
/// Path to audit log JSONL file
#[arg(short, long)]
input: String,
/// Output model file path
#[arg(short, long, default_value = "scanner_model.bin")]
output: String,
/// Directory (or file) containing .txt wordlists of scanner paths
#[arg(long)]
wordlists: Option<String>,
/// Classification threshold
#[arg(long, default_value = "0.5")]
threshold: f64,
/// Include CSIC 2010 dataset as base training data (downloaded from GitHub, cached locally)
#[arg(long)]
csic: bool,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command.unwrap_or(Commands::Serve { upgrade: false }) {
Commands::Serve { upgrade } => run_serve(upgrade),
Commands::ReplayDdos {
input,
model,
config,
k,
threshold,
window_secs,
min_events,
rate_limit,
} => ddos::replay::run(ddos::replay::ReplayArgs {
input,
model_path: model,
config_path: config,
k,
threshold,
window_secs,
min_events,
rate_limit,
}),
Commands::TrainDdos {
input,
output,
attack_ips,
normal_ips,
heuristics,
k,
threshold,
window_secs,
min_events,
} => ddos::train::run(ddos::train::TrainArgs {
input,
output,
attack_ips,
normal_ips,
heuristics,
k,
threshold,
window_secs,
min_events,
}),
Commands::TrainScanner {
input,
output,
wordlists,
threshold,
csic,
} => scanner::train::run(scanner::train::TrainScannerArgs {
input,
output,
wordlists,
threshold,
csic,
}),
}
}
fn run_serve(upgrade: bool) -> Result<()> {
// Install the aws-lc-rs crypto provider for rustls before any TLS init.
// Required because rustls 0.23 no longer auto-selects a provider at compile time.
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.expect("crypto provider already installed");
let config_path = std::env::var("SUNBEAM_CONFIG")
.unwrap_or_else(|_| "/etc/pingora/config.toml".to_string());
let cfg = config::Config::load(&config_path)?;
// 1. Init telemetry (JSON logs + optional OTEL traces).
telemetry::init(&cfg.telemetry.otlp_endpoint);
// 1b. Spawn metrics HTTP server (needs a tokio runtime for the TCP listener).
let metrics_port = cfg.telemetry.metrics_port;
// 2. Load DDoS detection model if configured.
let ddos_detector = if let Some(ddos_cfg) = &cfg.ddos {
if ddos_cfg.enabled {
match ddos::model::TrainedModel::load(
std::path::Path::new(&ddos_cfg.model_path),
Some(ddos_cfg.k),
Some(ddos_cfg.threshold),
) {
Ok(model) => {
let point_count = model.point_count();
let detector = Arc::new(ddos::detector::DDoSDetector::new(model, ddos_cfg));
tracing::info!(
points = point_count,
k = ddos_cfg.k,
threshold = ddos_cfg.threshold,
"DDoS detector loaded"
);
Some(detector)
}
Err(e) => {
tracing::warn!(error = %e, "failed to load DDoS model; detection disabled");
None
}
}
} else {
None
}
} else {
None
};
// 2b. Init rate limiter if configured.
let rate_limiter = if let Some(rl_cfg) = &cfg.rate_limit {
if rl_cfg.enabled {
let limiter = Arc::new(rate_limit::limiter::RateLimiter::new(rl_cfg));
let evict_limiter = limiter.clone();
let interval = rl_cfg.eviction_interval_secs;
std::thread::spawn(move || loop {
std::thread::sleep(std::time::Duration::from_secs(interval));
evict_limiter.evict_stale();
});
tracing::info!(
auth_burst = rl_cfg.authenticated.burst,
auth_rate = rl_cfg.authenticated.rate,
unauth_burst = rl_cfg.unauthenticated.burst,
unauth_rate = rl_cfg.unauthenticated.rate,
"rate limiter enabled"
);
Some(limiter)
} else {
None
}
} else {
None
};
// 2c. Load scanner model if configured.
let (scanner_detector, bot_allowlist) = if let Some(scanner_cfg) = &cfg.scanner {
if scanner_cfg.enabled {
match scanner::model::ScannerModel::load(std::path::Path::new(&scanner_cfg.model_path)) {
Ok(mut model) => {
let fragment_count = model.fragments.len();
model.threshold = scanner_cfg.threshold;
let detector = scanner::detector::ScannerDetector::new(&model, &cfg.routes);
let handle = Arc::new(arc_swap::ArcSwap::from_pointee(detector));
// Start bot allowlist if rules are configured.
let bot_allowlist = if !scanner_cfg.allowlist.is_empty() {
let al = scanner::allowlist::BotAllowlist::spawn(
&scanner_cfg.allowlist,
scanner_cfg.bot_cache_ttl_secs,
);
tracing::info!(
rules = scanner_cfg.allowlist.len(),
"bot allowlist enabled"
);
Some(al)
} else {
None
};
// Start background file watcher for hot-reload.
if scanner_cfg.poll_interval_secs > 0 {
let watcher_handle = handle.clone();
let model_path = std::path::PathBuf::from(&scanner_cfg.model_path);
let threshold = scanner_cfg.threshold;
let routes = cfg.routes.clone();
let interval = std::time::Duration::from_secs(scanner_cfg.poll_interval_secs);
std::thread::spawn(move || {
scanner::watcher::watch_scanner_model(
watcher_handle, model_path, threshold, routes, interval,
);
});
}
tracing::info!(
fragments = fragment_count,
threshold = scanner_cfg.threshold,
poll_interval_secs = scanner_cfg.poll_interval_secs,
"scanner detector loaded"
);
(Some(handle), bot_allowlist)
}
Err(e) => {
tracing::warn!(error = %e, "failed to load scanner model; scanner detection disabled");
(None, None)
}
}
} else {
(None, None)
}
} else {
(None, None)
};
// 3. Fetch the TLS cert from K8s before Pingora binds the TLS port.
// The Client is created and dropped within this temp runtime — we do NOT
// carry it across runtime boundaries, which would kill its tower workers.
// The watcher thread creates its own fresh Client on its own runtime.
let k8s_available = {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
rt.block_on(async {
match Client::try_default().await {
Ok(c) => {
if !upgrade {
if let Err(e) =
cert::fetch_and_write(
&c,
&cfg.kubernetes.namespace,
&cfg.kubernetes.tls_secret,
&cfg.tls.cert_path,
&cfg.tls.key_path,
).await
{
tracing::warn!(error = %e, "cert fetch from K8s failed; using existing files");
}
}
true
}
Err(e) => {
tracing::warn!(error = %e, "no K8s client; cert auto-reload and ACME routing disabled");
false
}
}
})
};
let opt = Opt {
upgrade,
daemon: false,
nocapture: false,
test: false,
conf: None,
};
// 4. Create Pingora server and bootstrap (binds ports or inherits FDs).
let mut server = Server::new(Some(opt))?;
server.bootstrap();
// 5. Shared ACME challenge route table. Populated by the Ingress watcher;
// consulted by the proxy for every /.well-known/acme-challenge/ request.
// Uses std::sync::RwLock so reads are sync and lock-guard-safe across
// Pingora's async proxy calls without cross-runtime waker concerns.
let acme_routes: acme::AcmeRoutes = Arc::new(RwLock::new(HashMap::new()));
let compiled_rewrites = SunbeamProxy::compile_rewrites(&cfg.routes);
let http_client = reqwest::Client::new();
let proxy = SunbeamProxy {
routes: cfg.routes.clone(),
acme_routes: acme_routes.clone(),
ddos_detector,
scanner_detector,
bot_allowlist,
rate_limiter,
compiled_rewrites,
http_client,
};
let mut svc = http_proxy_service(&server.configuration, proxy);
// Port 80: always serve plain HTTP (ACME challenges + redirect to HTTPS).
svc.add_tcp(&cfg.listen.http);
// Port 443: only add the TLS listener if the cert files exist.
let cert_exists = std::path::Path::new(&cfg.tls.cert_path).exists();
if cert_exists {
svc.add_tls(&cfg.listen.https, &cfg.tls.cert_path, &cfg.tls.key_path)?;
tracing::info!("TLS listener added on {}", cfg.listen.https);
} else {
tracing::warn!(
cert_path = %cfg.tls.cert_path,
"cert not found — starting HTTP-only; ACME challenge will complete and trigger upgrade"
);
}
server.add_service(svc);
// 5b. Spawn metrics + health HTTP server on its own thread.
if metrics_port > 0 {
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("metrics runtime");
rt.block_on(async {
sunbeam_proxy::metrics::spawn_metrics_server(metrics_port);
// Keep the runtime alive.
std::future::pending::<()>().await;
});
});
}
// 5c. SSH TCP passthrough (port 22 → Gitea SSH), if configured.
if let Some(ssh_cfg) = &cfg.ssh {
let listen = ssh_cfg.listen.clone();
let backend = ssh_cfg.backend.clone();
tracing::info!(%listen, %backend, "SSH TCP proxy enabled");
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("ssh proxy runtime");
rt.block_on(sunbeam_proxy::ssh::run_tcp_proxy(&listen, &backend));
});
}
// 6. Background K8s watchers on their own OS thread + tokio runtime.
if k8s_available {
let k8s_cfg = cfg.kubernetes.clone();
let cert_path = cfg.tls.cert_path.clone();
let key_path = cfg.tls.key_path.clone();
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("watcher runtime");
rt.block_on(async move {
let client = match Client::try_default().await {
Ok(c) => c,
Err(e) => {
tracing::error!(error = %e, "watcher: failed to create K8s client; watchers disabled");
return;
}
};
tokio::join!(
acme::watch_ingresses(
client.clone(),
k8s_cfg.namespace.clone(),
acme_routes,
),
watcher::run_watcher(
client,
k8s_cfg.namespace,
k8s_cfg.tls_secret,
k8s_cfg.config_configmap,
cert_path,
key_path,
),
);
});
});
}
tracing::info!(upgrade, "sunbeam-proxy starting");
server.run_forever();
}