Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
497
crates/ruvector-tiny-dancer-core/docs/training-api-reference.md
Normal file
497
crates/ruvector-tiny-dancer-core/docs/training-api-reference.md
Normal file
@@ -0,0 +1,497 @@
|
||||
# Training API Reference
|
||||
|
||||
## Module: `ruvector_tiny_dancer_core::training`
|
||||
|
||||
Complete API reference for the FastGRNN training pipeline.
|
||||
|
||||
## Core Types
|
||||
|
||||
### TrainingConfig
|
||||
|
||||
Configuration for training hyperparameters.
|
||||
|
||||
```rust
|
||||
pub struct TrainingConfig {
|
||||
pub learning_rate: f32,
|
||||
pub batch_size: usize,
|
||||
pub epochs: usize,
|
||||
pub validation_split: f32,
|
||||
pub early_stopping_patience: Option<usize>,
|
||||
pub lr_decay: f32,
|
||||
pub lr_decay_step: usize,
|
||||
pub grad_clip: f32,
|
||||
pub adam_beta1: f32,
|
||||
pub adam_beta2: f32,
|
||||
pub adam_epsilon: f32,
|
||||
pub l2_reg: f32,
|
||||
pub enable_distillation: bool,
|
||||
pub distillation_temperature: f32,
|
||||
pub distillation_alpha: f32,
|
||||
}
|
||||
```
|
||||
|
||||
**Default values:**
|
||||
- `learning_rate`: 0.001
|
||||
- `batch_size`: 32
|
||||
- `epochs`: 100
|
||||
- `validation_split`: 0.2
|
||||
- `early_stopping_patience`: Some(10)
|
||||
- `lr_decay`: 0.5
|
||||
- `lr_decay_step`: 20
|
||||
- `grad_clip`: 5.0
|
||||
- `adam_beta1`: 0.9
|
||||
- `adam_beta2`: 0.999
|
||||
- `adam_epsilon`: 1e-8
|
||||
- `l2_reg`: 1e-5
|
||||
- `enable_distillation`: false
|
||||
- `distillation_temperature`: 3.0
|
||||
- `distillation_alpha`: 0.5
|
||||
|
||||
### TrainingDataset
|
||||
|
||||
Training dataset with features and labels.
|
||||
|
||||
```rust
|
||||
pub struct TrainingDataset {
|
||||
pub features: Vec<Vec<f32>>,
|
||||
pub labels: Vec<f32>,
|
||||
pub soft_targets: Option<Vec<f32>>,
|
||||
}
|
||||
```
|
||||
|
||||
**Methods:**
|
||||
|
||||
#### `new`
|
||||
```rust
|
||||
pub fn new(features: Vec<Vec<f32>>, labels: Vec<f32>) -> Result<Self>
|
||||
```
|
||||
Create a new training dataset.
|
||||
|
||||
**Parameters:**
|
||||
- `features`: Input features (N × input_dim)
|
||||
- `labels`: Target labels (N)
|
||||
|
||||
**Returns:** Result<TrainingDataset>
|
||||
|
||||
**Errors:**
|
||||
- Returns error if features and labels have different lengths
|
||||
- Returns error if dataset is empty
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let features = vec![
|
||||
vec![0.8, 0.9, 0.7, 0.85, 0.2],
|
||||
vec![0.3, 0.2, 0.4, 0.35, 0.9],
|
||||
];
|
||||
let labels = vec![1.0, 0.0];
|
||||
let dataset = TrainingDataset::new(features, labels)?;
|
||||
```
|
||||
|
||||
#### `with_soft_targets`
|
||||
```rust
|
||||
pub fn with_soft_targets(self, soft_targets: Vec<f32>) -> Result<Self>
|
||||
```
|
||||
Add soft targets from teacher model for knowledge distillation.
|
||||
|
||||
**Parameters:**
|
||||
- `soft_targets`: Soft predictions from teacher model (N)
|
||||
|
||||
**Returns:** Result<TrainingDataset>
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let soft_targets = generate_teacher_predictions(&teacher, &features, 3.0)?;
|
||||
let dataset = dataset.with_soft_targets(soft_targets)?;
|
||||
```
|
||||
|
||||
#### `split`
|
||||
```rust
|
||||
pub fn split(&self, val_ratio: f32) -> Result<(Self, Self)>
|
||||
```
|
||||
Split dataset into train and validation sets.
|
||||
|
||||
**Parameters:**
|
||||
- `val_ratio`: Validation set ratio (0.0 to 1.0)
|
||||
|
||||
**Returns:** Result<(train_dataset, val_dataset)>
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let (train, val) = dataset.split(0.2)?; // 80% train, 20% val
|
||||
```
|
||||
|
||||
#### `normalize`
|
||||
```rust
|
||||
pub fn normalize(&mut self) -> Result<(Vec<f32>, Vec<f32>)>
|
||||
```
|
||||
Normalize features using z-score normalization.
|
||||
|
||||
**Returns:** Result<(means, stds)>
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let (means, stds) = dataset.normalize()?;
|
||||
// Save for inference
|
||||
save_normalization_params("norm.json", &means, &stds)?;
|
||||
```
|
||||
|
||||
#### `len`
|
||||
```rust
|
||||
pub fn len(&self) -> usize
|
||||
```
|
||||
Get number of samples in dataset.
|
||||
|
||||
#### `is_empty`
|
||||
```rust
|
||||
pub fn is_empty(&self) -> bool
|
||||
```
|
||||
Check if dataset is empty.
|
||||
|
||||
### BatchIterator
|
||||
|
||||
Iterator for mini-batch training.
|
||||
|
||||
```rust
|
||||
pub struct BatchIterator<'a> {
|
||||
// Private fields
|
||||
}
|
||||
```
|
||||
|
||||
**Methods:**
|
||||
|
||||
#### `new`
|
||||
```rust
|
||||
pub fn new(dataset: &'a TrainingDataset, batch_size: usize, shuffle: bool) -> Self
|
||||
```
|
||||
Create a new batch iterator.
|
||||
|
||||
**Parameters:**
|
||||
- `dataset`: Reference to training dataset
|
||||
- `batch_size`: Size of each batch
|
||||
- `shuffle`: Whether to shuffle data
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let batch_iter = BatchIterator::new(&dataset, 32, true);
|
||||
for (features, labels, soft_targets) in batch_iter {
|
||||
// Train on batch
|
||||
}
|
||||
```
|
||||
|
||||
### TrainingMetrics
|
||||
|
||||
Metrics recorded during training.
|
||||
|
||||
```rust
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TrainingMetrics {
|
||||
pub epoch: usize,
|
||||
pub train_loss: f32,
|
||||
pub val_loss: f32,
|
||||
pub train_accuracy: f32,
|
||||
pub val_accuracy: f32,
|
||||
pub learning_rate: f32,
|
||||
}
|
||||
```
|
||||
|
||||
### Trainer
|
||||
|
||||
Main trainer for FastGRNN models.
|
||||
|
||||
```rust
|
||||
pub struct Trainer {
|
||||
// Private fields
|
||||
}
|
||||
```
|
||||
|
||||
**Methods:**
|
||||
|
||||
#### `new`
|
||||
```rust
|
||||
pub fn new(model_config: &FastGRNNConfig, config: TrainingConfig) -> Self
|
||||
```
|
||||
Create a new trainer.
|
||||
|
||||
**Parameters:**
|
||||
- `model_config`: Model configuration
|
||||
- `config`: Training configuration
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let trainer = Trainer::new(&model_config, training_config);
|
||||
```
|
||||
|
||||
#### `train`
|
||||
```rust
|
||||
pub fn train(
|
||||
&mut self,
|
||||
model: &mut FastGRNN,
|
||||
dataset: &TrainingDataset,
|
||||
) -> Result<Vec<TrainingMetrics>>
|
||||
```
|
||||
Train the model on the dataset.
|
||||
|
||||
**Parameters:**
|
||||
- `model`: Mutable reference to the model
|
||||
- `dataset`: Training dataset
|
||||
|
||||
**Returns:** Result<Vec<TrainingMetrics>> - Metrics for each epoch
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let metrics = trainer.train(&mut model, &dataset)?;
|
||||
|
||||
// Print results
|
||||
for m in &metrics {
|
||||
println!("Epoch {}: val_loss={:.4}, val_acc={:.2}%",
|
||||
m.epoch, m.val_loss, m.val_accuracy * 100.0);
|
||||
}
|
||||
```
|
||||
|
||||
#### `metrics_history`
|
||||
```rust
|
||||
pub fn metrics_history(&self) -> &[TrainingMetrics]
|
||||
```
|
||||
Get training metrics history.
|
||||
|
||||
**Returns:** Slice of training metrics
|
||||
|
||||
#### `save_metrics`
|
||||
```rust
|
||||
pub fn save_metrics<P: AsRef<Path>>(&self, path: P) -> Result<()>
|
||||
```
|
||||
Save training metrics to JSON file.
|
||||
|
||||
**Parameters:**
|
||||
- `path`: Output file path
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
trainer.save_metrics("models/metrics.json")?;
|
||||
```
|
||||
|
||||
## Functions
|
||||
|
||||
### binary_cross_entropy
|
||||
```rust
|
||||
fn binary_cross_entropy(prediction: f32, target: f32) -> f32
|
||||
```
|
||||
Compute binary cross-entropy loss.
|
||||
|
||||
**Formula:**
|
||||
```
|
||||
BCE = -target * log(pred) - (1 - target) * log(1 - pred)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `prediction`: Model prediction (0.0 to 1.0)
|
||||
- `target`: True label (0.0 or 1.0)
|
||||
|
||||
**Returns:** Loss value
|
||||
|
||||
### temperature_softmax
|
||||
```rust
|
||||
pub fn temperature_softmax(logit: f32, temperature: f32) -> f32
|
||||
```
|
||||
Temperature-scaled sigmoid for knowledge distillation.
|
||||
|
||||
**Parameters:**
|
||||
- `logit`: Model output logit
|
||||
- `temperature`: Temperature scaling factor (> 1.0 = softer)
|
||||
|
||||
**Returns:** Temperature-scaled probability
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let soft_pred = temperature_softmax(logit, 3.0);
|
||||
```
|
||||
|
||||
### generate_teacher_predictions
|
||||
```rust
|
||||
pub fn generate_teacher_predictions(
|
||||
teacher: &FastGRNN,
|
||||
features: &[Vec<f32>],
|
||||
temperature: f32,
|
||||
) -> Result<Vec<f32>>
|
||||
```
|
||||
Generate soft predictions from teacher model.
|
||||
|
||||
**Parameters:**
|
||||
- `teacher`: Teacher model
|
||||
- `features`: Input features
|
||||
- `temperature`: Temperature for softening
|
||||
|
||||
**Returns:** Result<Vec<f32>> - Soft predictions
|
||||
|
||||
**Example:**
|
||||
```rust
|
||||
let teacher = FastGRNN::load("teacher.safetensors")?;
|
||||
let soft_targets = generate_teacher_predictions(&teacher, &features, 3.0)?;
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Training
|
||||
|
||||
```rust
|
||||
use ruvector_tiny_dancer_core::{
|
||||
model::{FastGRNN, FastGRNNConfig},
|
||||
training::{TrainingConfig, TrainingDataset, Trainer},
|
||||
};
|
||||
|
||||
// Prepare data
|
||||
let features = vec![/* ... */];
|
||||
let labels = vec![/* ... */];
|
||||
let mut dataset = TrainingDataset::new(features, labels)?;
|
||||
dataset.normalize()?;
|
||||
|
||||
// Configure
|
||||
let model_config = FastGRNNConfig::default();
|
||||
let training_config = TrainingConfig::default();
|
||||
|
||||
// Train
|
||||
let mut model = FastGRNN::new(model_config.clone())?;
|
||||
let mut trainer = Trainer::new(&model_config, training_config);
|
||||
let metrics = trainer.train(&mut model, &dataset)?;
|
||||
|
||||
// Save
|
||||
model.save("model.safetensors")?;
|
||||
```
|
||||
|
||||
### Knowledge Distillation
|
||||
|
||||
```rust
|
||||
use ruvector_tiny_dancer_core::training::generate_teacher_predictions;
|
||||
|
||||
// Load teacher
|
||||
let teacher = FastGRNN::load("teacher.safetensors")?;
|
||||
|
||||
// Generate soft targets
|
||||
let temperature = 3.0;
|
||||
let soft_targets = generate_teacher_predictions(&teacher, &features, temperature)?;
|
||||
|
||||
// Add to dataset
|
||||
let dataset = dataset.with_soft_targets(soft_targets)?;
|
||||
|
||||
// Configure distillation
|
||||
let training_config = TrainingConfig {
|
||||
enable_distillation: true,
|
||||
distillation_temperature: temperature,
|
||||
distillation_alpha: 0.7,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Train with distillation
|
||||
let mut trainer = Trainer::new(&model_config, training_config);
|
||||
trainer.train(&mut model, &dataset)?;
|
||||
```
|
||||
|
||||
### Custom Training Loop
|
||||
|
||||
```rust
|
||||
use ruvector_tiny_dancer_core::training::BatchIterator;
|
||||
|
||||
for epoch in 0..50 {
|
||||
let mut epoch_loss = 0.0;
|
||||
let mut n_batches = 0;
|
||||
|
||||
let batch_iter = BatchIterator::new(&train_dataset, 32, true);
|
||||
for (features, labels, soft_targets) in batch_iter {
|
||||
// Your training logic here
|
||||
epoch_loss += train_batch(&mut model, &features, &labels);
|
||||
n_batches += 1;
|
||||
}
|
||||
|
||||
let avg_loss = epoch_loss / n_batches as f32;
|
||||
println!("Epoch {}: loss={:.4}", epoch, avg_loss);
|
||||
}
|
||||
```
|
||||
|
||||
### Progressive Training
|
||||
|
||||
```rust
|
||||
// Start with high LR
|
||||
let mut config = TrainingConfig {
|
||||
learning_rate: 0.1,
|
||||
epochs: 20,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut trainer = Trainer::new(&model_config, config.clone());
|
||||
trainer.train(&mut model, &dataset)?;
|
||||
|
||||
// Continue with lower LR
|
||||
config.learning_rate = 0.01;
|
||||
config.epochs = 30;
|
||||
|
||||
let mut trainer2 = Trainer::new(&model_config, config);
|
||||
trainer2.train(&mut model, &dataset)?;
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
All training functions return `Result<T>` with `TinyDancerError`:
|
||||
|
||||
```rust
|
||||
match trainer.train(&mut model, &dataset) {
|
||||
Ok(metrics) => {
|
||||
println!("Training successful!");
|
||||
println!("Final accuracy: {:.2}%",
|
||||
metrics.last().unwrap().val_accuracy * 100.0);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Training failed: {}", e);
|
||||
// Handle error appropriately
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Common errors:
|
||||
- `InvalidInput`: Invalid dataset, configuration, or parameters
|
||||
- `SerializationError`: Failed to save/load files
|
||||
- `IoError`: File I/O errors
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Usage
|
||||
|
||||
- **Dataset**: O(N × input_dim) floats
|
||||
- **Model**: ~850 parameters for default config (16 hidden units)
|
||||
- **Optimizer**: 2× model size (Adam momentum)
|
||||
|
||||
For large datasets (>100K samples), consider:
|
||||
- Batch processing
|
||||
- Data streaming
|
||||
- Memory-mapped files
|
||||
|
||||
### Training Speed
|
||||
|
||||
Typical training times (CPU):
|
||||
- Small dataset (1K samples): ~10 seconds
|
||||
- Medium dataset (10K samples): ~1-2 minutes
|
||||
- Large dataset (100K samples): ~10-20 minutes
|
||||
|
||||
Optimization tips:
|
||||
- Use larger batch sizes (32-128)
|
||||
- Enable early stopping
|
||||
- Use knowledge distillation for faster convergence
|
||||
|
||||
### Reproducibility
|
||||
|
||||
For reproducible results:
|
||||
1. Set random seed before training
|
||||
2. Use deterministic operations
|
||||
3. Save normalization parameters
|
||||
4. Version control all hyperparameters
|
||||
|
||||
```rust
|
||||
// Set seed (note: full reproducibility requires more work)
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [Training Guide](./training-guide.md) - Complete training walkthrough
|
||||
- [Model API](../src/model.rs) - FastGRNN model implementation
|
||||
- [Examples](../examples/train-model.rs) - Working code examples
|
||||
Reference in New Issue
Block a user