// Copyright Sunbeam Studios 2026 // SPDX-License-Identifier: Apache-2.0 //! Shared training infrastructure: Dataset adapter, Batcher, and batch types //! for use with burn's SupervisedTraining. use crate::dataset::sample::TrainingSample; use burn::data::dataloader::batcher::Batcher; use burn::data::dataloader::Dataset; use burn::prelude::*; /// A single normalized training item ready for batching. #[derive(Clone, Debug)] pub struct TrainingItem { pub features: Vec, pub label: i32, pub weight: f32, } /// A batch of training items as tensors. #[derive(Clone, Debug)] pub struct TrainingBatch { pub features: Tensor, pub labels: Tensor, pub weights: Tensor, } /// Wraps a `Vec` as a burn `Dataset`, applying min-max /// normalization to features at construction time. #[derive(Clone)] pub struct SampleDataset { items: Vec, } impl SampleDataset { pub fn new(samples: &[TrainingSample], mins: &[f32], maxs: &[f32]) -> Self { let items = samples .iter() .map(|s| { let features: Vec = s .features .iter() .enumerate() .map(|(i, &v)| { let range = maxs[i] - mins[i]; if range > f32::EPSILON { ((v - mins[i]) / range).clamp(0.0, 1.0) } else { 0.0 } }) .collect(); TrainingItem { features, label: if s.label >= 0.5 { 1 } else { 0 }, weight: s.weight, } }) .collect(); Self { items } } } impl Dataset for SampleDataset { fn get(&self, index: usize) -> Option { self.items.get(index).cloned() } fn len(&self) -> usize { self.items.len() } } /// Converts a `Vec` into a `TrainingBatch` of tensors. #[derive(Clone)] pub struct SampleBatcher; impl SampleBatcher { pub fn new() -> Self { Self } } impl Batcher> for SampleBatcher { fn batch(&self, items: Vec, device: &B::Device) -> TrainingBatch { let batch_size = items.len(); let num_features = items[0].features.len(); let flat_features: Vec = items .iter() .flat_map(|item| item.features.iter().copied()) .collect(); let labels: Vec = items.iter().map(|item| item.label).collect(); let weights: Vec = items.iter().map(|item| item.weight).collect(); let features = Tensor::::from_floats(flat_features.as_slice(), device) .reshape([batch_size, num_features]); let labels = Tensor::::from_ints(labels.as_slice(), device); let weights = Tensor::::from_floats(weights.as_slice(), device) .reshape([batch_size, 1]); TrainingBatch { features, labels, weights, } } }