Files
proxy/src/dataset/sample.rs

160 lines
5.7 KiB
Rust
Raw Normal View History

// Copyright Sunbeam Studios 2026
// SPDX-License-Identifier: Apache-2.0
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use anyhow::{Context, Result};
/// Provenance of a training sample.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DataSource {
ProductionLogs,
Csic2010,
OwaspModSec,
SyntheticCicTiming,
SyntheticWordlist,
}
impl std::fmt::Display for DataSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DataSource::ProductionLogs => write!(f, "ProductionLogs"),
DataSource::Csic2010 => write!(f, "Csic2010"),
DataSource::OwaspModSec => write!(f, "OwaspModSec"),
DataSource::SyntheticCicTiming => write!(f, "SyntheticCicTiming"),
DataSource::SyntheticWordlist => write!(f, "SyntheticWordlist"),
}
}
}
/// A single labeled training sample with its feature vector, label, provenance,
/// and weight (used during training to down-weight synthetic/external data).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingSample {
pub features: Vec<f32>,
/// 0.0 = normal, 1.0 = attack
pub label: f32,
pub source: DataSource,
/// Sample weight: 1.0 for production, 0.8 for external datasets, 0.5 for synthetic.
pub weight: f32,
}
/// Aggregate statistics about a prepared dataset.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetStats {
pub total_samples: usize,
pub scanner_samples: usize,
pub ddos_samples: usize,
pub samples_by_source: HashMap<DataSource, usize>,
pub attack_ratio_scanner: f64,
pub attack_ratio_ddos: f64,
}
/// The full serializable dataset: scanner samples, DDoS samples, and stats.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetManifest {
pub scanner_samples: Vec<TrainingSample>,
pub ddos_samples: Vec<TrainingSample>,
pub stats: DatasetStats,
}
/// Serialize a `DatasetManifest` to a bincode file.
pub fn save_dataset(manifest: &DatasetManifest, path: &Path) -> Result<()> {
let encoded = bincode::serialize(manifest).context("serializing dataset manifest")?;
std::fs::write(path, &encoded)
.with_context(|| format!("writing dataset to {}", path.display()))?;
Ok(())
}
/// Deserialize a `DatasetManifest` from a bincode file.
pub fn load_dataset(path: &Path) -> Result<DatasetManifest> {
let data = std::fs::read(path)
.with_context(|| format!("reading dataset from {}", path.display()))?;
let manifest: DatasetManifest =
bincode::deserialize(&data).context("deserializing dataset manifest")?;
Ok(manifest)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_sample(features: Vec<f32>, label: f32, source: DataSource, weight: f32) -> TrainingSample {
TrainingSample { features, label, source, weight }
}
#[test]
fn test_bincode_roundtrip() {
let manifest = DatasetManifest {
scanner_samples: vec![
make_sample(vec![0.1, 0.2, 0.3], 0.0, DataSource::ProductionLogs, 1.0),
make_sample(vec![0.9, 0.8, 0.7], 1.0, DataSource::Csic2010, 0.8),
],
ddos_samples: vec![
make_sample(vec![0.5; 14], 1.0, DataSource::SyntheticCicTiming, 0.5),
make_sample(vec![0.1; 14], 0.0, DataSource::ProductionLogs, 1.0),
],
stats: DatasetStats {
total_samples: 4,
scanner_samples: 2,
ddos_samples: 2,
samples_by_source: {
let mut m = HashMap::new();
m.insert(DataSource::ProductionLogs, 2);
m.insert(DataSource::Csic2010, 1);
m.insert(DataSource::SyntheticCicTiming, 1);
m
},
attack_ratio_scanner: 0.5,
attack_ratio_ddos: 0.5,
},
};
let encoded = bincode::serialize(&manifest).unwrap();
let decoded: DatasetManifest = bincode::deserialize(&encoded).unwrap();
assert_eq!(decoded.scanner_samples.len(), 2);
assert_eq!(decoded.ddos_samples.len(), 2);
assert_eq!(decoded.stats.total_samples, 4);
assert_eq!(decoded.stats.samples_by_source[&DataSource::ProductionLogs], 2);
// Verify feature values survive the roundtrip.
assert!((decoded.scanner_samples[0].features[0] - 0.1).abs() < 1e-6);
assert_eq!(decoded.scanner_samples[1].label, 1.0);
assert_eq!(decoded.ddos_samples[0].source, DataSource::SyntheticCicTiming);
}
#[test]
fn test_save_load_roundtrip() {
let manifest = DatasetManifest {
scanner_samples: vec![
make_sample(vec![1.0, 2.0], 0.0, DataSource::ProductionLogs, 1.0),
],
ddos_samples: vec![
make_sample(vec![3.0, 4.0], 1.0, DataSource::OwaspModSec, 0.8),
],
stats: DatasetStats {
total_samples: 2,
scanner_samples: 1,
ddos_samples: 1,
samples_by_source: HashMap::new(),
attack_ratio_scanner: 0.0,
attack_ratio_ddos: 1.0,
},
};
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test_dataset.bin");
save_dataset(&manifest, &path).unwrap();
let loaded = load_dataset(&path).unwrap();
assert_eq!(loaded.scanner_samples.len(), 1);
assert_eq!(loaded.ddos_samples.len(), 1);
assert!((loaded.scanner_samples[0].features[0] - 1.0).abs() < 1e-6);
assert_eq!(loaded.ddos_samples[0].source, DataSource::OwaspModSec);
}
}