Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
313
crates/ruvector-tiny-dancer-core/examples/train-model.rs
Normal file
313
crates/ruvector-tiny-dancer-core/examples/train-model.rs
Normal file
@@ -0,0 +1,313 @@
|
||||
//! Example: Training a FastGRNN model for routing decisions
|
||||
//!
|
||||
//! This example demonstrates:
|
||||
//! - Synthetic data generation for routing tasks
|
||||
//! - Training a FastGRNN model with validation
|
||||
//! - Knowledge distillation from a teacher model
|
||||
//! - Early stopping and learning rate scheduling
|
||||
//! - Model evaluation and saving
|
||||
|
||||
use rand::Rng;
|
||||
use ruvector_tiny_dancer_core::{
|
||||
model::{FastGRNN, FastGRNNConfig},
|
||||
training::{generate_teacher_predictions, Trainer, TrainingConfig, TrainingDataset},
|
||||
Result,
|
||||
};
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
println!("=== FastGRNN Training Example ===\n");
|
||||
|
||||
// 1. Generate synthetic training data
|
||||
println!("Generating synthetic training data...");
|
||||
let (features, labels) = generate_synthetic_data(1000);
|
||||
let mut dataset = TrainingDataset::new(features, labels)?;
|
||||
|
||||
// Normalize features
|
||||
println!("Normalizing features...");
|
||||
let (means, stds) = dataset.normalize()?;
|
||||
println!("Feature means: {:?}", means);
|
||||
println!("Feature stds: {:?}\n", stds);
|
||||
|
||||
// 2. Create model configuration
|
||||
let model_config = FastGRNNConfig {
|
||||
input_dim: 5,
|
||||
hidden_dim: 16,
|
||||
output_dim: 1,
|
||||
nu: 0.8,
|
||||
zeta: 1.2,
|
||||
rank: Some(8),
|
||||
};
|
||||
|
||||
// 3. Create and initialize model
|
||||
println!("Creating FastGRNN model...");
|
||||
let mut model = FastGRNN::new(model_config.clone())?;
|
||||
println!("Model size: {} bytes\n", model.size_bytes());
|
||||
|
||||
// 4. Optional: Knowledge distillation setup
|
||||
println!("Setting up knowledge distillation...");
|
||||
let teacher_model = create_pretrained_teacher(&model_config)?;
|
||||
let temperature = 3.0;
|
||||
let soft_targets =
|
||||
generate_teacher_predictions(&teacher_model, &dataset.features, temperature)?;
|
||||
dataset = dataset.with_soft_targets(soft_targets)?;
|
||||
println!("Generated soft targets from teacher model\n");
|
||||
|
||||
// 5. Configure training
|
||||
let training_config = TrainingConfig {
|
||||
learning_rate: 0.01,
|
||||
batch_size: 32,
|
||||
epochs: 50,
|
||||
validation_split: 0.2,
|
||||
early_stopping_patience: Some(5),
|
||||
lr_decay: 0.8,
|
||||
lr_decay_step: 10,
|
||||
grad_clip: 5.0,
|
||||
adam_beta1: 0.9,
|
||||
adam_beta2: 0.999,
|
||||
adam_epsilon: 1e-8,
|
||||
l2_reg: 1e-4,
|
||||
enable_distillation: true,
|
||||
distillation_temperature: temperature,
|
||||
distillation_alpha: 0.7,
|
||||
};
|
||||
|
||||
// 6. Create trainer and train model
|
||||
println!("Starting training...\n");
|
||||
let mut trainer = Trainer::new(&model_config, training_config);
|
||||
let metrics = trainer.train(&mut model, &dataset)?;
|
||||
|
||||
// 7. Print training summary
|
||||
println!("\n=== Training Summary ===");
|
||||
println!("Total epochs: {}", metrics.len());
|
||||
if let Some(last_metrics) = metrics.last() {
|
||||
println!("Final train loss: {:.4}", last_metrics.train_loss);
|
||||
println!("Final val loss: {:.4}", last_metrics.val_loss);
|
||||
println!(
|
||||
"Final train accuracy: {:.2}%",
|
||||
last_metrics.train_accuracy * 100.0
|
||||
);
|
||||
println!(
|
||||
"Final val accuracy: {:.2}%",
|
||||
last_metrics.val_accuracy * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
// 8. Find best epoch
|
||||
if let Some(best) = metrics
|
||||
.iter()
|
||||
.min_by(|a, b| a.val_loss.partial_cmp(&b.val_loss).unwrap())
|
||||
{
|
||||
println!(
|
||||
"\nBest validation loss: {:.4} at epoch {}",
|
||||
best.val_loss,
|
||||
best.epoch + 1
|
||||
);
|
||||
println!(
|
||||
"Best validation accuracy: {:.2}%",
|
||||
best.val_accuracy * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
// 9. Test inference on sample data
|
||||
println!("\n=== Testing Inference ===");
|
||||
test_inference(&model)?;
|
||||
|
||||
// 10. Save model and metrics
|
||||
println!("\n=== Saving Model ===");
|
||||
let model_path = PathBuf::from("models/fastgrnn_trained.safetensors");
|
||||
let metrics_path = PathBuf::from("models/training_metrics.json");
|
||||
|
||||
// Create models directory if it doesn't exist
|
||||
std::fs::create_dir_all("models").ok();
|
||||
|
||||
model.save(&model_path)?;
|
||||
trainer.save_metrics(&metrics_path)?;
|
||||
|
||||
println!("Model saved to: {:?}", model_path);
|
||||
println!("Metrics saved to: {:?}", metrics_path);
|
||||
|
||||
// 11. Demonstrate model optimization
|
||||
println!("\n=== Model Optimization ===");
|
||||
let original_size = model.size_bytes();
|
||||
println!("Original model size: {} bytes", original_size);
|
||||
|
||||
model.quantize()?;
|
||||
let quantized_size = model.size_bytes();
|
||||
println!("Quantized model size: {} bytes", quantized_size);
|
||||
println!(
|
||||
"Size reduction: {:.1}%",
|
||||
(1.0 - quantized_size as f32 / original_size as f32) * 100.0
|
||||
);
|
||||
|
||||
println!("\n=== Training Complete ===");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate synthetic training data for routing decisions
|
||||
///
|
||||
/// Features represent:
|
||||
/// - [0]: Semantic similarity (0.0 to 1.0)
|
||||
/// - [1]: Recency score (0.0 to 1.0)
|
||||
/// - [2]: Popularity score (0.0 to 1.0)
|
||||
/// - [3]: Historical success rate (0.0 to 1.0)
|
||||
/// - [4]: Query complexity (0.0 to 1.0)
|
||||
///
|
||||
/// Label: 1.0 = route to lightweight model, 0.0 = route to powerful model
|
||||
fn generate_synthetic_data(n_samples: usize) -> (Vec<Vec<f32>>, Vec<f32>) {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut features = Vec::with_capacity(n_samples);
|
||||
let mut labels = Vec::with_capacity(n_samples);
|
||||
|
||||
for _ in 0..n_samples {
|
||||
// Generate random features
|
||||
let similarity: f32 = rng.gen();
|
||||
let recency: f32 = rng.gen();
|
||||
let popularity: f32 = rng.gen();
|
||||
let success_rate: f32 = rng.gen();
|
||||
let complexity: f32 = rng.gen();
|
||||
|
||||
let feature_vec = vec![similarity, recency, popularity, success_rate, complexity];
|
||||
|
||||
// Generate label based on heuristic rules
|
||||
// High similarity + high success rate + low complexity -> lightweight (1.0)
|
||||
// Low similarity + low success rate + high complexity -> powerful (0.0)
|
||||
let lightweight_score = similarity * 0.4 + success_rate * 0.3 + (1.0 - complexity) * 0.3;
|
||||
|
||||
// Add some noise and threshold
|
||||
let noise: f32 = rng.gen_range(-0.1..0.1);
|
||||
let label = if lightweight_score + noise > 0.6 {
|
||||
1.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
features.push(feature_vec);
|
||||
labels.push(label);
|
||||
}
|
||||
|
||||
(features, labels)
|
||||
}
|
||||
|
||||
/// Create a pretrained teacher model (simulated)
|
||||
///
|
||||
/// In practice, this would be a larger, more accurate model
|
||||
/// For this example, we create a model with similar architecture
|
||||
/// but pretend it's been trained to high accuracy
|
||||
fn create_pretrained_teacher(config: &FastGRNNConfig) -> Result<FastGRNN> {
|
||||
// Create a teacher model with larger capacity
|
||||
let teacher_config = FastGRNNConfig {
|
||||
input_dim: config.input_dim,
|
||||
hidden_dim: config.hidden_dim * 2, // Larger model
|
||||
output_dim: config.output_dim,
|
||||
nu: config.nu,
|
||||
zeta: config.zeta,
|
||||
rank: config.rank.map(|r| r * 2),
|
||||
};
|
||||
|
||||
let teacher = FastGRNN::new(teacher_config)?;
|
||||
// In practice, you would load pretrained weights here:
|
||||
// teacher.load("path/to/teacher/model.safetensors")?;
|
||||
|
||||
Ok(teacher)
|
||||
}
|
||||
|
||||
/// Test model inference on sample inputs
|
||||
fn test_inference(model: &FastGRNN) -> Result<()> {
|
||||
// Test case 1: High confidence -> lightweight
|
||||
let high_confidence = vec![0.9, 0.8, 0.7, 0.9, 0.2]; // high sim, low complexity
|
||||
let pred1 = model.forward(&high_confidence, None)?;
|
||||
println!("High confidence case: prediction = {:.4}", pred1);
|
||||
|
||||
// Test case 2: Low confidence -> powerful
|
||||
let low_confidence = vec![0.3, 0.2, 0.1, 0.4, 0.9]; // low sim, high complexity
|
||||
let pred2 = model.forward(&low_confidence, None)?;
|
||||
println!("Low confidence case: prediction = {:.4}", pred2);
|
||||
|
||||
// Test case 3: Medium confidence
|
||||
let medium_confidence = vec![0.5, 0.5, 0.5, 0.5, 0.5];
|
||||
let pred3 = model.forward(&medium_confidence, None)?;
|
||||
println!("Medium confidence case: prediction = {:.4}", pred3);
|
||||
|
||||
// Batch inference
|
||||
let batch = vec![high_confidence, low_confidence, medium_confidence];
|
||||
let batch_preds = model.forward_batch(&batch)?;
|
||||
println!("\nBatch predictions: {:?}", batch_preds);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Example: Custom training loop with manual control
|
||||
#[allow(dead_code)]
|
||||
fn example_custom_training_loop() -> Result<()> {
|
||||
println!("=== Custom Training Loop Example ===\n");
|
||||
|
||||
// Setup
|
||||
let (features, labels) = generate_synthetic_data(500);
|
||||
let dataset = TrainingDataset::new(features, labels)?;
|
||||
let (train_dataset, val_dataset) = dataset.split(0.2)?;
|
||||
|
||||
let config = FastGRNNConfig::default();
|
||||
let mut model = FastGRNN::new(config.clone())?;
|
||||
|
||||
let training_config = TrainingConfig {
|
||||
batch_size: 16,
|
||||
learning_rate: 0.005,
|
||||
epochs: 20,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut trainer = Trainer::new(&config, training_config);
|
||||
|
||||
// Custom training with per-epoch callbacks
|
||||
println!("Training with custom callbacks...");
|
||||
for epoch in 0..10 {
|
||||
// You could implement custom logic here
|
||||
// For example: dynamic batch size, custom metrics, etc.
|
||||
|
||||
println!("Epoch {}: Custom preprocessing...", epoch + 1);
|
||||
|
||||
// Train for one epoch
|
||||
// In practice, you'd call trainer.train_epoch() here
|
||||
// This is just to demonstrate the pattern
|
||||
}
|
||||
|
||||
println!("Custom training complete!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Example: Continual learning scenario
|
||||
#[allow(dead_code)]
|
||||
fn example_continual_learning() -> Result<()> {
|
||||
println!("=== Continual Learning Example ===\n");
|
||||
|
||||
let config = FastGRNNConfig::default();
|
||||
let mut model = FastGRNN::new(config.clone())?;
|
||||
|
||||
// Train on initial dataset
|
||||
println!("Phase 1: Training on initial data...");
|
||||
let (features1, labels1) = generate_synthetic_data(500);
|
||||
let dataset1 = TrainingDataset::new(features1, labels1)?;
|
||||
|
||||
let training_config = TrainingConfig {
|
||||
epochs: 20,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut trainer = Trainer::new(&config, training_config.clone());
|
||||
trainer.train(&mut model, &dataset1)?;
|
||||
|
||||
// Continue training on new data
|
||||
println!("\nPhase 2: Continual learning on new data...");
|
||||
let (features2, labels2) = generate_synthetic_data(300);
|
||||
let dataset2 = TrainingDataset::new(features2, labels2)?;
|
||||
|
||||
let mut trainer2 = Trainer::new(&config, training_config);
|
||||
trainer2.train(&mut model, &dataset2)?;
|
||||
|
||||
println!("\nContinual learning complete!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user