git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
11 KiB
Training API Reference
Module: ruvector_tiny_dancer_core::training
Complete API reference for the FastGRNN training pipeline.
Core Types
TrainingConfig
Configuration for training hyperparameters.
pub struct TrainingConfig {
pub learning_rate: f32,
pub batch_size: usize,
pub epochs: usize,
pub validation_split: f32,
pub early_stopping_patience: Option<usize>,
pub lr_decay: f32,
pub lr_decay_step: usize,
pub grad_clip: f32,
pub adam_beta1: f32,
pub adam_beta2: f32,
pub adam_epsilon: f32,
pub l2_reg: f32,
pub enable_distillation: bool,
pub distillation_temperature: f32,
pub distillation_alpha: f32,
}
Default values:
learning_rate: 0.001batch_size: 32epochs: 100validation_split: 0.2early_stopping_patience: Some(10)lr_decay: 0.5lr_decay_step: 20grad_clip: 5.0adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-8l2_reg: 1e-5enable_distillation: falsedistillation_temperature: 3.0distillation_alpha: 0.5
TrainingDataset
Training dataset with features and labels.
pub struct TrainingDataset {
pub features: Vec<Vec<f32>>,
pub labels: Vec<f32>,
pub soft_targets: Option<Vec<f32>>,
}
Methods:
new
pub fn new(features: Vec<Vec<f32>>, labels: Vec<f32>) -> Result<Self>
Create a new training dataset.
Parameters:
features: Input features (N × input_dim)labels: Target labels (N)
Returns: Result
Errors:
- Returns error if features and labels have different lengths
- Returns error if dataset is empty
Example:
let features = vec![
vec![0.8, 0.9, 0.7, 0.85, 0.2],
vec![0.3, 0.2, 0.4, 0.35, 0.9],
];
let labels = vec![1.0, 0.0];
let dataset = TrainingDataset::new(features, labels)?;
with_soft_targets
pub fn with_soft_targets(self, soft_targets: Vec<f32>) -> Result<Self>
Add soft targets from teacher model for knowledge distillation.
Parameters:
soft_targets: Soft predictions from teacher model (N)
Returns: Result
Example:
let soft_targets = generate_teacher_predictions(&teacher, &features, 3.0)?;
let dataset = dataset.with_soft_targets(soft_targets)?;
split
pub fn split(&self, val_ratio: f32) -> Result<(Self, Self)>
Split dataset into train and validation sets.
Parameters:
val_ratio: Validation set ratio (0.0 to 1.0)
Returns: Result<(train_dataset, val_dataset)>
Example:
let (train, val) = dataset.split(0.2)?; // 80% train, 20% val
normalize
pub fn normalize(&mut self) -> Result<(Vec<f32>, Vec<f32>)>
Normalize features using z-score normalization.
Returns: Result<(means, stds)>
Example:
let (means, stds) = dataset.normalize()?;
// Save for inference
save_normalization_params("norm.json", &means, &stds)?;
len
pub fn len(&self) -> usize
Get number of samples in dataset.
is_empty
pub fn is_empty(&self) -> bool
Check if dataset is empty.
BatchIterator
Iterator for mini-batch training.
pub struct BatchIterator<'a> {
// Private fields
}
Methods:
new
pub fn new(dataset: &'a TrainingDataset, batch_size: usize, shuffle: bool) -> Self
Create a new batch iterator.
Parameters:
dataset: Reference to training datasetbatch_size: Size of each batchshuffle: Whether to shuffle data
Example:
let batch_iter = BatchIterator::new(&dataset, 32, true);
for (features, labels, soft_targets) in batch_iter {
// Train on batch
}
TrainingMetrics
Metrics recorded during training.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub epoch: usize,
pub train_loss: f32,
pub val_loss: f32,
pub train_accuracy: f32,
pub val_accuracy: f32,
pub learning_rate: f32,
}
Trainer
Main trainer for FastGRNN models.
pub struct Trainer {
// Private fields
}
Methods:
new
pub fn new(model_config: &FastGRNNConfig, config: TrainingConfig) -> Self
Create a new trainer.
Parameters:
model_config: Model configurationconfig: Training configuration
Example:
let trainer = Trainer::new(&model_config, training_config);
train
pub fn train(
&mut self,
model: &mut FastGRNN,
dataset: &TrainingDataset,
) -> Result<Vec<TrainingMetrics>>
Train the model on the dataset.
Parameters:
model: Mutable reference to the modeldataset: Training dataset
Returns: Result<Vec> - Metrics for each epoch
Example:
let metrics = trainer.train(&mut model, &dataset)?;
// Print results
for m in &metrics {
println!("Epoch {}: val_loss={:.4}, val_acc={:.2}%",
m.epoch, m.val_loss, m.val_accuracy * 100.0);
}
metrics_history
pub fn metrics_history(&self) -> &[TrainingMetrics]
Get training metrics history.
Returns: Slice of training metrics
save_metrics
pub fn save_metrics<P: AsRef<Path>>(&self, path: P) -> Result<()>
Save training metrics to JSON file.
Parameters:
path: Output file path
Example:
trainer.save_metrics("models/metrics.json")?;
Functions
binary_cross_entropy
fn binary_cross_entropy(prediction: f32, target: f32) -> f32
Compute binary cross-entropy loss.
Formula:
BCE = -target * log(pred) - (1 - target) * log(1 - pred)
Parameters:
prediction: Model prediction (0.0 to 1.0)target: True label (0.0 or 1.0)
Returns: Loss value
temperature_softmax
pub fn temperature_softmax(logit: f32, temperature: f32) -> f32
Temperature-scaled sigmoid for knowledge distillation.
Parameters:
logit: Model output logittemperature: Temperature scaling factor (> 1.0 = softer)
Returns: Temperature-scaled probability
Example:
let soft_pred = temperature_softmax(logit, 3.0);
generate_teacher_predictions
pub fn generate_teacher_predictions(
teacher: &FastGRNN,
features: &[Vec<f32>],
temperature: f32,
) -> Result<Vec<f32>>
Generate soft predictions from teacher model.
Parameters:
teacher: Teacher modelfeatures: Input featurestemperature: Temperature for softening
Returns: Result<Vec> - Soft predictions
Example:
let teacher = FastGRNN::load("teacher.safetensors")?;
let soft_targets = generate_teacher_predictions(&teacher, &features, 3.0)?;
Usage Examples
Basic Training
use ruvector_tiny_dancer_core::{
model::{FastGRNN, FastGRNNConfig},
training::{TrainingConfig, TrainingDataset, Trainer},
};
// Prepare data
let features = vec![/* ... */];
let labels = vec![/* ... */];
let mut dataset = TrainingDataset::new(features, labels)?;
dataset.normalize()?;
// Configure
let model_config = FastGRNNConfig::default();
let training_config = TrainingConfig::default();
// Train
let mut model = FastGRNN::new(model_config.clone())?;
let mut trainer = Trainer::new(&model_config, training_config);
let metrics = trainer.train(&mut model, &dataset)?;
// Save
model.save("model.safetensors")?;
Knowledge Distillation
use ruvector_tiny_dancer_core::training::generate_teacher_predictions;
// Load teacher
let teacher = FastGRNN::load("teacher.safetensors")?;
// Generate soft targets
let temperature = 3.0;
let soft_targets = generate_teacher_predictions(&teacher, &features, temperature)?;
// Add to dataset
let dataset = dataset.with_soft_targets(soft_targets)?;
// Configure distillation
let training_config = TrainingConfig {
enable_distillation: true,
distillation_temperature: temperature,
distillation_alpha: 0.7,
..Default::default()
};
// Train with distillation
let mut trainer = Trainer::new(&model_config, training_config);
trainer.train(&mut model, &dataset)?;
Custom Training Loop
use ruvector_tiny_dancer_core::training::BatchIterator;
for epoch in 0..50 {
let mut epoch_loss = 0.0;
let mut n_batches = 0;
let batch_iter = BatchIterator::new(&train_dataset, 32, true);
for (features, labels, soft_targets) in batch_iter {
// Your training logic here
epoch_loss += train_batch(&mut model, &features, &labels);
n_batches += 1;
}
let avg_loss = epoch_loss / n_batches as f32;
println!("Epoch {}: loss={:.4}", epoch, avg_loss);
}
Progressive Training
// Start with high LR
let mut config = TrainingConfig {
learning_rate: 0.1,
epochs: 20,
..Default::default()
};
let mut trainer = Trainer::new(&model_config, config.clone());
trainer.train(&mut model, &dataset)?;
// Continue with lower LR
config.learning_rate = 0.01;
config.epochs = 30;
let mut trainer2 = Trainer::new(&model_config, config);
trainer2.train(&mut model, &dataset)?;
Error Handling
All training functions return Result<T> with TinyDancerError:
match trainer.train(&mut model, &dataset) {
Ok(metrics) => {
println!("Training successful!");
println!("Final accuracy: {:.2}%",
metrics.last().unwrap().val_accuracy * 100.0);
}
Err(e) => {
eprintln!("Training failed: {}", e);
// Handle error appropriately
}
}
Common errors:
InvalidInput: Invalid dataset, configuration, or parametersSerializationError: Failed to save/load filesIoError: File I/O errors
Performance Considerations
Memory Usage
- Dataset: O(N × input_dim) floats
- Model: ~850 parameters for default config (16 hidden units)
- Optimizer: 2× model size (Adam momentum)
For large datasets (>100K samples), consider:
- Batch processing
- Data streaming
- Memory-mapped files
Training Speed
Typical training times (CPU):
- Small dataset (1K samples): ~10 seconds
- Medium dataset (10K samples): ~1-2 minutes
- Large dataset (100K samples): ~10-20 minutes
Optimization tips:
- Use larger batch sizes (32-128)
- Enable early stopping
- Use knowledge distillation for faster convergence
Reproducibility
For reproducible results:
- Set random seed before training
- Use deterministic operations
- Save normalization parameters
- Version control all hyperparameters
// Set seed (note: full reproducibility requires more work)
use rand::SeedableRng;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
See Also
- Training Guide - Complete training walkthrough
- Model API - FastGRNN model implementation
- Examples - Working code examples