Files
wifi-densepose/crates/ruvector-tiny-dancer-core/docs/training-guide.md
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

18 KiB
Raw Blame History

FastGRNN Training Pipeline Guide

This guide covers the complete training pipeline for the FastGRNN model used in Tiny Dancer's neural routing system.

Table of Contents

  1. Overview
  2. Architecture
  3. Quick Start
  4. Training Configuration
  5. Data Preparation
  6. Training Loop
  7. Knowledge Distillation
  8. Advanced Features
  9. Production Deployment

Overview

The FastGRNN training pipeline provides a complete solution for training lightweight recurrent neural networks for AI agent routing decisions. Key features include:

  • Adam Optimizer: State-of-the-art adaptive learning rate optimization
  • Mini-batch Training: Efficient batch processing with configurable batch sizes
  • Early Stopping: Automatic stopping when validation loss stops improving
  • Learning Rate Scheduling: Exponential decay for better convergence
  • Knowledge Distillation: Learn from larger teacher models
  • Gradient Clipping: Prevent exploding gradients
  • L2 Regularization: Prevent overfitting

Architecture

FastGRNN Cell

The FastGRNN (Fast Gated Recurrent Neural Network) uses a simplified gating mechanism:

r_t = σ(W_r × x_t + b_r)                    [Reset gate]
u_t = σ(W_u × x_t + b_u)                    [Update gate]
c_t = tanh(W_c × x_t + W × (r_t ⊙ h_t-1))  [Candidate state]
h_t = u_t ⊙ h_t-1 + (1 - u_t) ⊙ c_t         [Hidden state]
y_t = σ(W_out × h_t + b_out)                [Output]

Where:

  • σ is the sigmoid activation with scaling parameter nu
  • tanh is the hyperbolic tangent with scaling parameter zeta
  • denotes element-wise multiplication

Training Pipeline

┌─────────────────┐
│  Raw Features   │
│  + Labels       │
└────────┬────────┘
         │
         ▼
┌─────────────────┐
│  Normalization  │
│  (z-score)      │
└────────┬────────┘
         │
         ▼
┌─────────────────┐
│  Train/Val      │
│  Split          │
└────────┬────────┘
         │
         ▼
┌─────────────────┐
│  Mini-batch     │
│  Training       │
│  (BPTT)         │
└────────┬────────┘
         │
         ▼
┌─────────────────┐
│  Adam Update    │
│  + Grad Clip    │
└────────┬────────┘
         │
         ▼
┌─────────────────┐
│  Validation     │
│  + Early Stop   │
└────────┬────────┘
         │
         ▼
┌─────────────────┐
│  Trained Model  │
└─────────────────┘

Quick Start

Basic Training

use ruvector_tiny_dancer_core::{
    model::{FastGRNN, FastGRNNConfig},
    training::{TrainingConfig, TrainingDataset, Trainer},
};

// 1. Prepare your data
let features = vec![
    vec![0.8, 0.9, 0.7, 0.85, 0.2], // High confidence case
    vec![0.3, 0.2, 0.4, 0.35, 0.9], // Low confidence case
    // ... more samples
];
let labels = vec![1.0, 0.0, /* ... */]; // 1.0 = lightweight, 0.0 = powerful

let mut dataset = TrainingDataset::new(features, labels)?;

// 2. Normalize features
let (means, stds) = dataset.normalize()?;

// 3. Create model
let model_config = FastGRNNConfig {
    input_dim: 5,
    hidden_dim: 16,
    output_dim: 1,
    nu: 0.8,
    zeta: 1.2,
    rank: Some(8),
};
let mut model = FastGRNN::new(model_config.clone())?;

// 4. Configure training
let training_config = TrainingConfig {
    learning_rate: 0.01,
    batch_size: 32,
    epochs: 50,
    validation_split: 0.2,
    early_stopping_patience: Some(5),
    ..Default::default()
};

// 5. Train
let mut trainer = Trainer::new(&model_config, training_config);
let metrics = trainer.train(&mut model, &dataset)?;

// 6. Save model
model.save("models/fastgrnn.safetensors")?;

Run the Example

cd crates/ruvector-tiny-dancer-core
cargo run --example train-model

Training Configuration

Hyperparameters

pub struct TrainingConfig {
    /// Learning rate (default: 0.001)
    pub learning_rate: f32,

    /// Batch size (default: 32)
    pub batch_size: usize,

