Files
wifi-densepose/vendor/ruvector/examples/ruvLLM/src/router.rs

905 lines
28 KiB
Rust

//! FastGRNN Router for intelligent resource allocation
//!
//! Implements a FastGRNN (Fast, Accurate, Stable, and Tiny GRU) based router
//! that learns to select optimal model size, context size, and generation
//! parameters based on query characteristics.
use crate::config::RouterConfig;
use crate::error::{Error, Result, RouterError};
use crate::types::{ModelSize, RouterSample, RoutingDecision, CONTEXT_BINS};
use ndarray::{Array1, Array2, Axis};
use parking_lot::RwLock;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
/// FastGRNN Router for dynamic resource allocation
pub struct FastGRNNRouter {
/// Cell parameters
cell: FastGRNNCell,
/// Output heads
output_heads: OutputHeads,
/// Input normalization parameters
input_norm: LayerNorm,
/// Configuration
config: RouterConfig,
/// Training statistics
stats: RouterStats,
}
/// Router statistics for monitoring
#[derive(Debug, Default)]
pub struct RouterStats {
/// Total forward passes
pub forward_count: AtomicU64,
/// Total training steps
pub training_steps: AtomicU64,
/// Cumulative loss
pub cumulative_loss: RwLock<f64>,
/// Model selection histogram
pub model_counts: [AtomicU64; 4],
}
impl RouterStats {
pub fn record_forward(&self, model: ModelSize) {
self.forward_count.fetch_add(1, Ordering::Relaxed);
self.model_counts[model.to_index()].fetch_add(1, Ordering::Relaxed);
}
pub fn get_model_distribution(&self) -> [f64; 4] {
let total = self.forward_count.load(Ordering::Relaxed) as f64;
if total == 0.0 {
return [0.25; 4];
}
[
self.model_counts[0].load(Ordering::Relaxed) as f64 / total,
self.model_counts[1].load(Ordering::Relaxed) as f64 / total,
self.model_counts[2].load(Ordering::Relaxed) as f64 / total,
self.model_counts[3].load(Ordering::Relaxed) as f64 / total,
]
}
}
/// FastGRNN cell implementation with sparse and low-rank matrices
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FastGRNNCell {
/// Input-to-update gate weights (dense, will be sparsified)
w_z: Array2<f32>,
/// Recurrent-to-update gate weights (low-rank: U_z = A_z @ B_z)
u_z_a: Array2<f32>,
u_z_b: Array2<f32>,
/// Update gate bias
b_z: Array1<f32>,
/// Input-to-hidden weights
w_h: Array2<f32>,
/// Recurrent-to-hidden weights (low-rank: U_h = A_h @ B_h)
u_h_a: Array2<f32>,
u_h_b: Array2<f32>,
/// Hidden bias
b_h: Array1<f32>,
/// FastGRNN zeta scalar (gate modulation)
zeta: f32,
/// FastGRNN nu scalar (gate modulation)
nu: f32,
/// Sparsity mask for W matrices
w_z_mask: Array2<f32>,
w_h_mask: Array2<f32>,
}
/// Output heads for routing decisions
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputHeads {
/// Model selection: hidden_dim -> 4
w_model: Array2<f32>,
b_model: Array1<f32>,
/// Context selection: hidden_dim -> 5
w_context: Array2<f32>,
b_context: Array1<f32>,
/// Temperature: hidden_dim -> 1
w_temp: Array1<f32>,
b_temp: f32,
/// Top-p: hidden_dim -> 1
w_top_p: Array1<f32>,
b_top_p: f32,
/// Confidence: hidden_dim -> 1
w_conf: Array1<f32>,
b_conf: f32,
}
/// Layer normalization
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerNorm {
gamma: Array1<f32>,
beta: Array1<f32>,
eps: f32,
}
/// Adam optimizer state
#[derive(Debug, Clone)]
pub struct AdamState {
/// First moment estimates
m: Vec<Array1<f32>>,
/// Second moment estimates
v: Vec<Array1<f32>>,
/// Time step
t: usize,
/// Learning rate
lr: f32,
/// Beta1
beta1: f32,
/// Beta2
beta2: f32,
/// Epsilon
eps: f32,
}
impl AdamState {
pub fn new(param_shapes: &[usize], lr: f32) -> Self {
Self {
m: param_shapes.iter().map(|&s| Array1::zeros(s)).collect(),
v: param_shapes.iter().map(|&s| Array1::zeros(s)).collect(),
t: 0,
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
}
}
pub fn step(&mut self, params: &mut [Array1<f32>], grads: &[Array1<f32>]) {
self.t += 1;
let bias_correction1 = 1.0 - self.beta1.powi(self.t as i32);
let bias_correction2 = 1.0 - self.beta2.powi(self.t as i32);
for (i, (param, grad)) in params.iter_mut().zip(grads.iter()).enumerate() {
// Update biased first moment estimate
self.m[i] = &self.m[i] * self.beta1 + grad * (1.0 - self.beta1);
// Update biased second moment estimate
self.v[i] = &self.v[i] * self.beta2 + &(grad * grad) * (1.0 - self.beta2);
// Compute bias-corrected estimates
let m_hat = &self.m[i] / bias_correction1;
let v_hat = &self.v[i] / bias_correction2;
// Update parameters
*param = param.clone() - &(&m_hat / &(v_hat.mapv(f32::sqrt) + self.eps)) * self.lr;
}
}
}
impl FastGRNNRouter {
/// Create a new router with random initialization
pub fn new(config: &RouterConfig) -> Result<Self> {
let cell = FastGRNNCell::new(
config.input_dim,
config.hidden_dim,
config.sparsity,
config.rank,
);
let output_heads = OutputHeads::new(config.hidden_dim);
let input_norm = LayerNorm::new(config.input_dim);
Ok(Self {
cell,
output_heads,
input_norm,
config: config.clone(),
stats: RouterStats::default(),
})
}
/// Load router from weights file
pub fn load(path: impl AsRef<Path>, config: &RouterConfig) -> Result<Self> {
let data = std::fs::read(path.as_ref())?;
let (cell, output_heads, input_norm): (FastGRNNCell, OutputHeads, LayerNorm) =
bincode::serde::decode_from_slice(&data, bincode::config::standard())
.map_err(|e| Error::Serialization(e.to_string()))?
.0;
Ok(Self {
cell,
output_heads,
input_norm,
config: config.clone(),
stats: RouterStats::default(),
})
}
/// Save router weights
pub fn save_weights(&self, path: impl AsRef<Path>) -> Result<()> {
let data = bincode::serde::encode_to_vec(
(&self.cell, &self.output_heads, &self.input_norm),
bincode::config::standard(),
)
.map_err(|e| Error::Serialization(e.to_string()))?;
std::fs::write(path, data)?;
Ok(())
}
/// Forward pass through router
pub fn forward(&self, features: &[f32], hidden: &[f32]) -> Result<RoutingDecision> {
// Validate input dimensions
if features.len() != self.config.input_dim {
return Err(RouterError::InvalidFeatures {
expected: self.config.input_dim,
actual: features.len(),
}
.into());
}
let x = Array1::from_vec(features.to_vec());
let h = Array1::from_vec(hidden.to_vec());
// Normalize input
let x_norm = self.input_norm.forward(&x);
// FastGRNN cell
let h_new = self.cell.forward(&x_norm, &h);
// Output heads
let model_logits = self.output_heads.model_forward(&h_new);
let context_logits = self.output_heads.context_forward(&h_new);
let temp_raw = self.output_heads.temp_forward(&h_new);
let top_p_raw = self.output_heads.top_p_forward(&h_new);
let conf_raw = self.output_heads.confidence_forward(&h_new);
// Activations
let model_probs = softmax_array(&model_logits);
let context_probs = softmax_array(&context_logits);
let temperature = sigmoid(temp_raw) * 2.0;
let top_p = sigmoid(top_p_raw);
let confidence = sigmoid(conf_raw);
// Decode decisions
let (model, context_size) = if confidence >= self.config.confidence_threshold {
let model_idx = argmax_array(&model_probs);
let context_idx = argmax_array(&context_probs);
(ModelSize::from_index(model_idx), CONTEXT_BINS[context_idx])
} else {
// Safe defaults when confidence is low
(ModelSize::B1_2, 2048)
};
// Record statistics
self.stats.record_forward(model);
Ok(RoutingDecision {
model,
context_size,
temperature,
top_p,
confidence,
model_probs: [
model_probs[0],
model_probs[1],
model_probs[2],
model_probs[3],
],
new_hidden: h_new.to_vec(),
features: features.to_vec(),
})
}
/// Train the router on a batch of samples
pub fn train_batch(
&mut self,
samples: &[RouterSample],
learning_rate: f32,
ewc_lambda: f32,
fisher_info: Option<&[f32]>,
optimal_weights: Option<&[f32]>,
) -> TrainingMetrics {
if samples.is_empty() {
return TrainingMetrics::default();
}
let batch_size = samples.len() as f32;
let mut total_loss = 0.0;
let mut model_correct = 0;
let mut context_correct = 0;
// Accumulate gradients over batch
let mut grad_accum = self.zero_gradients();
for sample in samples {
let hidden = vec![0.0f32; self.config.hidden_dim];
let x = Array1::from_vec(sample.features.clone());
let h = Array1::from_vec(hidden);
// Forward pass
let x_norm = self.input_norm.forward(&x);
let h_new = self.cell.forward(&x_norm, &h);
let model_logits = self.output_heads.model_forward(&h_new);
let context_logits = self.output_heads.context_forward(&h_new);
let temp_pred = self.output_heads.temp_forward(&h_new);
let top_p_pred = self.output_heads.top_p_forward(&h_new);
let model_probs = softmax_array(&model_logits);
let context_probs = softmax_array(&context_logits);
// Compute loss
let model_loss = -model_probs[sample.label_model].ln().max(-10.0);
let context_loss = -context_probs[sample.label_context].ln().max(-10.0);
let temp_loss = (sigmoid(temp_pred) * 2.0 - sample.label_temperature).powi(2);
let top_p_loss = (sigmoid(top_p_pred) - sample.label_top_p).powi(2);
let sample_loss = model_loss + context_loss + 0.1 * temp_loss + 0.1 * top_p_loss;
total_loss += sample_loss;
// Check accuracy
if argmax_array(&model_probs) == sample.label_model {
model_correct += 1;
}
if argmax_array(&context_probs) == sample.label_context {
context_correct += 1;
}
// Compute gradients (simplified - using finite differences for demo)
self.accumulate_gradients(
&mut grad_accum,
sample,
&h_new,
&model_probs,
&context_probs,
);
}
// Average gradients
for g in &mut grad_accum {
*g /= batch_size;
}
// Add EWC regularization gradient if provided
if let (Some(fisher), Some(optimal)) = (fisher_info, optimal_weights) {
self.add_ewc_gradient(&mut grad_accum, fisher, optimal, ewc_lambda);
}
// Apply gradients with simple SGD (can be replaced with Adam)
self.apply_gradients(&grad_accum, learning_rate);
self.stats.training_steps.fetch_add(1, Ordering::Relaxed);
*self.stats.cumulative_loss.write() += total_loss as f64;
TrainingMetrics {
total_loss: total_loss / batch_size,
model_accuracy: model_correct as f32 / batch_size,
context_accuracy: context_correct as f32 / batch_size,
samples_processed: samples.len(),
}
}
fn zero_gradients(&self) -> Vec<f32> {
vec![0.0; self.parameter_count()]
}
fn parameter_count(&self) -> usize {
let cell_params = self.cell.w_z.len()
+ self.cell.w_h.len()
+ self.cell.u_z_a.len()
+ self.cell.u_z_b.len()
+ self.cell.u_h_a.len()
+ self.cell.u_h_b.len()
+ self.cell.b_z.len()
+ self.cell.b_h.len();
let head_params = self.output_heads.w_model.len()
+ self.output_heads.w_context.len()
+ self.output_heads.w_temp.len()
+ self.output_heads.w_top_p.len()
+ self.output_heads.w_conf.len()
+ self.output_heads.b_model.len()
+ self.output_heads.b_context.len()
+ 3; // temp, top_p, conf biases
cell_params + head_params
}
fn accumulate_gradients(
&self,
grads: &mut [f32],
sample: &RouterSample,
h_new: &Array1<f32>,
model_probs: &Array1<f32>,
context_probs: &Array1<f32>,
) {
// Simplified gradient computation
// In production, use autograd or manual backprop
// Model head gradients (cross-entropy)
let mut model_grad = model_probs.clone();
model_grad[sample.label_model] -= 1.0;
// Context head gradients
let mut context_grad = context_probs.clone();
context_grad[sample.label_context] -= 1.0;
// Accumulate into flat gradient buffer
let offset = 0;
for (i, &g) in model_grad.iter().enumerate() {
for (j, &h) in h_new.iter().enumerate() {
let idx = offset + i * self.config.hidden_dim + j;
if idx < grads.len() {
grads[idx] += g * h;
}
}
}
}
fn add_ewc_gradient(&self, grads: &mut [f32], fisher: &[f32], optimal: &[f32], lambda: f32) {
let params = self.get_flat_params();
for (i, ((g, &f), &w_opt)) in grads
.iter_mut()
.zip(fisher.iter())
.zip(optimal.iter())
.enumerate()
{
if i < params.len() {
*g += lambda * f * (params[i] - w_opt);
}
}
}
fn apply_gradients(&mut self, grads: &[f32], lr: f32) {
// Apply gradients to output heads (simplified)
let mut offset = 0;
let model_size = self.output_heads.w_model.len();
for (i, w) in self.output_heads.w_model.iter_mut().enumerate() {
if offset + i < grads.len() {
*w -= lr * grads[offset + i];
}
}
offset += model_size;
let context_size = self.output_heads.w_context.len();
for (i, w) in self.output_heads.w_context.iter_mut().enumerate() {
if offset + i < grads.len() {
*w -= lr * grads[offset + i];
}
}
}
fn get_flat_params(&self) -> Vec<f32> {
let mut params = Vec::new();
params.extend(self.output_heads.w_model.iter().cloned());
params.extend(self.output_heads.w_context.iter().cloned());
params.extend(self.output_heads.w_temp.iter().cloned());
params.extend(self.output_heads.w_top_p.iter().cloned());
params.extend(self.output_heads.w_conf.iter().cloned());
params
}
/// Compute Fisher information diagonal for EWC
pub fn compute_fisher(&self, samples: &[RouterSample]) -> Vec<f32> {
let param_count = self.parameter_count();
let mut fisher = vec![0.0f32; param_count];
for sample in samples {
let hidden = vec![0.0f32; self.config.hidden_dim];
if let Ok(decision) = self.forward(&sample.features, &hidden) {
// Approximate Fisher with squared gradients
// In production, compute actual log-likelihood gradients
for i in 0..fisher.len().min(sample.features.len()) {
fisher[i] += sample.features[i].powi(2) * decision.confidence;
}
}
}
// Normalize
let n = samples.len() as f32;
for f in &mut fisher {
*f /= n;
}
fisher
}
/// Get router statistics
pub fn stats(&self) -> &RouterStats {
&self.stats
}
/// Get current weights as a flat vector (for EWC)
pub fn get_weights(&self) -> Vec<f32> {
self.get_flat_params()
}
/// Reset router to initial state
pub fn reset(&mut self) {
self.cell = FastGRNNCell::new(
self.config.input_dim,
self.config.hidden_dim,
self.config.sparsity,
self.config.rank,
);
self.output_heads = OutputHeads::new(self.config.hidden_dim);
}
}
impl FastGRNNCell {
fn new(input_dim: usize, hidden_dim: usize, sparsity: f32, rank: usize) -> Self {
use rand::Rng;
use rand_distr::Normal;
let mut rng = rand::thread_rng();
let std_w = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
let std_u = (2.0 / (hidden_dim + hidden_dim) as f32).sqrt();
let normal_w = Normal::new(0.0, std_w).unwrap();
let normal_u = Normal::new(0.0, std_u).unwrap();
// Initialize W matrices
let w_z = Array2::from_shape_fn((hidden_dim, input_dim), |_| rng.sample(normal_w));
let w_h = Array2::from_shape_fn((hidden_dim, input_dim), |_| rng.sample(normal_w));
// Create sparsity masks
let w_z_mask = Array2::from_shape_fn((hidden_dim, input_dim), |_| {
if rng.gen::<f32>() > sparsity {
1.0
} else {
0.0
}
});
let w_h_mask = Array2::from_shape_fn((hidden_dim, input_dim), |_| {
if rng.gen::<f32>() > sparsity {
1.0
} else {
0.0
}
});
// Initialize low-rank U matrices
let u_z_a = Array2::from_shape_fn((hidden_dim, rank), |_| rng.sample(normal_u));
let u_z_b = Array2::from_shape_fn((rank, hidden_dim), |_| rng.sample(normal_u));
let u_h_a = Array2::from_shape_fn((hidden_dim, rank), |_| rng.sample(normal_u));
let u_h_b = Array2::from_shape_fn((rank, hidden_dim), |_| rng.sample(normal_u));
Self {
w_z: &w_z * &w_z_mask,
w_h: &w_h * &w_h_mask,
u_z_a,
u_z_b,
u_h_a,
u_h_b,
b_z: Array1::zeros(hidden_dim),
b_h: Array1::zeros(hidden_dim),
zeta: 1.0,
nu: 0.5,
w_z_mask,
w_h_mask,
}
}
fn forward(&self, x: &Array1<f32>, h: &Array1<f32>) -> Array1<f32> {
// z = sigmoid(W_z @ x + U_z @ h + b_z)
// where U_z = A_z @ B_z (low-rank)
let w_z_x = self.w_z.dot(x);
let u_z_h = self.u_z_a.dot(&self.u_z_b.dot(h));
let z_pre = &w_z_x + &u_z_h + &self.b_z;
let z = z_pre.mapv(sigmoid);
// h_tilde = tanh(W_h @ x + U_h @ h + b_h)
let w_h_x = self.w_h.dot(x);
let u_h_h = self.u_h_a.dot(&self.u_h_b.dot(h));
let h_tilde_pre = &w_h_x + &u_h_h + &self.b_h;
let h_tilde = h_tilde_pre.mapv(|v| v.tanh());
// h_new = (zeta * (1 - z) + nu) * h_tilde + z * h
let gate = z.mapv(|zi| self.zeta * (1.0 - zi) + self.nu);
&gate * &h_tilde + &z * h
}
}
impl LayerNorm {
fn new(dim: usize) -> Self {
Self {
gamma: Array1::ones(dim),
beta: Array1::zeros(dim),
eps: 1e-5,
}
}
fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
let mean = x.mean().unwrap_or(0.0);
let var = x.mapv(|v| (v - mean).powi(2)).mean().unwrap_or(0.0);
let std = (var + self.eps).sqrt();
let normalized = x.mapv(|v| (v - mean) / std);
&self.gamma * &normalized + &self.beta
}
}
impl OutputHeads {
fn new(hidden_dim: usize) -> Self {
use rand::Rng;
use rand_distr::Normal;
let mut rng = rand::thread_rng();
let std = (2.0 / hidden_dim as f32).sqrt();
let normal = Normal::new(0.0, std).unwrap();
Self {
w_model: Array2::from_shape_fn((4, hidden_dim), |_| rng.sample(normal)),
b_model: Array1::zeros(4),
w_context: Array2::from_shape_fn((5, hidden_dim), |_| rng.sample(normal)),
b_context: Array1::zeros(5),
w_temp: Array1::from_shape_fn(hidden_dim, |_| rng.sample(normal)),
b_temp: 0.0,
w_top_p: Array1::from_shape_fn(hidden_dim, |_| rng.sample(normal)),
b_top_p: 0.0,
w_conf: Array1::from_shape_fn(hidden_dim, |_| rng.sample(normal)),
b_conf: 0.0,
}
}
fn model_forward(&self, h: &Array1<f32>) -> Array1<f32> {
self.w_model.dot(h) + &self.b_model
}
fn context_forward(&self, h: &Array1<f32>) -> Array1<f32> {
self.w_context.dot(h) + &self.b_context
}
fn temp_forward(&self, h: &Array1<f32>) -> f32 {
self.w_temp.dot(h) + self.b_temp
}
fn top_p_forward(&self, h: &Array1<f32>) -> f32 {
self.w_top_p.dot(h) + self.b_top_p
}
fn confidence_forward(&self, h: &Array1<f32>) -> f32 {
self.w_conf.dot(h) + self.b_conf
}
}
/// Training metrics
#[derive(Debug, Clone, Default)]
pub struct TrainingMetrics {
pub total_loss: f32,
pub model_accuracy: f32,
pub context_accuracy: f32,
pub samples_processed: usize,
}
// Helper functions
/// Optimized sigmoid with fast exp approximation
#[inline(always)]
fn sigmoid(x: f32) -> f32 {
// Fast sigmoid using rational approximation for |x| < 4.5
// More accurate than simple clamped exp for common ranges
let x = x.clamp(-20.0, 20.0);
if x.abs() < 4.5 {
// Pade approximant: 0.5 + 0.5 * x / (1 + |x| + 0.555 * x^2)
let abs_x = x.abs();
0.5 + 0.5 * x / (1.0 + abs_x + 0.555 * x * x)
} else {
1.0 / (1.0 + (-x).exp())
}
}
/// Optimized softmax for small arrays (common in router)
fn softmax_array(x: &Array1<f32>) -> Array1<f32> {
let len = x.len();
// For small arrays, use simple scalar approach with improved numerics
if len <= 8 {
let max = x.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp = x.mapv(|v| fast_exp(v - max));
let sum = exp.sum();
if sum > 0.0 {
exp / sum
} else {
Array1::from_elem(len, 1.0 / len as f32)
}
} else {
// For larger arrays, use standard approach
let max = x.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp = x.mapv(|v| (v - max).exp());
let sum = exp.sum();
// Guard against division by zero (all -inf inputs)
if sum > 0.0 {
exp / sum
} else {
Array1::from_elem(len, 1.0 / len as f32)
}
}
}
/// Fast exp approximation using Schraudolph's method
#[inline(always)]
fn fast_exp(x: f32) -> f32 {
// Clamp to avoid overflow/underflow
let x = x.clamp(-88.0, 88.0);
// Polynomial approximation: exp(x) ≈ 1 + x + x²/2 + x³/6 for |x| < 1
if x.abs() < 1.0 {
let x2 = x * x;
let x3 = x2 * x;
1.0 + x + x2 * 0.5 + x3 * 0.16666667
} else {
x.exp()
}
}
/// Branchless argmax for fixed-size arrays (optimized for common sizes)
#[inline]
fn argmax_array(x: &Array1<f32>) -> usize {
let len = x.len();
if len == 0 {
return 0;
}
// For size 4 (model selection), use branchless comparison
if len == 4 {
let x = x.as_slice().unwrap();
let mut max_idx = 0usize;
let mut max_val = x[0];
// Unrolled comparison
if x[1] > max_val {
max_val = x[1];
max_idx = 1;
}
if x[2] > max_val {
max_val = x[2];
max_idx = 2;
}
if x[3] > max_val {
max_idx = 3;
}
return max_idx;
}
// For size 5 (context selection), also unroll
if len == 5 {
let x = x.as_slice().unwrap();
let mut max_idx = 0usize;
let mut max_val = x[0];
if x[1] > max_val {
max_val = x[1];
max_idx = 1;
}
if x[2] > max_val {
max_val = x[2];
max_idx = 2;
}
if x[3] > max_val {
max_val = x[3];
max_idx = 3;
}
if x[4] > max_val {
max_idx = 4;
}
return max_idx;
}
// General case
x.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_creation() {
let config = RouterConfig::default();
let router = FastGRNNRouter::new(&config).unwrap();
assert_eq!(router.config.input_dim, 128);
assert_eq!(router.config.hidden_dim, 64);
}
#[test]
fn test_router_forward() {
let config = RouterConfig::default();
let router = FastGRNNRouter::new(&config).unwrap();
let features = vec![0.5f32; config.input_dim];
let hidden = vec![0.0f32; config.hidden_dim];
let decision = router.forward(&features, &hidden).unwrap();
// Verify outputs are valid
assert!(decision.temperature >= 0.0 && decision.temperature <= 2.0);
assert!(decision.top_p >= 0.0 && decision.top_p <= 1.0);
assert!(decision.confidence >= 0.0 && decision.confidence <= 1.0);
assert_eq!(decision.new_hidden.len(), config.hidden_dim);
// Probabilities should sum to ~1
let prob_sum: f32 = decision.model_probs.iter().sum();
assert!((prob_sum - 1.0).abs() < 0.01);
}
#[test]
fn test_router_training() {
let config = RouterConfig::default();
let mut router = FastGRNNRouter::new(&config).unwrap();
let samples: Vec<RouterSample> = (0..10)
.map(|i| RouterSample {
features: vec![0.1 * i as f32; config.input_dim],
label_model: i % 4,
label_context: i % 5,
label_temperature: 0.7,
label_top_p: 0.9,
quality: 0.8,
latency_ms: 100.0,
})
.collect();
let metrics = router.train_batch(&samples, 0.001, 0.0, None, None);
assert!(metrics.total_loss > 0.0);
assert!(metrics.samples_processed == 10);
}
#[test]
fn test_layer_norm() {
let norm = LayerNorm::new(4);
let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let result = norm.forward(&x);
// Mean should be ~0 after normalization
let mean = result.mean().unwrap();
assert!(mean.abs() < 0.01);
}
#[test]
fn test_softmax() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = softmax_array(&x);
let sum: f32 = result.sum();
assert!((sum - 1.0).abs() < 1e-5);
// Higher input should have higher probability
assert!(result[2] > result[1]);
assert!(result[1] > result[0]);
}
#[test]
fn test_fisher_computation() {
let config = RouterConfig::default();
let router = FastGRNNRouter::new(&config).unwrap();
let samples: Vec<RouterSample> = (0..5)
.map(|_| RouterSample {
features: vec![0.5f32; config.input_dim],
label_model: 1,
label_context: 2,
label_temperature: 0.7,
label_top_p: 0.9,
quality: 0.8,
latency_ms: 100.0,
})
.collect();
let fisher = router.compute_fisher(&samples);
assert!(!fisher.is_empty());
}
#[test]
fn test_stats_tracking() {
let config = RouterConfig::default();
let router = FastGRNNRouter::new(&config).unwrap();
let features = vec![0.5f32; config.input_dim];
let hidden = vec![0.0f32; config.hidden_dim];
for _ in 0..10 {
let _ = router.forward(&features, &hidden);
}
assert_eq!(router.stats.forward_count.load(Ordering::Relaxed), 10);
}
}