Files
proxy/src/training/export.rs
Sienna Meridian Satterwhite 067d822244 feat(training): add burn MLP and CART tree trainers with weight export
Behind the `training` feature flag (burn 0.20 + ndarray + autodiff).
Trains a single-hidden-layer MLP with Adam optimizer and weighted BCE
loss, plus a CART decision tree using Gini impurity. Exports trained
weights as Rust const arrays that compile directly into the binary.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
2026-03-10 23:38:21 +00:00

343 lines
10 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Weight export: converts trained models into standalone Rust `const` arrays
//! and optionally Lean 4 definitions.
//!
//! The generated Rust source is meant to be placed in
//! `src/ensemble/gen/{scanner,ddos}_weights.rs` so the inference side can use
//! compile-time weight constants with zero runtime cost.
use anyhow::Result;
use std::fmt::Write as FmtWrite;
use std::path::Path;
/// All data needed to emit a standalone inference source file.
pub struct ExportedModel {
/// Module name used in generated code comments and Lean defs.
pub model_name: String,
/// Number of input features.
pub input_dim: usize,
/// Hidden layer width (always 32 in the current ensemble).
pub hidden_dim: usize,
/// Weight matrix for layer 1: `hidden_dim x input_dim`.
pub w1: Vec<Vec<f32>>,
/// Bias vector for layer 1: length `hidden_dim`.
pub b1: Vec<f32>,
/// Weight vector for layer 2: length `hidden_dim`.
pub w2: Vec<f32>,
/// Bias scalar for layer 2.
pub b2: f32,
/// Packed decision tree nodes.
pub tree_nodes: Vec<(u8, f32, u16, u16)>,
/// MLP classification threshold.
pub threshold: f32,
/// Per-feature normalization minimums.
pub norm_mins: Vec<f32>,
/// Per-feature normalization maximums.
pub norm_maxs: Vec<f32>,
}
/// Generate a Rust source file with `const` arrays for all model weights.
pub fn generate_rust_source(model: &ExportedModel) -> String {
let mut s = String::with_capacity(8192);
writeln!(
s,
"//! Auto-generated weights for the {} ensemble.",
model.model_name
)
.unwrap();
writeln!(
s,
"//! DO NOT EDIT — regenerate with `cargo run --features training -- train-{}-mlp`.",
model.model_name.to_ascii_lowercase()
)
.unwrap();
writeln!(s).unwrap();
// Threshold.
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", model.threshold).unwrap();
writeln!(s).unwrap();
// Normalization params.
write_f32_array(&mut s, "NORM_MINS", &model.norm_mins);
write_f32_array(&mut s, "NORM_MAXS", &model.norm_maxs);
// W1: hidden_dim x input_dim.
writeln!(
s,
"pub const W1: [[f32; {}]; {}] = [",
model.input_dim, model.hidden_dim
)
.unwrap();
for row in &model.w1 {
write!(s, " [").unwrap();
for (i, v) in row.iter().enumerate() {
if i > 0 {
write!(s, ", ").unwrap();
}
write!(s, "{:.8}", v).unwrap();
}
writeln!(s, "],").unwrap();
}
writeln!(s, "];").unwrap();
writeln!(s).unwrap();
// B1.
write_f32_array(&mut s, "B1", &model.b1);
// W2.
write_f32_array(&mut s, "W2", &model.w2);
// B2.
writeln!(s, "pub const B2: f32 = {:.8};", model.b2).unwrap();
writeln!(s).unwrap();
// Tree nodes.
writeln!(
s,
"pub const TREE_NODES: [(u8, f32, u16, u16); {}] = [",
model.tree_nodes.len()
)
.unwrap();
for &(feat, thresh, left, right) in &model.tree_nodes {
writeln!(
s,
" ({}, {:.8}, {}, {}),",
feat, thresh, left, right
)
.unwrap();
}
writeln!(s, "];").unwrap();
s
}
/// Generate Lean 4 definitions for formal verification.
pub fn generate_lean_source(model: &ExportedModel) -> String {
let mut s = String::with_capacity(8192);
writeln!(
s,
"-- Auto-generated Lean 4 definitions for {} ensemble.",
model.model_name
)
.unwrap();
writeln!(
s,
"-- DO NOT EDIT — regenerate with the training pipeline."
)
.unwrap();
writeln!(s).unwrap();
writeln!(
s,
"namespace Sunbeam.Ensemble.{}",
capitalize(&model.model_name)
)
.unwrap();
writeln!(s).unwrap();
writeln!(s, "def inputDim : Nat := {}", model.input_dim).unwrap();
writeln!(s, "def hiddenDim : Nat := {}", model.hidden_dim).unwrap();
writeln!(s, "def threshold : Float := {:.8}", model.threshold).unwrap();
writeln!(s, "def b2 : Float := {:.8}", model.b2).unwrap();
writeln!(s).unwrap();
write_lean_float_list(&mut s, "normMins", &model.norm_mins);
write_lean_float_list(&mut s, "normMaxs", &model.norm_maxs);
write_lean_float_list(&mut s, "b1", &model.b1);
write_lean_float_list(&mut s, "w2", &model.w2);
// W1 as list of lists.
writeln!(s, "def w1 : List (List Float) := [").unwrap();
for (i, row) in model.w1.iter().enumerate() {
let comma = if i + 1 < model.w1.len() { "," } else { "" };
write!(s, " [").unwrap();
for (j, v) in row.iter().enumerate() {
if j > 0 {
write!(s, ", ").unwrap();
}
write!(s, "{:.8}", v).unwrap();
}
writeln!(s, "]{}", comma).unwrap();
}
writeln!(s, "]").unwrap();
writeln!(s).unwrap();
// Tree nodes as list of tuples.
writeln!(
s,
"def treeNodes : List (Nat × Float × Nat × Nat) := ["
)
.unwrap();
for (i, &(feat, thresh, left, right)) in model.tree_nodes.iter().enumerate() {
let comma = if i + 1 < model.tree_nodes.len() {
","
} else {
""
};
writeln!(
s,
" ({}, {:.8}, {}, {}){}",
feat, thresh, left, right, comma
)
.unwrap();
}
writeln!(s, "]").unwrap();
writeln!(s).unwrap();
writeln!(
s,
"end Sunbeam.Ensemble.{}",
capitalize(&model.model_name)
)
.unwrap();
s
}
/// Write the generated Rust source to a file.
pub fn export_to_file(model: &ExportedModel, path: &Path) -> Result<()> {
let source = generate_rust_source(model);
std::fs::write(path, source.as_bytes())
.map_err(|e| anyhow::anyhow!("writing export to {}: {}", path.display(), e))?;
Ok(())
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap();
write!(s, " ").unwrap();
for (i, v) in values.iter().enumerate() {
if i > 0 {
write!(s, ", ").unwrap();
}
// Line-wrap every 8 values for readability.
if i > 0 && i % 8 == 0 {
write!(s, "\n ").unwrap();
}
write!(s, "{:.8}", v).unwrap();
}
writeln!(s, "\n];").unwrap();
writeln!(s).unwrap();
}
fn write_lean_float_list(s: &mut String, name: &str, values: &[f32]) {
write!(s, "def {} : List Float := [", name).unwrap();
for (i, v) in values.iter().enumerate() {
if i > 0 {
write!(s, ", ").unwrap();
}
write!(s, "{:.8}", v).unwrap();
}
writeln!(s, "]").unwrap();
writeln!(s).unwrap();
}
fn capitalize(s: &str) -> String {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(first) => first.to_uppercase().to_string() + c.as_str(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_model() -> ExportedModel {
ExportedModel {
model_name: "scanner".to_string(),
input_dim: 2,
hidden_dim: 2,
w1: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
b1: vec![0.01, 0.02],
w2: vec![0.5, 0.6],
b2: -0.1,
tree_nodes: vec![
(0, 0.5, 1, 2),
(255, 0.0, 0, 0),
(255, 1.0, 0, 0),
],
threshold: 0.5,
norm_mins: vec![0.0, 0.0],
norm_maxs: vec![1.0, 10.0],
}
}
#[test]
fn test_rust_source_contains_consts() {
let model = make_test_model();
let src = generate_rust_source(&model);
assert!(src.contains("pub const THRESHOLD: f32 ="), "missing THRESHOLD");
assert!(src.contains("pub const NORM_MINS:"), "missing NORM_MINS");
assert!(src.contains("pub const NORM_MAXS:"), "missing NORM_MAXS");
assert!(src.contains("pub const W1:"), "missing W1");
assert!(src.contains("pub const B1:"), "missing B1");
assert!(src.contains("pub const W2:"), "missing W2");
assert!(src.contains("pub const B2: f32 ="), "missing B2");
assert!(src.contains("pub const TREE_NODES:"), "missing TREE_NODES");
}
#[test]
fn test_rust_source_array_dims() {
let model = make_test_model();
let src = generate_rust_source(&model);
// W1 should be [f32; 2]; 2]
assert!(src.contains("[[f32; 2]; 2]"), "W1 dimensions wrong");
// B1 should be [f32; 2]
assert!(src.contains("B1: [f32; 2]"), "B1 dimensions wrong");
// W2 should be [f32; 2]
assert!(src.contains("W2: [f32; 2]"), "W2 dimensions wrong");
// TREE_NODES should have 3 entries
assert!(
src.contains("[(u8, f32, u16, u16); 3]"),
"TREE_NODES count wrong"
);
}
#[test]
fn test_weight_values_roundtrip() {
let model = make_test_model();
let src = generate_rust_source(&model);
// The threshold should appear with reasonable precision.
assert!(src.contains("0.50000000"), "threshold value missing");
// B2 value.
assert!(src.contains("-0.10000000"), "b2 value missing");
}
#[test]
fn test_lean_source_structure() {
let model = make_test_model();
let src = generate_lean_source(&model);
assert!(src.contains("namespace Sunbeam.Ensemble.Scanner"));
assert!(src.contains("def inputDim : Nat := 2"));
assert!(src.contains("def hiddenDim : Nat := 2"));
assert!(src.contains("def threshold : Float :="));
assert!(src.contains("def normMins : List Float :="));
assert!(src.contains("def w1 : List (List Float) :="));
assert!(src.contains("def treeNodes :"));
assert!(src.contains("end Sunbeam.Ensemble.Scanner"));
}
#[test]
fn test_export_to_file() {
let model = make_test_model();
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test_weights.rs");
export_to_file(&model, &path).unwrap();
let content = std::fs::read_to_string(&path).unwrap();
assert!(content.contains("pub const THRESHOLD:"));
assert!(content.contains("pub const W1:"));
}
}