    /// Number of epochs (default: 100)
    pub epochs: usize,

    /// Validation split ratio (default: 0.2)
    pub validation_split: f32,

    /// Early stopping patience (default: Some(10))
    pub early_stopping_patience: Option<usize>,

    /// Learning rate decay factor (default: 0.5)
    pub lr_decay: f32,

    /// Learning rate decay step in epochs (default: 20)
    pub lr_decay_step: usize,

    /// Gradient clipping threshold (default: 5.0)
    pub grad_clip: f32,

    /// Adam beta1 parameter (default: 0.9)
    pub adam_beta1: f32,

    /// Adam beta2 parameter (default: 0.999)
    pub adam_beta2: f32,

    /// Adam epsilon (default: 1e-8)
    pub adam_epsilon: f32,

    /// L2 regularization strength (default: 1e-5)
    pub l2_reg: f32,
}

Small Datasets (< 1,000 samples)

TrainingConfig {
    learning_rate: 0.01,
    batch_size: 16,
    epochs: 100,
    validation_split: 0.2,
    early_stopping_patience: Some(10),
    lr_decay: 0.8,
    lr_decay_step: 20,
    l2_reg: 1e-4,
    ..Default::default()
}

Medium Datasets (1,000 - 10,000 samples)

TrainingConfig {
    learning_rate: 0.005,
    batch_size: 32,
    epochs: 50,
    validation_split: 0.15,
    early_stopping_patience: Some(5),
    lr_decay: 0.7,
    lr_decay_step: 10,
    l2_reg: 1e-5,
    ..Default::default()
}

Large Datasets (> 10,000 samples)

TrainingConfig {
    learning_rate: 0.001,
    batch_size: 64,
    epochs: 30,
    validation_split: 0.1,
    early_stopping_patience: Some(3),
    lr_decay: 0.5,
    lr_decay_step: 5,
    l2_reg: 1e-6,
    ..Default::default()
}

Data Preparation

Feature Engineering

For routing decisions, typical features include:

pub struct RoutingFeatures {
    /// Semantic similarity between query and candidate (0.0 to 1.0)
    pub similarity: f32,

    /// Recency score - how recently was this candidate accessed (0.0 to 1.0)
    pub recency: f32,

    /// Popularity score - how often is this candidate used (0.0 to 1.0)
    pub popularity: f32,

    /// Historical success rate for this candidate (0.0 to 1.0)
    pub success_rate: f32,

    /// Query complexity estimate (0.0 to 1.0)
    pub complexity: f32,
}

impl RoutingFeatures {
    fn to_vector(&self) -> Vec<f32> {
        vec![
            self.similarity,
            self.recency,
            self.popularity,
            self.success_rate,
            self.complexity,
        ]
    }
}

Data Collection

// Collect training data from production logs
fn collect_training_data(logs: &[RoutingLog]) -> (Vec<Vec<f32>>, Vec<f32>) {
    let mut features = Vec::new();
    let mut labels = Vec::new();

    for log in logs {
        // Extract features
        let feature_vec = vec![
            log.similarity_score,
            log.recency_score,
            log.popularity_score,
            log.success_rate,
            log.complexity_score,
        ];

        // Label based on actual outcome
        // 1.0 if lightweight model was sufficient
        // 0.0 if powerful model was needed
        let label = if log.lightweight_successful { 1.0 } else { 0.0 };

        features.push(feature_vec);
        labels.push(label);
    }

    (features, labels)
}

Data Normalization

Always normalize your features before training:

let mut dataset = TrainingDataset::new(features, labels)?;
let (means, stds) = dataset.normalize()?;

// Save normalization parameters for inference
save_normalization_params("models/normalization.json", &means, &stds)?;

During inference, apply the same normalization:

fn normalize_features(features: &mut [f32], means: &[f32], stds: &[f32]) {
    for (i, feat) in features.iter_mut().enumerate() {
        *feat = (*feat - means[i]) / stds[i];
    }
}

Training Loop

Basic Training

let mut trainer = Trainer::new(&model_config, training_config);
let metrics = trainer.train(&mut model, &dataset)?;

// Print final results
if let Some(last) = metrics.last() {
    println!("Final validation accuracy: {:.2}%", last.val_accuracy * 100.0);
}

Custom Training Loop

For more control, implement your own training loop:

use ruvector_tiny_dancer_core::training::BatchIterator;

