Behind the `training` feature flag (burn 0.20 + ndarray + autodiff). Trains a single-hidden-layer MLP with Adam optimizer and weighted BCE loss, plus a CART decision tree using Gini impurity. Exports trained weights as Rust const arrays that compile directly into the binary. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
343 lines
10 KiB
Rust
343 lines
10 KiB
Rust
//! Weight export: converts trained models into standalone Rust `const` arrays
|
||
//! and optionally Lean 4 definitions.
|
||
//!
|
||
//! The generated Rust source is meant to be placed in
|
||
//! `src/ensemble/gen/{scanner,ddos}_weights.rs` so the inference side can use
|
||
//! compile-time weight constants with zero runtime cost.
|
||
|
||
use anyhow::Result;
|
||
use std::fmt::Write as FmtWrite;
|
||
use std::path::Path;
|
||
|
||
/// All data needed to emit a standalone inference source file.
|
||
pub struct ExportedModel {
|
||
/// Module name used in generated code comments and Lean defs.
|
||
pub model_name: String,
|
||
/// Number of input features.
|
||
pub input_dim: usize,
|
||
/// Hidden layer width (always 32 in the current ensemble).
|
||
pub hidden_dim: usize,
|
||
/// Weight matrix for layer 1: `hidden_dim x input_dim`.
|
||
pub w1: Vec<Vec<f32>>,
|
||
/// Bias vector for layer 1: length `hidden_dim`.
|
||
pub b1: Vec<f32>,
|
||
/// Weight vector for layer 2: length `hidden_dim`.
|
||
pub w2: Vec<f32>,
|
||
/// Bias scalar for layer 2.
|
||
pub b2: f32,
|
||
/// Packed decision tree nodes.
|
||
pub tree_nodes: Vec<(u8, f32, u16, u16)>,
|
||
/// MLP classification threshold.
|
||
pub threshold: f32,
|
||
/// Per-feature normalization minimums.
|
||
pub norm_mins: Vec<f32>,
|
||
/// Per-feature normalization maximums.
|
||
pub norm_maxs: Vec<f32>,
|
||
}
|
||
|
||
/// Generate a Rust source file with `const` arrays for all model weights.
|
||
pub fn generate_rust_source(model: &ExportedModel) -> String {
|
||
let mut s = String::with_capacity(8192);
|
||
|
||
writeln!(
|
||
s,
|
||
"//! Auto-generated weights for the {} ensemble.",
|
||
model.model_name
|
||
)
|
||
.unwrap();
|
||
writeln!(
|
||
s,
|
||
"//! DO NOT EDIT — regenerate with `cargo run --features training -- train-{}-mlp`.",
|
||
model.model_name.to_ascii_lowercase()
|
||
)
|
||
.unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
// Threshold.
|
||
writeln!(s, "pub const THRESHOLD: f32 = {:.8};", model.threshold).unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
// Normalization params.
|
||
write_f32_array(&mut s, "NORM_MINS", &model.norm_mins);
|
||
write_f32_array(&mut s, "NORM_MAXS", &model.norm_maxs);
|
||
|
||
// W1: hidden_dim x input_dim.
|
||
writeln!(
|
||
s,
|
||
"pub const W1: [[f32; {}]; {}] = [",
|
||
model.input_dim, model.hidden_dim
|
||
)
|
||
.unwrap();
|
||
for row in &model.w1 {
|
||
write!(s, " [").unwrap();
|
||
for (i, v) in row.iter().enumerate() {
|
||
if i > 0 {
|
||
write!(s, ", ").unwrap();
|
||
}
|
||
write!(s, "{:.8}", v).unwrap();
|
||
}
|
||
writeln!(s, "],").unwrap();
|
||
}
|
||
writeln!(s, "];").unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
// B1.
|
||
write_f32_array(&mut s, "B1", &model.b1);
|
||
|
||
// W2.
|
||
write_f32_array(&mut s, "W2", &model.w2);
|
||
|
||
// B2.
|
||
writeln!(s, "pub const B2: f32 = {:.8};", model.b2).unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
// Tree nodes.
|
||
writeln!(
|
||
s,
|
||
"pub const TREE_NODES: [(u8, f32, u16, u16); {}] = [",
|
||
model.tree_nodes.len()
|
||
)
|
||
.unwrap();
|
||
for &(feat, thresh, left, right) in &model.tree_nodes {
|
||
writeln!(
|
||
s,
|
||
" ({}, {:.8}, {}, {}),",
|
||
feat, thresh, left, right
|
||
)
|
||
.unwrap();
|
||
}
|
||
writeln!(s, "];").unwrap();
|
||
|
||
s
|
||
}
|
||
|
||
/// Generate Lean 4 definitions for formal verification.
|
||
pub fn generate_lean_source(model: &ExportedModel) -> String {
|
||
let mut s = String::with_capacity(8192);
|
||
|
||
writeln!(
|
||
s,
|
||
"-- Auto-generated Lean 4 definitions for {} ensemble.",
|
||
model.model_name
|
||
)
|
||
.unwrap();
|
||
writeln!(
|
||
s,
|
||
"-- DO NOT EDIT — regenerate with the training pipeline."
|
||
)
|
||
.unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
writeln!(
|
||
s,
|
||
"namespace Sunbeam.Ensemble.{}",
|
||
capitalize(&model.model_name)
|
||
)
|
||
.unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
writeln!(s, "def inputDim : Nat := {}", model.input_dim).unwrap();
|
||
writeln!(s, "def hiddenDim : Nat := {}", model.hidden_dim).unwrap();
|
||
writeln!(s, "def threshold : Float := {:.8}", model.threshold).unwrap();
|
||
writeln!(s, "def b2 : Float := {:.8}", model.b2).unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
write_lean_float_list(&mut s, "normMins", &model.norm_mins);
|
||
write_lean_float_list(&mut s, "normMaxs", &model.norm_maxs);
|
||
write_lean_float_list(&mut s, "b1", &model.b1);
|
||
write_lean_float_list(&mut s, "w2", &model.w2);
|
||
|
||
// W1 as list of lists.
|
||
writeln!(s, "def w1 : List (List Float) := [").unwrap();
|
||
for (i, row) in model.w1.iter().enumerate() {
|
||
let comma = if i + 1 < model.w1.len() { "," } else { "" };
|
||
write!(s, " [").unwrap();
|
||
for (j, v) in row.iter().enumerate() {
|
||
if j > 0 {
|
||
write!(s, ", ").unwrap();
|
||
}
|
||
write!(s, "{:.8}", v).unwrap();
|
||
}
|
||
writeln!(s, "]{}", comma).unwrap();
|
||
}
|
||
writeln!(s, "]").unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
// Tree nodes as list of tuples.
|
||
writeln!(
|
||
s,
|
||
"def treeNodes : List (Nat × Float × Nat × Nat) := ["
|
||
)
|
||
.unwrap();
|
||
for (i, &(feat, thresh, left, right)) in model.tree_nodes.iter().enumerate() {
|
||
let comma = if i + 1 < model.tree_nodes.len() {
|
||
","
|
||
} else {
|
||
""
|
||
};
|
||
writeln!(
|
||
s,
|
||
" ({}, {:.8}, {}, {}){}",
|
||
feat, thresh, left, right, comma
|
||
)
|
||
.unwrap();
|
||
}
|
||
writeln!(s, "]").unwrap();
|
||
writeln!(s).unwrap();
|
||
|
||
writeln!(
|
||
s,
|
||
"end Sunbeam.Ensemble.{}",
|
||
capitalize(&model.model_name)
|
||
)
|
||
.unwrap();
|
||
|
||
s
|
||
}
|
||
|
||
/// Write the generated Rust source to a file.
|
||
pub fn export_to_file(model: &ExportedModel, path: &Path) -> Result<()> {
|
||
let source = generate_rust_source(model);
|
||
std::fs::write(path, source.as_bytes())
|
||
.map_err(|e| anyhow::anyhow!("writing export to {}: {}", path.display(), e))?;
|
||
Ok(())
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Helpers
|
||
// ---------------------------------------------------------------------------
|
||
|
||
fn write_f32_array(s: &mut String, name: &str, values: &[f32]) {
|
||
writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap();
|
||
write!(s, " ").unwrap();
|
||
for (i, v) in values.iter().enumerate() {
|
||
if i > 0 {
|
||
write!(s, ", ").unwrap();
|
||
}
|
||
// Line-wrap every 8 values for readability.
|
||
if i > 0 && i % 8 == 0 {
|
||
write!(s, "\n ").unwrap();
|
||
}
|
||
write!(s, "{:.8}", v).unwrap();
|
||
}
|
||
writeln!(s, "\n];").unwrap();
|
||
writeln!(s).unwrap();
|
||
}
|
||
|
||
fn write_lean_float_list(s: &mut String, name: &str, values: &[f32]) {
|
||
write!(s, "def {} : List Float := [", name).unwrap();
|
||
for (i, v) in values.iter().enumerate() {
|
||
if i > 0 {
|
||
write!(s, ", ").unwrap();
|
||
}
|
||
write!(s, "{:.8}", v).unwrap();
|
||
}
|
||
writeln!(s, "]").unwrap();
|
||
writeln!(s).unwrap();
|
||
}
|
||
|
||
fn capitalize(s: &str) -> String {
|
||
let mut c = s.chars();
|
||
match c.next() {
|
||
None => String::new(),
|
||
Some(first) => first.to_uppercase().to_string() + c.as_str(),
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
fn make_test_model() -> ExportedModel {
|
||
ExportedModel {
|
||
model_name: "scanner".to_string(),
|
||
input_dim: 2,
|
||
hidden_dim: 2,
|
||
w1: vec![vec![0.1, 0.2], vec![0.3, 0.4]],
|
||
b1: vec![0.01, 0.02],
|
||
w2: vec![0.5, 0.6],
|
||
b2: -0.1,
|
||
tree_nodes: vec![
|
||
(0, 0.5, 1, 2),
|
||
(255, 0.0, 0, 0),
|
||
(255, 1.0, 0, 0),
|
||
],
|
||
threshold: 0.5,
|
||
norm_mins: vec![0.0, 0.0],
|
||
norm_maxs: vec![1.0, 10.0],
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_rust_source_contains_consts() {
|
||
let model = make_test_model();
|
||
let src = generate_rust_source(&model);
|
||
|
||
assert!(src.contains("pub const THRESHOLD: f32 ="), "missing THRESHOLD");
|
||
assert!(src.contains("pub const NORM_MINS:"), "missing NORM_MINS");
|
||
assert!(src.contains("pub const NORM_MAXS:"), "missing NORM_MAXS");
|
||
assert!(src.contains("pub const W1:"), "missing W1");
|
||
assert!(src.contains("pub const B1:"), "missing B1");
|
||
assert!(src.contains("pub const W2:"), "missing W2");
|
||
assert!(src.contains("pub const B2: f32 ="), "missing B2");
|
||
assert!(src.contains("pub const TREE_NODES:"), "missing TREE_NODES");
|
||
}
|
||
|
||
#[test]
|
||
fn test_rust_source_array_dims() {
|
||
let model = make_test_model();
|
||
let src = generate_rust_source(&model);
|
||
|
||
// W1 should be [f32; 2]; 2]
|
||
assert!(src.contains("[[f32; 2]; 2]"), "W1 dimensions wrong");
|
||
// B1 should be [f32; 2]
|
||
assert!(src.contains("B1: [f32; 2]"), "B1 dimensions wrong");
|
||
// W2 should be [f32; 2]
|
||
assert!(src.contains("W2: [f32; 2]"), "W2 dimensions wrong");
|
||
// TREE_NODES should have 3 entries
|
||
assert!(
|
||
src.contains("[(u8, f32, u16, u16); 3]"),
|
||
"TREE_NODES count wrong"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_weight_values_roundtrip() {
|
||
let model = make_test_model();
|
||
let src = generate_rust_source(&model);
|
||
|
||
// The threshold should appear with reasonable precision.
|
||
assert!(src.contains("0.50000000"), "threshold value missing");
|
||
// B2 value.
|
||
assert!(src.contains("-0.10000000"), "b2 value missing");
|
||
}
|
||
|
||
#[test]
|
||
fn test_lean_source_structure() {
|
||
let model = make_test_model();
|
||
let src = generate_lean_source(&model);
|
||
|
||
assert!(src.contains("namespace Sunbeam.Ensemble.Scanner"));
|
||
assert!(src.contains("def inputDim : Nat := 2"));
|
||
assert!(src.contains("def hiddenDim : Nat := 2"));
|
||
assert!(src.contains("def threshold : Float :="));
|
||
assert!(src.contains("def normMins : List Float :="));
|
||
assert!(src.contains("def w1 : List (List Float) :="));
|
||
assert!(src.contains("def treeNodes :"));
|
||
assert!(src.contains("end Sunbeam.Ensemble.Scanner"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_export_to_file() {
|
||
let model = make_test_model();
|
||
let dir = tempfile::tempdir().unwrap();
|
||
let path = dir.path().join("test_weights.rs");
|
||
|
||
export_to_file(&model, &path).unwrap();
|
||
|
||
let content = std::fs::read_to_string(&path).unwrap();
|
||
assert!(content.contains("pub const THRESHOLD:"));
|
||
assert!(content.contains("pub const W1:"));
|
||
}
|
||
}
|