feat: complete ensemble integration and remove legacy model code
- Remove legacy KNN DDoS replay and scanner model file watcher - Wire ensemble inference into detector check() paths - Update config: remove model_path/k/poll_interval_secs, add observe_only - Add cookie_weight sweep CLI command for hyperparameter exploration - Update training pipeline: batch iterator, weight export improvements - Retrain ensemble weights (scanner 99.73%, DDoS 99.99% val accuracy) - Add unified audit log module - Update dataset parsers with copyright headers and minor fixes Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
112
src/training/batch.rs
Normal file
112
src/training/batch.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright Sunbeam Studios 2026
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! Shared training infrastructure: Dataset adapter, Batcher, and batch types
|
||||
//! for use with burn's SupervisedTraining.
|
||||
|
||||
use crate::dataset::sample::TrainingSample;
|
||||
|
||||
use burn::data::dataloader::batcher::Batcher;
|
||||
use burn::data::dataloader::Dataset;
|
||||
use burn::prelude::*;
|
||||
|
||||
/// A single normalized training item ready for batching.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TrainingItem {
|
||||
pub features: Vec<f32>,
|
||||
pub label: i32,
|
||||
pub weight: f32,
|
||||
}
|
||||
|
||||
/// A batch of training items as tensors.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TrainingBatch<B: Backend> {
|
||||
pub features: Tensor<B, 2>,
|
||||
pub labels: Tensor<B, 1, Int>,
|
||||
pub weights: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
/// Wraps a `Vec<TrainingSample>` as a burn `Dataset`, applying min-max
|
||||
/// normalization to features at construction time.
|
||||
#[derive(Clone)]
|
||||
pub struct SampleDataset {
|
||||
items: Vec<TrainingItem>,
|
||||
}
|
||||
|
||||
impl SampleDataset {
|
||||
pub fn new(samples: &[TrainingSample], mins: &[f32], maxs: &[f32]) -> Self {
|
||||
let items = samples
|
||||
.iter()
|
||||
.map(|s| {
|
||||
let features: Vec<f32> = s
|
||||
.features
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| {
|
||||
let range = maxs[i] - mins[i];
|
||||
if range > f32::EPSILON {
|
||||
((v - mins[i]) / range).clamp(0.0, 1.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
TrainingItem {
|
||||
features,
|
||||
label: if s.label >= 0.5 { 1 } else { 0 },
|
||||
weight: s.weight,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
Self { items }
|
||||
}
|
||||
}
|
||||
|
||||
impl Dataset<TrainingItem> for SampleDataset {
|
||||
fn get(&self, index: usize) -> Option<TrainingItem> {
|
||||
self.items.get(index).cloned()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.items.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a `Vec<TrainingItem>` into a `TrainingBatch` of tensors.
|
||||
#[derive(Clone)]
|
||||
pub struct SampleBatcher;
|
||||
|
||||
impl SampleBatcher {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<B, TrainingItem, TrainingBatch<B>> for SampleBatcher {
|
||||
fn batch(&self, items: Vec<TrainingItem>, device: &B::Device) -> TrainingBatch<B> {
|
||||
let batch_size = items.len();
|
||||
let num_features = items[0].features.len();
|
||||
|
||||
let flat_features: Vec<f32> = items
|
||||
.iter()
|
||||
.flat_map(|item| item.features.iter().copied())
|
||||
.collect();
|
||||
|
||||
let labels: Vec<i32> = items.iter().map(|item| item.label).collect();
|
||||
let weights: Vec<f32> = items.iter().map(|item| item.weight).collect();
|
||||
|
||||
let features = Tensor::<B, 1>::from_floats(flat_features.as_slice(), device)
|
||||
.reshape([batch_size, num_features]);
|
||||
|
||||
let labels = Tensor::<B, 1, Int>::from_ints(labels.as_slice(), device);
|
||||
|
||||
let weights = Tensor::<B, 1>::from_floats(weights.as_slice(), device)
|
||||
.reshape([batch_size, 1]);
|
||||
|
||||
TrainingBatch {
|
||||
features,
|
||||
labels,
|
||||
weights,
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user