for epoch in 0..config.epochs {
    let mut epoch_loss = 0.0;
    let mut n_batches = 0;

    // Training phase
    let batch_iter = BatchIterator::new(&train_dataset, config.batch_size, true);
    for (features, labels, _) in batch_iter {
        // Forward pass
        let predictions: Vec<f32> = features
            .iter()
            .map(|f| model.forward(f, None).unwrap())
            .collect();

        // Compute loss
        let batch_loss: f32 = predictions
            .iter()
            .zip(&labels)
            .map(|(&pred, &target)| binary_cross_entropy(pred, target))
            .sum::<f32>() / predictions.len() as f32;

        epoch_loss += batch_loss;
        n_batches += 1;

        // Backward pass (simplified - real implementation needs BPTT)
        // ...
    }

    println!("Epoch {}: loss = {:.4}", epoch, epoch_loss / n_batches as f32);
}

Knowledge Distillation

Knowledge distillation allows a smaller "student" model to learn from a larger "teacher" model.

Setup

use ruvector_tiny_dancer_core::training::{
    generate_teacher_predictions,
    temperature_softmax,
};

// 1. Create/load teacher model (larger, pre-trained)
let teacher_config = FastGRNNConfig {
    input_dim: 5,
    hidden_dim: 32,  // Larger than student
    output_dim: 1,
    ..Default::default()
};
let teacher = FastGRNN::load("models/teacher.safetensors")?;

// 2. Generate soft targets
let temperature = 3.0;  // Higher = softer probabilities
let soft_targets = generate_teacher_predictions(
    &teacher,
    &dataset.features,
    temperature
)?;

// 3. Add soft targets to dataset
let dataset = dataset.with_soft_targets(soft_targets)?;

// 4. Enable distillation in training config
let training_config = TrainingConfig {
    enable_distillation: true,
    distillation_temperature: temperature,
    distillation_alpha: 0.7,  // 70% soft targets, 30% hard targets
    ..Default::default()
};

Distillation Loss

The total loss combines hard and soft targets:

L_total = α × L_soft + (1 - α) × L_hard

where:
- L_soft = BCE(student_logit / T, teacher_logit / T)
- L_hard = BCE(student_logit, true_label)
- α = distillation_alpha (typically 0.5 to 0.9)
- T = temperature (typically 2.0 to 5.0)

Benefits

  • Faster Inference: Student model is smaller and faster
  • Better Accuracy: Student learns from teacher's knowledge
  • Compression: 2-4x smaller models with minimal accuracy loss
  • Transfer Learning: Transfer knowledge across architectures

Advanced Features

Learning Rate Scheduling

Exponential decay schedule:

TrainingConfig {
    learning_rate: 0.01,      // Initial LR
    lr_decay: 0.8,            // Multiply by 0.8 every lr_decay_step epochs
    lr_decay_step: 10,        // Decay every 10 epochs
    ..Default::default()
}

// Schedule:
// Epochs 0-9:   LR = 0.01
// Epochs 10-19: LR = 0.008
// Epochs 20-29: LR = 0.0064
// Epochs 30-39: LR = 0.00512
// ...

Early Stopping

Prevent overfitting by stopping when validation loss stops improving:

TrainingConfig {
    early_stopping_patience: Some(5),  // Stop after 5 epochs without improvement
    ..Default::default()
}

Gradient Clipping

Prevent exploding gradients in RNNs:

TrainingConfig {
    grad_clip: 5.0,  // Clip gradients to [-5.0, 5.0]
    ..Default::default()
}

Regularization

L2 weight decay to prevent overfitting:

TrainingConfig {
    l2_reg: 1e-5,  // Add L2 penalty to loss
    ..Default::default()
}

Production Deployment

