//! Weight export: converts trained models into standalone Rust `const` arrays //! and optionally Lean 4 definitions. //! //! The generated Rust source is meant to be placed in //! `src/ensemble/gen/{scanner,ddos}_weights.rs` so the inference side can use //! compile-time weight constants with zero runtime cost. use anyhow::Result; use std::fmt::Write as FmtWrite; use std::path::Path; /// All data needed to emit a standalone inference source file. pub struct ExportedModel { /// Module name used in generated code comments and Lean defs. pub model_name: String, /// Number of input features. pub input_dim: usize, /// Hidden layer width (always 32 in the current ensemble). pub hidden_dim: usize, /// Weight matrix for layer 1: `hidden_dim x input_dim`. pub w1: Vec>, /// Bias vector for layer 1: length `hidden_dim`. pub b1: Vec, /// Weight vector for layer 2: length `hidden_dim`. pub w2: Vec, /// Bias scalar for layer 2. pub b2: f32, /// Packed decision tree nodes. pub tree_nodes: Vec<(u8, f32, u16, u16)>, /// MLP classification threshold. pub threshold: f32, /// Per-feature normalization minimums. pub norm_mins: Vec, /// Per-feature normalization maximums. pub norm_maxs: Vec, } /// Generate a Rust source file with `const` arrays for all model weights. pub fn generate_rust_source(model: &ExportedModel) -> String { let mut s = String::with_capacity(8192); writeln!( s, "//! Auto-generated weights for the {} ensemble.", model.model_name ) .unwrap(); writeln!( s, "//! DO NOT EDIT — regenerate with `cargo run --features training -- train-{}-mlp`.", model.model_name.to_ascii_lowercase() ) .unwrap(); writeln!(s).unwrap(); // Threshold. writeln!(s, "pub const THRESHOLD: f32 = {:.8};", model.threshold).unwrap(); writeln!(s).unwrap(); // Normalization params. write_f32_array(&mut s, "NORM_MINS", &model.norm_mins); write_f32_array(&mut s, "NORM_MAXS", &model.norm_maxs); // W1: hidden_dim x input_dim. writeln!( s, "pub const W1: [[f32; {}]; {}] = [", model.input_dim, model.hidden_dim ) .unwrap(); for row in &model.w1 { write!(s, " [").unwrap(); for (i, v) in row.iter().enumerate() { if i > 0 { write!(s, ", ").unwrap(); } write!(s, "{:.8}", v).unwrap(); } writeln!(s, "],").unwrap(); } writeln!(s, "];").unwrap(); writeln!(s).unwrap(); // B1. write_f32_array(&mut s, "B1", &model.b1); // W2. write_f32_array(&mut s, "W2", &model.w2); // B2. writeln!(s, "pub const B2: f32 = {:.8};", model.b2).unwrap(); writeln!(s).unwrap(); // Tree nodes. writeln!( s, "pub const TREE_NODES: [(u8, f32, u16, u16); {}] = [", model.tree_nodes.len() ) .unwrap(); for &(feat, thresh, left, right) in &model.tree_nodes { writeln!( s, " ({}, {:.8}, {}, {}),", feat, thresh, left, right ) .unwrap(); } writeln!(s, "];").unwrap(); s } /// Generate Lean 4 definitions for formal verification. pub fn generate_lean_source(model: &ExportedModel) -> String { let mut s = String::with_capacity(8192); writeln!( s, "-- Auto-generated Lean 4 definitions for {} ensemble.", model.model_name ) .unwrap(); writeln!( s, "-- DO NOT EDIT — regenerate with the training pipeline." ) .unwrap(); writeln!(s).unwrap(); writeln!( s, "namespace Sunbeam.Ensemble.{}", capitalize(&model.model_name) ) .unwrap(); writeln!(s).unwrap(); writeln!(s, "def inputDim : Nat := {}", model.input_dim).unwrap(); writeln!(s, "def hiddenDim : Nat := {}", model.hidden_dim).unwrap(); writeln!(s, "def threshold : Float := {:.8}", model.threshold).unwrap(); writeln!(s, "def b2 : Float := {:.8}", model.b2).unwrap(); writeln!(s).unwrap(); write_lean_float_list(&mut s, "normMins", &model.norm_mins); write_lean_float_list(&mut s, "normMaxs", &model.norm_maxs); write_lean_float_list(&mut s, "b1", &model.b1); write_lean_float_list(&mut s, "w2", &model.w2); // W1 as list of lists. writeln!(s, "def w1 : List (List Float) := [").unwrap(); for (i, row) in model.w1.iter().enumerate() { let comma = if i + 1 < model.w1.len() { "," } else { "" }; write!(s, " [").unwrap(); for (j, v) in row.iter().enumerate() { if j > 0 { write!(s, ", ").unwrap(); } write!(s, "{:.8}", v).unwrap(); } writeln!(s, "]{}", comma).unwrap(); } writeln!(s, "]").unwrap(); writeln!(s).unwrap(); // Tree nodes as list of tuples. writeln!( s, "def treeNodes : List (Nat × Float × Nat × Nat) := [" ) .unwrap(); for (i, &(feat, thresh, left, right)) in model.tree_nodes.iter().enumerate() { let comma = if i + 1 < model.tree_nodes.len() { "," } else { "" }; writeln!( s, " ({}, {:.8}, {}, {}){}", feat, thresh, left, right, comma ) .unwrap(); } writeln!(s, "]").unwrap(); writeln!(s).unwrap(); writeln!( s, "end Sunbeam.Ensemble.{}", capitalize(&model.model_name) ) .unwrap(); s } /// Write the generated Rust source to a file. pub fn export_to_file(model: &ExportedModel, path: &Path) -> Result<()> { let source = generate_rust_source(model); std::fs::write(path, source.as_bytes()) .map_err(|e| anyhow::anyhow!("writing export to {}: {}", path.display(), e))?; Ok(()) } // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- fn write_f32_array(s: &mut String, name: &str, values: &[f32]) { writeln!(s, "pub const {}: [f32; {}] = [", name, values.len()).unwrap(); write!(s, " ").unwrap(); for (i, v) in values.iter().enumerate() { if i > 0 { write!(s, ", ").unwrap(); } // Line-wrap every 8 values for readability. if i > 0 && i % 8 == 0 { write!(s, "\n ").unwrap(); } write!(s, "{:.8}", v).unwrap(); } writeln!(s, "\n];").unwrap(); writeln!(s).unwrap(); } fn write_lean_float_list(s: &mut String, name: &str, values: &[f32]) { write!(s, "def {} : List Float := [", name).unwrap(); for (i, v) in values.iter().enumerate() { if i > 0 { write!(s, ", ").unwrap(); } write!(s, "{:.8}", v).unwrap(); } writeln!(s, "]").unwrap(); writeln!(s).unwrap(); } fn capitalize(s: &str) -> String { let mut c = s.chars(); match c.next() { None => String::new(), Some(first) => first.to_uppercase().to_string() + c.as_str(), } } #[cfg(test)] mod tests { use super::*; fn make_test_model() -> ExportedModel { ExportedModel { model_name: "scanner".to_string(), input_dim: 2, hidden_dim: 2, w1: vec![vec![0.1, 0.2], vec![0.3, 0.4]], b1: vec![0.01, 0.02], w2: vec![0.5, 0.6], b2: -0.1, tree_nodes: vec![ (0, 0.5, 1, 2), (255, 0.0, 0, 0), (255, 1.0, 0, 0), ], threshold: 0.5, norm_mins: vec![0.0, 0.0], norm_maxs: vec![1.0, 10.0], } } #[test] fn test_rust_source_contains_consts() { let model = make_test_model(); let src = generate_rust_source(&model); assert!(src.contains("pub const THRESHOLD: f32 ="), "missing THRESHOLD"); assert!(src.contains("pub const NORM_MINS:"), "missing NORM_MINS"); assert!(src.contains("pub const NORM_MAXS:"), "missing NORM_MAXS"); assert!(src.contains("pub const W1:"), "missing W1"); assert!(src.contains("pub const B1:"), "missing B1"); assert!(src.contains("pub const W2:"), "missing W2"); assert!(src.contains("pub const B2: f32 ="), "missing B2"); assert!(src.contains("pub const TREE_NODES:"), "missing TREE_NODES"); } #[test] fn test_rust_source_array_dims() { let model = make_test_model(); let src = generate_rust_source(&model); // W1 should be [f32; 2]; 2] assert!(src.contains("[[f32; 2]; 2]"), "W1 dimensions wrong"); // B1 should be [f32; 2] assert!(src.contains("B1: [f32; 2]"), "B1 dimensions wrong"); // W2 should be [f32; 2] assert!(src.contains("W2: [f32; 2]"), "W2 dimensions wrong"); // TREE_NODES should have 3 entries assert!( src.contains("[(u8, f32, u16, u16); 3]"), "TREE_NODES count wrong" ); } #[test] fn test_weight_values_roundtrip() { let model = make_test_model(); let src = generate_rust_source(&model); // The threshold should appear with reasonable precision. assert!(src.contains("0.50000000"), "threshold value missing"); // B2 value. assert!(src.contains("-0.10000000"), "b2 value missing"); } #[test] fn test_lean_source_structure() { let model = make_test_model(); let src = generate_lean_source(&model); assert!(src.contains("namespace Sunbeam.Ensemble.Scanner")); assert!(src.contains("def inputDim : Nat := 2")); assert!(src.contains("def hiddenDim : Nat := 2")); assert!(src.contains("def threshold : Float :=")); assert!(src.contains("def normMins : List Float :=")); assert!(src.contains("def w1 : List (List Float) :=")); assert!(src.contains("def treeNodes :")); assert!(src.contains("end Sunbeam.Ensemble.Scanner")); } #[test] fn test_export_to_file() { let model = make_test_model(); let dir = tempfile::tempdir().unwrap(); let path = dir.path().join("test_weights.rs"); export_to_file(&model, &path).unwrap(); let content = std::fs::read_to_string(&path).unwrap(); assert!(content.contains("pub const THRESHOLD:")); assert!(content.contains("pub const W1:")); } }