Training Pipeline

  1. Data Collection

    // Collect production logs
    let logs = collect_routing_logs_from_db(db_path)?;
    let (features, labels) = extract_features_and_labels(&logs);
    
  2. Data Validation

    // Check data quality
    assert!(features.len() >= 1000, "Need at least 1000 samples");
    assert!(labels.iter().filter(|&&l| l > 0.5).count() > 100,
            "Need balanced dataset");
    
  3. Training

    let mut dataset = TrainingDataset::new(features, labels)?;
    let (means, stds) = dataset.normalize()?;
    
    let mut trainer = Trainer::new(&model_config, training_config);
    let metrics = trainer.train(&mut model, &dataset)?;
    
  4. Validation

    // Test on holdout set
    let (_, test_dataset) = dataset.split(0.2)?;
    let (test_loss, test_accuracy) = evaluate_model(&model, &test_dataset)?;
    
    assert!(test_accuracy > 0.85, "Model accuracy too low");
    
  5. Save Artifacts

    // Save model
    model.save("models/fastgrnn_v1.safetensors")?;
    
    // Save normalization params
    save_normalization("models/normalization_v1.json", &means, &stds)?;
    
    // Save metrics
    trainer.save_metrics("models/metrics_v1.json")?;
    
  6. Optimization

    // Quantize for production
    model.quantize()?;
    
    // Optional: Prune weights
    model.prune(0.3)?;  // 30% sparsity
    

Continual Learning

Update the model with new data:

// Load existing model
let mut model = FastGRNN::load("models/current.safetensors")?;

// Collect new data
let new_logs = collect_recent_logs(since_timestamp)?;
let (new_features, new_labels) = extract_features_and_labels(&new_logs);

// Create dataset
let new_dataset = TrainingDataset::new(new_features, new_labels)?;

// Fine-tune with lower learning rate
let training_config = TrainingConfig {
    learning_rate: 0.0001,  // Lower LR for fine-tuning
    epochs: 10,
    ..Default::default()
};

let mut trainer = Trainer::new(model.config(), training_config);
trainer.train(&mut model, &new_dataset)?;

// Save updated model
model.save("models/current_v2.safetensors")?;

Model Versioning

use chrono::Utc;

pub struct ModelVersion {
    pub version: String,
    pub timestamp: i64,
    pub model_path: String,
    pub metrics_path: String,
    pub normalization_path: String,
    pub test_accuracy: f32,
    pub model_size_bytes: usize,
}

impl ModelVersion {
    pub fn create_new(model: &FastGRNN, metrics: &[TrainingMetrics]) -> Self {
        let timestamp = Utc::now().timestamp();
        let version = format!("v{}", timestamp);

        Self {
            version: version.clone(),
            timestamp,
            model_path: format!("models/fastgrnn_{}.safetensors", version),
            metrics_path: format!("models/metrics_{}.json", version),
            normalization_path: format!("models/norm_{}.json", version),
            test_accuracy: metrics.last().unwrap().val_accuracy,
            model_size_bytes: model.size_bytes(),
        }
    }
}

Performance Benchmarks

Training Speed

Dataset Size Batch Size Epoch Time Total Time (50 epochs)
1,000 32 0.2s 10s
10,000 64 1.5s 75s
100,000 128 12s 600s (10 min)

Model Size

Configuration Parameters FP32 Size INT8 Size Compression
Tiny (8 hidden) ~250 1 KB 256 B 4x
Small (16 hidden) ~850 3.4 KB 850 B 4x
Medium (32 hidden) ~3,200 12.8 KB 3.2 KB 4x

Inference Speed

After training and quantization:

  • Inference time: < 100 μs per sample
  • Batch inference (32 samples): < 1 ms
  • Memory footprint: < 5 KB

Troubleshooting

Common Issues

1. Loss Not Decreasing

Symptoms: Training loss stays high or increases

Solutions:

  • Reduce learning rate (try 0.001 or lower)
  • Increase batch size
  • Check data normalization
  • Verify labels are correct (0.0 or 1.0)
  • Add more training data

2. Overfitting

Symptoms: Training accuracy high, validation accuracy low

Solutions:

  • Increase L2 regularization (try 1e-4)
  • Reduce model size (fewer hidden units)
  • Use early stopping
  • Add more training data
  • Increase validation split

3. Slow Convergence

Symptoms: Training takes too many epochs

Solutions:

  • Increase learning rate (try 0.01 or 0.1)
  • Use knowledge distillation
  • Better feature engineering
  • Use larger batch sizes

4. Gradient Explosion

Symptoms: Loss becomes NaN, training crashes

Solutions:

  • Enable gradient clipping (grad_clip: 1.0 or 5.0)
  • Reduce learning rate
  • Check for invalid data (NaN, Inf values)

Next Steps

  1. Run the example: cargo run --example train-model
  2. Collect your own data: Integrate with production logs
  3. Experiment with hyperparameters: Find optimal settings
  4. Deploy to production: Integrate with the Router
  5. Monitor performance: Track accuracy and latency
  6. Iterate: Collect more data and retrain regularly

References