//! Training utilities for GNN models. //! //! Provides training loop utilities, optimizers, and loss functions. use crate::error::{GnnError, Result}; use crate::search::cosine_similarity; use ndarray::Array2; /// Optimizer types #[derive(Debug, Clone)] pub enum OptimizerType { /// Stochastic Gradient Descent Sgd { /// Learning rate learning_rate: f32, /// Momentum coefficient (0.0 = no momentum, 0.9 = standard) momentum: f32, }, /// Adam optimizer Adam { /// Learning rate learning_rate: f32, /// Beta1 parameter (exponential decay rate for first moment) beta1: f32, /// Beta2 parameter (exponential decay rate for second moment) beta2: f32, /// Epsilon for numerical stability epsilon: f32, }, } /// Optimizer state storage #[derive(Debug)] enum OptimizerState { /// SGD with momentum state Sgd { /// Momentum buffer (velocity) velocity: Option>, }, /// Adam optimizer state Adam { /// First moment estimate (mean of gradients) m: Option>, /// Second moment estimate (uncentered variance of gradients) v: Option>, /// Timestep counter t: usize, }, } /// Optimizer for parameter updates pub struct Optimizer { optimizer_type: OptimizerType, state: OptimizerState, } impl Optimizer { /// Create a new optimizer pub fn new(optimizer_type: OptimizerType) -> Self { let state = match &optimizer_type { OptimizerType::Sgd { .. } => OptimizerState::Sgd { velocity: None }, OptimizerType::Adam { .. } => OptimizerState::Adam { m: None, v: None, t: 0, }, }; Self { optimizer_type, state, } } /// Perform optimization step /// /// Updates parameters in-place based on gradients using the configured optimizer. /// /// # Arguments /// * `params` - Parameters to update (modified in-place) /// * `grads` - Gradients for the parameters /// /// # Returns /// * `Ok(())` on success /// * `Err(GnnError)` if shapes don't match or other errors occur pub fn step(&mut self, params: &mut Array2, grads: &Array2) -> Result<()> { // Validate shapes match if params.shape() != grads.shape() { return Err(GnnError::dimension_mismatch( format!("{:?}", params.shape()), format!("{:?}", grads.shape()), )); } match (&self.optimizer_type, &mut self.state) { ( OptimizerType::Sgd { learning_rate, momentum, }, OptimizerState::Sgd { velocity }, ) => Self::sgd_step_with_momentum(params, grads, *learning_rate, *momentum, velocity), ( OptimizerType::Adam { learning_rate, beta1, beta2, epsilon, }, OptimizerState::Adam { m, v, t }, ) => Self::adam_step( params, grads, *learning_rate, *beta1, *beta2, *epsilon, m, v, t, ), _ => return Err(GnnError::invalid_input("Optimizer type and state mismatch")), } } /// SGD optimization step with momentum /// /// Implements: v_t = momentum * v_{t-1} + learning_rate * grad /// params = params - v_t fn sgd_step_with_momentum( params: &mut Array2, grads: &Array2, learning_rate: f32, momentum: f32, velocity: &mut Option>, ) -> Result<()> { if momentum == 0.0 { // Simple SGD without momentum *params -= &(grads * learning_rate); } else { // SGD with momentum if velocity.is_none() { // Initialize velocity buffer *velocity = Some(Array2::zeros(params.dim())); } if let Some(v) = velocity { // Update velocity: v = momentum * v + learning_rate * grad let new_velocity = v.mapv(|x| x * momentum) + grads * learning_rate; *v = new_velocity; // Update parameters: params = params - v *params -= &*v; } } Ok(()) } /// Adam optimization step /// /// Implements the Adam algorithm: /// 1. m_t = beta1 * m_{t-1} + (1 - beta1) * g_t /// 2. v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2 /// 3. m_hat = m_t / (1 - beta1^t) /// 4. v_hat = v_t / (1 - beta2^t) /// 5. params = params - lr * m_hat / (sqrt(v_hat) + epsilon) #[allow(clippy::too_many_arguments)] fn adam_step( params: &mut Array2, grads: &Array2, learning_rate: f32, beta1: f32, beta2: f32, epsilon: f32, m: &mut Option>, v: &mut Option>, t: &mut usize, ) -> Result<()> { // Initialize moment buffers if needed if m.is_none() { *m = Some(Array2::zeros(params.dim())); } if v.is_none() { *v = Some(Array2::zeros(params.dim())); } // Increment timestep *t += 1; let timestep = *t as f32; if let (Some(m_buf), Some(v_buf)) = (m, v) { // Update biased first moment estimate // m_t = beta1 * m_{t-1} + (1 - beta1) * g_t let new_m = m_buf.mapv(|x| x * beta1) + grads * (1.0 - beta1); *m_buf = new_m; // Update biased second raw moment estimate // v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2 let grads_squared = grads.mapv(|x| x * x); let new_v = v_buf.mapv(|x| x * beta2) + grads_squared * (1.0 - beta2); *v_buf = new_v; // Compute bias-corrected first moment estimate // m_hat = m_t / (1 - beta1^t) let bias_correction1 = 1.0 - beta1.powi(*t as i32); let m_hat = m_buf.mapv(|x| x / bias_correction1); // Compute bias-corrected second raw moment estimate // v_hat = v_t / (1 - beta2^t) let bias_correction2 = 1.0 - beta2.powi(*t as i32); let v_hat = v_buf.mapv(|x| x / bias_correction2); // Update parameters // params = params - lr * m_hat / (sqrt(v_hat) + epsilon) let update = m_hat .iter() .zip(v_hat.iter()) .map(|(&m_val, &v_val)| learning_rate * m_val / (v_val.sqrt() + epsilon)); for (param, upd) in params.iter_mut().zip(update) { *param -= upd; } } Ok(()) } } /// Loss function types #[derive(Debug, Clone, Copy)] pub enum LossType { /// Mean Squared Error Mse, /// Cross Entropy CrossEntropy, /// Binary Cross Entropy BinaryCrossEntropy, } /// Loss function implementations for neural network training. /// /// Provides forward (loss computation) and backward (gradient computation) passes /// for common loss functions used in GNN training. /// /// # Numerical Stability /// /// All loss functions use epsilon clamping and gradient clipping to prevent /// numerical instability with extreme prediction values (near 0 or 1). pub struct Loss; impl Loss { /// Small epsilon value for numerical stability in logarithms and divisions. const EPS: f32 = 1e-7; /// Maximum absolute gradient value to prevent explosion. const MAX_GRAD: f32 = 1e6; /// Compute the loss value between predictions and targets. /// /// # Arguments /// * `loss_type` - The type of loss function to use /// * `predictions` - Model predictions as a 2D array /// * `targets` - Ground truth targets as a 2D array (same shape as predictions) /// /// # Returns /// * `Ok(f32)` - The computed scalar loss value /// * `Err(GnnError)` - If shapes don't match or computation fails /// /// # Example /// ``` /// use ndarray::Array2; /// use ruvector_gnn::training::{Loss, LossType}; /// /// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap(); /// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap(); /// let loss = Loss::compute(LossType::Mse, &predictions, &targets).unwrap(); /// assert!(loss >= 0.0); /// ``` pub fn compute( loss_type: LossType, predictions: &Array2, targets: &Array2, ) -> Result { // Validate shapes match if predictions.shape() != targets.shape() { return Err(GnnError::dimension_mismatch( format!("{:?}", predictions.shape()), format!("{:?}", targets.shape()), )); } if predictions.is_empty() { return Err(GnnError::invalid_input( "Cannot compute loss on empty arrays", )); } match loss_type { LossType::Mse => Self::mse_forward(predictions, targets), LossType::CrossEntropy => Self::cross_entropy_forward(predictions, targets), LossType::BinaryCrossEntropy => Self::bce_forward(predictions, targets), } } /// Compute the gradient of the loss with respect to predictions. /// /// # Arguments /// * `loss_type` - The type of loss function to use /// * `predictions` - Model predictions as a 2D array /// * `targets` - Ground truth targets as a 2D array (same shape as predictions) /// /// # Returns /// * `Ok(Array2)` - Gradient array with same shape as predictions /// * `Err(GnnError)` - If shapes don't match or computation fails /// /// # Example /// ``` /// use ndarray::Array2; /// use ruvector_gnn::training::{Loss, LossType}; /// /// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap(); /// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap(); /// let grad = Loss::gradient(LossType::Mse, &predictions, &targets).unwrap(); /// assert_eq!(grad.shape(), predictions.shape()); /// ``` pub fn gradient( loss_type: LossType, predictions: &Array2, targets: &Array2, ) -> Result> { // Validate shapes match if predictions.shape() != targets.shape() { return Err(GnnError::dimension_mismatch( format!("{:?}", predictions.shape()), format!("{:?}", targets.shape()), )); } if predictions.is_empty() { return Err(GnnError::invalid_input( "Cannot compute gradient on empty arrays", )); } match loss_type { LossType::Mse => Self::mse_backward(predictions, targets), LossType::CrossEntropy => Self::cross_entropy_backward(predictions, targets), LossType::BinaryCrossEntropy => Self::bce_backward(predictions, targets), } } /// Mean Squared Error: MSE = mean((predictions - targets)^2) fn mse_forward(predictions: &Array2, targets: &Array2) -> Result { let diff = predictions - targets; let squared = diff.mapv(|x| x * x); Ok(squared.mean().unwrap_or(0.0)) } /// MSE gradient: d(MSE)/d(pred) = 2 * (predictions - targets) / n fn mse_backward(predictions: &Array2, targets: &Array2) -> Result> { let n = predictions.len() as f32; let diff = predictions - targets; Ok(diff.mapv(|x| 2.0 * x / n)) } /// Cross Entropy: CE = -mean(sum(targets * log(predictions), axis=1)) /// /// Used for multi-class classification where targets are one-hot encoded /// and predictions are softmax probabilities. fn cross_entropy_forward(predictions: &Array2, targets: &Array2) -> Result { let log_pred = predictions.mapv(|x| (x.max(Self::EPS)).ln()); let elementwise = targets * &log_pred; let loss = -elementwise.sum() / predictions.nrows() as f32; Ok(loss) } /// Cross Entropy gradient: d(CE)/d(pred) = -targets / predictions / n /// /// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion. fn cross_entropy_backward( predictions: &Array2, targets: &Array2, ) -> Result> { let n = predictions.nrows() as f32; // Clamp predictions to avoid division by zero let safe_pred = predictions.mapv(|x| x.max(Self::EPS)); let grad = targets / &safe_pred; // Apply gradient clipping Ok(grad.mapv(|x| (-x / n).clamp(-Self::MAX_GRAD, Self::MAX_GRAD))) } /// Binary Cross Entropy: BCE = -mean(targets * log(pred) + (1 - targets) * log(1 - pred)) /// /// Used for binary classification or multi-label classification. fn bce_forward(predictions: &Array2, targets: &Array2) -> Result { let n = predictions.len() as f32; let loss: f32 = predictions .iter() .zip(targets.iter()) .map(|(&p, &t)| { // Clamp predictions to (eps, 1-eps) for numerical stability let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS); -(t * p_safe.ln() + (1.0 - t) * (1.0 - p_safe).ln()) }) .sum(); Ok(loss / n) } /// BCE gradient: d(BCE)/d(pred) = (-targets/pred + (1-targets)/(1-pred)) / n /// /// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion. fn bce_backward(predictions: &Array2, targets: &Array2) -> Result> { let n = predictions.len() as f32; let grad_vec: Vec = predictions .iter() .zip(targets.iter()) .map(|(&p, &t)| { // Clamp predictions for numerical stability let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS); let grad = (-t / p_safe + (1.0 - t) / (1.0 - p_safe)) / n; // Clip gradient to prevent explosion grad.clamp(-Self::MAX_GRAD, Self::MAX_GRAD) }) .collect(); Array2::from_shape_vec(predictions.dim(), grad_vec) .map_err(|e| GnnError::training(format!("Failed to reshape gradient: {}", e))) } } /// TODO: Implement training configuration #[derive(Debug, Clone)] pub struct TrainingConfig { /// Number of epochs pub epochs: usize, /// Batch size pub batch_size: usize, /// Learning rate pub learning_rate: f32, /// Loss type pub loss_type: LossType, /// Optimizer type pub optimizer_type: OptimizerType, } impl Default for TrainingConfig { fn default() -> Self { Self { epochs: 100, batch_size: 32, learning_rate: 0.001, loss_type: LossType::Mse, optimizer_type: OptimizerType::Adam { learning_rate: 0.001, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, }, } } } /// Configuration for contrastive learning training #[derive(Debug, Clone)] pub struct TrainConfig { /// Batch size for training pub batch_size: usize, /// Number of negative samples per positive pub n_negatives: usize, /// Temperature parameter for contrastive loss pub temperature: f32, /// Learning rate for optimization pub learning_rate: f32, /// Number of updates before flushing to storage pub flush_threshold: usize, } impl Default for TrainConfig { fn default() -> Self { Self { batch_size: 256, n_negatives: 64, temperature: 0.07, learning_rate: 0.001, flush_threshold: 1000, } } } /// Configuration for online learning #[derive(Debug, Clone)] pub struct OnlineConfig { /// Number of local optimization steps pub local_steps: usize, /// Whether to propagate updates to neighbors pub propagate_updates: bool, } impl Default for OnlineConfig { fn default() -> Self { Self { local_steps: 5, propagate_updates: true, } } } /// Compute InfoNCE contrastive loss /// /// InfoNCE (Information Noise-Contrastive Estimation) loss is used for contrastive learning. /// It maximizes agreement between anchor and positive samples while minimizing agreement /// with negative samples. /// /// # Arguments /// * `anchor` - The anchor embedding vector /// * `positives` - Positive example embeddings (similar to anchor) /// * `negatives` - Negative example embeddings (dissimilar to anchor) /// * `temperature` - Temperature scaling parameter (lower = sharper distinctions) /// /// # Returns /// * The computed loss value (lower is better) /// /// # Example /// ``` /// use ruvector_gnn::training::info_nce_loss; /// /// let anchor = vec![1.0, 0.0, 0.0]; /// let positive = vec![0.9, 0.1, 0.0]; /// let negative1 = vec![0.0, 1.0, 0.0]; /// let negative2 = vec![0.0, 0.0, 1.0]; /// /// let loss = info_nce_loss( /// &anchor, /// &[&positive], /// &[&negative1, &negative2], /// 0.07 /// ); /// assert!(loss > 0.0); /// ``` pub fn info_nce_loss( anchor: &[f32], positives: &[&[f32]], negatives: &[&[f32]], temperature: f32, ) -> f32 { if positives.is_empty() { return 0.0; } // Compute similarities with positives (scaled by temperature) let pos_sims: Vec = positives .iter() .map(|pos| cosine_similarity(anchor, pos) / temperature) .collect(); // Compute similarities with negatives (scaled by temperature) let neg_sims: Vec = negatives .iter() .map(|neg| cosine_similarity(anchor, neg) / temperature) .collect(); // For each positive, compute the InfoNCE loss using log-sum-exp trick for numerical stability let mut total_loss = 0.0; for &pos_sim in &pos_sims { // Use log-sum-exp trick to avoid overflow // log(exp(pos_sim) / (exp(pos_sim) + sum(exp(neg_sim)))) // = pos_sim - log(exp(pos_sim) + sum(exp(neg_sim))) // = pos_sim - log_sum_exp([pos_sim, neg_sims...]) // Collect all logits for log-sum-exp let mut all_logits = vec![pos_sim]; all_logits.extend(&neg_sims); // Compute log-sum-exp with numerical stability let max_logit = all_logits.iter().copied().fold(f32::NEG_INFINITY, f32::max); let log_sum_exp = max_logit + all_logits .iter() .map(|&x| (x - max_logit).exp()) .sum::() .ln(); // Loss = -log(exp(pos_sim) / sum_exp) = -(pos_sim - log_sum_exp) total_loss -= pos_sim - log_sum_exp; } // Average over positives total_loss / positives.len() as f32 } /// Compute local contrastive loss for graph structures /// /// This loss encourages node embeddings to be similar to their neighbors /// and dissimilar to non-neighbors in the graph. /// /// # Arguments /// * `node_embedding` - The embedding of the target node /// * `neighbor_embeddings` - Embeddings of neighbor nodes /// * `non_neighbor_embeddings` - Embeddings of non-neighbor nodes /// * `temperature` - Temperature scaling parameter /// /// # Returns /// * The computed loss value (lower is better) /// /// # Example /// ``` /// use ruvector_gnn::training::local_contrastive_loss; /// /// let node = vec![1.0, 0.0, 0.0]; /// let neighbor = vec![0.9, 0.1, 0.0]; /// let non_neighbor1 = vec![0.0, 1.0, 0.0]; /// let non_neighbor2 = vec![0.0, 0.0, 1.0]; /// /// let loss = local_contrastive_loss( /// &node, /// &[neighbor], /// &[non_neighbor1, non_neighbor2], /// 0.07 /// ); /// assert!(loss > 0.0); /// ``` pub fn local_contrastive_loss( node_embedding: &[f32], neighbor_embeddings: &[Vec], non_neighbor_embeddings: &[Vec], temperature: f32, ) -> f32 { if neighbor_embeddings.is_empty() { return 0.0; } // Convert to slices for info_nce_loss let positives: Vec<&[f32]> = neighbor_embeddings.iter().map(|v| v.as_slice()).collect(); let negatives: Vec<&[f32]> = non_neighbor_embeddings .iter() .map(|v| v.as_slice()) .collect(); info_nce_loss(node_embedding, &positives, &negatives, temperature) } /// Perform a single SGD (Stochastic Gradient Descent) optimization step /// /// Updates the embedding in-place by subtracting the scaled gradient. /// /// # Arguments /// * `embedding` - The embedding to update (modified in-place) /// * `grad` - The gradient vector /// * `learning_rate` - The learning rate (step size) /// /// # Example /// ``` /// use ruvector_gnn::training::sgd_step; /// /// let mut embedding = vec![1.0, 2.0, 3.0]; /// let gradient = vec![0.1, -0.2, 0.3]; /// let learning_rate = 0.01; /// /// sgd_step(&mut embedding, &gradient, learning_rate); /// /// // Embedding is now updated: embedding[i] -= learning_rate * grad[i] /// assert!((embedding[0] - 0.999).abs() < 1e-6); /// assert!((embedding[1] - 2.002).abs() < 1e-6); /// assert!((embedding[2] - 2.997).abs() < 1e-6); /// ``` pub fn sgd_step(embedding: &mut [f32], grad: &[f32], learning_rate: f32) { assert_eq!( embedding.len(), grad.len(), "Embedding and gradient must have the same length" ); for (emb, &g) in embedding.iter_mut().zip(grad.iter()) { *emb -= learning_rate * g; } } #[cfg(test)] mod tests { use super::*; #[test] fn test_train_config_default() { let config = TrainConfig::default(); assert_eq!(config.batch_size, 256); assert_eq!(config.n_negatives, 64); assert_eq!(config.temperature, 0.07); assert_eq!(config.learning_rate, 0.001); assert_eq!(config.flush_threshold, 1000); } #[test] fn test_online_config_default() { let config = OnlineConfig::default(); assert_eq!(config.local_steps, 5); assert!(config.propagate_updates); } #[test] fn test_info_nce_loss_basic() { // Anchor and positive are similar let anchor = vec![1.0, 0.0, 0.0]; let positive = vec![0.9, 0.1, 0.0]; // Negatives are orthogonal let negative1 = vec![0.0, 1.0, 0.0]; let negative2 = vec![0.0, 0.0, 1.0]; let loss = info_nce_loss(&anchor, &[&positive], &[&negative1, &negative2], 0.07); // Loss should be positive assert!(loss > 0.0); // Loss should be reasonable (not infinite or NaN) assert!(loss.is_finite()); } #[test] fn test_info_nce_loss_perfect_match() { // Anchor and positive are identical let anchor = vec![1.0, 0.0, 0.0]; let positive = vec![1.0, 0.0, 0.0]; // Negatives are very different let negative1 = vec![0.0, 1.0, 0.0]; let negative2 = vec![0.0, 0.0, 1.0]; let loss = info_nce_loss(&anchor, &[&positive], &[&negative1, &negative2], 0.07); // Loss should be lower for perfect match assert!(loss < 1.0); assert!(loss.is_finite()); } #[test] fn test_info_nce_loss_no_positives() { let anchor = vec![1.0, 0.0, 0.0]; let negative1 = vec![0.0, 1.0, 0.0]; let loss = info_nce_loss(&anchor, &[], &[&negative1], 0.07); // Should return 0.0 when no positives assert_eq!(loss, 0.0); } #[test] fn test_info_nce_loss_temperature_effect() { let anchor = vec![1.0, 0.0, 0.0]; let positive = vec![0.9, 0.1, 0.0]; let negative = vec![0.0, 1.0, 0.0]; // Test with reasonable temperature values // Very low temperatures can cause numerical issues, so we use 0.07 (standard) and 1.0 let loss_low_temp = info_nce_loss(&anchor, &[&positive], &[&negative], 0.07); let loss_high_temp = info_nce_loss(&anchor, &[&positive], &[&negative], 1.0); // Both should be positive and finite assert!( loss_low_temp > 0.0 && loss_low_temp.is_finite(), "Low temp loss should be positive and finite, got: {}", loss_low_temp ); assert!( loss_high_temp > 0.0 && loss_high_temp.is_finite(), "High temp loss should be positive and finite, got: {}", loss_high_temp ); // With standard temperature, the loss should be reasonable assert!(loss_low_temp < 10.0, "Loss should not be too large"); assert!(loss_high_temp < 10.0, "Loss should not be too large"); } #[test] fn test_local_contrastive_loss_basic() { let node = vec![1.0, 0.0, 0.0]; let neighbor = vec![0.9, 0.1, 0.0]; let non_neighbor1 = vec![0.0, 1.0, 0.0]; let non_neighbor2 = vec![0.0, 0.0, 1.0]; let loss = local_contrastive_loss(&node, &[neighbor], &[non_neighbor1, non_neighbor2], 0.07); // Loss should be positive and finite assert!(loss > 0.0); assert!(loss.is_finite()); } #[test] fn test_local_contrastive_loss_multiple_neighbors() { let node = vec![1.0, 0.0, 0.0]; let neighbor1 = vec![0.9, 0.1, 0.0]; let neighbor2 = vec![0.95, 0.05, 0.0]; let non_neighbor = vec![0.0, 1.0, 0.0]; let loss = local_contrastive_loss(&node, &[neighbor1, neighbor2], &[non_neighbor], 0.07); assert!(loss > 0.0); assert!(loss.is_finite()); } #[test] fn test_local_contrastive_loss_no_neighbors() { let node = vec![1.0, 0.0, 0.0]; let non_neighbor = vec![0.0, 1.0, 0.0]; let loss = local_contrastive_loss(&node, &[], &[non_neighbor], 0.07); // Should return 0.0 when no neighbors assert_eq!(loss, 0.0); } #[test] fn test_sgd_step_basic() { let mut embedding = vec![1.0, 2.0, 3.0]; let gradient = vec![0.1, -0.2, 0.3]; let learning_rate = 0.01; sgd_step(&mut embedding, &gradient, learning_rate); // Expected: embedding[i] -= learning_rate * grad[i] assert!((embedding[0] - 0.999).abs() < 1e-6); // 1.0 - 0.01 * 0.1 assert!((embedding[1] - 2.002).abs() < 1e-6); // 2.0 - 0.01 * (-0.2) assert!((embedding[2] - 2.997).abs() < 1e-6); // 3.0 - 0.01 * 0.3 } #[test] fn test_sgd_step_zero_gradient() { let mut embedding = vec![1.0, 2.0, 3.0]; let original = embedding.clone(); let gradient = vec![0.0, 0.0, 0.0]; let learning_rate = 0.01; sgd_step(&mut embedding, &gradient, learning_rate); // Embedding should not change with zero gradient assert_eq!(embedding, original); } #[test] fn test_sgd_step_zero_learning_rate() { let mut embedding = vec![1.0, 2.0, 3.0]; let original = embedding.clone(); let gradient = vec![0.1, 0.2, 0.3]; let learning_rate = 0.0; sgd_step(&mut embedding, &gradient, learning_rate); // Embedding should not change with zero learning rate assert_eq!(embedding, original); } #[test] fn test_sgd_step_large_learning_rate() { let mut embedding = vec![10.0, 20.0, 30.0]; let gradient = vec![1.0, 2.0, 3.0]; let learning_rate = 5.0; sgd_step(&mut embedding, &gradient, learning_rate); // Expected: embedding[i] -= learning_rate * grad[i] assert!((embedding[0] - 5.0).abs() < 1e-5); // 10.0 - 5.0 * 1.0 assert!((embedding[1] - 10.0).abs() < 1e-5); // 20.0 - 5.0 * 2.0 assert!((embedding[2] - 15.0).abs() < 1e-5); // 30.0 - 5.0 * 3.0 } #[test] #[should_panic(expected = "Embedding and gradient must have the same length")] fn test_sgd_step_mismatched_lengths() { let mut embedding = vec![1.0, 2.0, 3.0]; let gradient = vec![0.1, 0.2]; // Wrong length sgd_step(&mut embedding, &gradient, 0.01); } #[test] fn test_info_nce_loss_multiple_positives() { let anchor = vec![1.0, 0.0, 0.0]; let positive1 = vec![0.9, 0.1, 0.0]; let positive2 = vec![0.95, 0.05, 0.0]; let negative = vec![0.0, 1.0, 0.0]; let loss = info_nce_loss(&anchor, &[&positive1, &positive2], &[&negative], 0.07); // Loss should be positive and finite assert!(loss > 0.0); assert!(loss.is_finite()); } #[test] fn test_contrastive_loss_gradient_property() { // Test that loss decreases when positive becomes more similar let anchor = vec![1.0, 0.0, 0.0]; let positive_far = vec![0.5, 0.5, 0.0]; let positive_close = vec![0.9, 0.1, 0.0]; let negative = vec![0.0, 1.0, 0.0]; let loss_far = info_nce_loss(&anchor, &[&positive_far], &[&negative], 0.07); let loss_close = info_nce_loss(&anchor, &[&positive_close], &[&negative], 0.07); // Loss should be lower when positive is closer to anchor assert!(loss_close < loss_far); } #[test] fn test_sgd_optimizer_basic() { let optimizer_type = OptimizerType::Sgd { learning_rate: 0.1, momentum: 0.0, }; let mut optimizer = Optimizer::new(optimizer_type); let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap(); let result = optimizer.step(&mut params, &grads); assert!(result.is_ok()); // Expected: params[i] -= learning_rate * grads[i] assert!((params[[0, 0]] - 0.99).abs() < 1e-6); // 1.0 - 0.1 * 0.1 assert!((params[[0, 1]] - 1.98).abs() < 1e-6); // 2.0 - 0.1 * 0.2 assert!((params[[1, 0]] - 2.97).abs() < 1e-6); // 3.0 - 0.1 * 0.3 assert!((params[[1, 1]] - 3.96).abs() < 1e-6); // 4.0 - 0.1 * 0.4 } #[test] fn test_sgd_optimizer_with_momentum() { let optimizer_type = OptimizerType::Sgd { learning_rate: 0.1, momentum: 0.9, }; let mut optimizer = Optimizer::new(optimizer_type); let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap(); // First step let result = optimizer.step(&mut params, &grads); assert!(result.is_ok()); // First step should be same as SGD without momentum (velocity starts at 0) assert!((params[[0, 0]] - 0.99).abs() < 1e-6); // Second step should use accumulated momentum let result = optimizer.step(&mut params, &grads); assert!(result.is_ok()); // With momentum, the update should be larger assert!(params[[0, 0]] < 0.99); } #[test] fn test_adam_optimizer_basic() { let optimizer_type = OptimizerType::Adam { learning_rate: 0.001, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, }; let mut optimizer = Optimizer::new(optimizer_type); let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap(); let original_params = params.clone(); let result = optimizer.step(&mut params, &grads); assert!(result.is_ok()); // Parameters should be updated (decreased in the direction of gradients) assert!(params[[0, 0]] < original_params[[0, 0]]); assert!(params[[0, 1]] < original_params[[0, 1]]); assert!(params[[1, 0]] < original_params[[1, 0]]); assert!(params[[1, 1]] < original_params[[1, 1]]); // Check that all values are finite assert!(params.iter().all(|&x| x.is_finite())); } #[test] fn test_adam_optimizer_multiple_steps() { let optimizer_type = OptimizerType::Adam { learning_rate: 0.01, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, }; let mut optimizer = Optimizer::new(optimizer_type); let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap(); let initial_params = params.clone(); // Perform multiple steps for _ in 0..10 { let result = optimizer.step(&mut params, &grads); assert!(result.is_ok()); assert!(params.iter().all(|&x| x.is_finite())); } // After multiple steps, parameters should have decreased (gradients are positive) assert!(params[[0, 0]] < initial_params[[0, 0]]); assert!(params[[1, 1]] < initial_params[[1, 1]]); // All parameters should have moved for i in 0..2 { for j in 0..2 { assert!(params[[i, j]] < initial_params[[i, j]]); } } } #[test] fn test_adam_bias_correction() { let optimizer_type = OptimizerType::Adam { learning_rate: 0.001, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, }; let mut optimizer = Optimizer::new(optimizer_type.clone()); let mut params = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap(); let grads = Array2::from_shape_vec((1, 1), vec![0.1]).unwrap(); // First step should have strong bias correction let result = optimizer.step(&mut params, &grads); assert!(result.is_ok()); let first_update = 1.0 - params[[0, 0]]; // Reset optimizer let mut optimizer = Optimizer::new(optimizer_type); let mut params = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap(); // Perform 100 steps, last step should have less bias correction effect for _ in 0..100 { let _ = optimizer.step(&mut params, &grads); } // The bias correction effect should diminish over time assert!(first_update > 0.0); } #[test] fn test_optimizer_shape_mismatch() { let optimizer_type = OptimizerType::Adam { learning_rate: 0.001, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, }; let mut optimizer = Optimizer::new(optimizer_type); let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let grads = Array2::from_shape_vec((3, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap(); let result = optimizer.step(&mut params, &grads); assert!(result.is_err()); if let Err(GnnError::DimensionMismatch { expected, actual }) = result { assert!(expected.contains("2, 2")); assert!(actual.contains("3, 2")); } else { panic!("Expected DimensionMismatch error"); } } #[test] fn test_adam_convergence() { // Test that Adam can minimize a simple quadratic function let optimizer_type = OptimizerType::Adam { learning_rate: 0.5, beta1: 0.9, beta2: 0.999, epsilon: 1e-8, }; let mut optimizer = Optimizer::new(optimizer_type); // Start with params far from optimum (0, 0) let mut params = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap(); // Gradient of f(x, y) = x^2 + y^2 is (2x, 2y) for _ in 0..200 { let grads = Array2::from_shape_vec((1, 2), vec![2.0 * params[[0, 0]], 2.0 * params[[0, 1]]]) .unwrap(); let _ = optimizer.step(&mut params, &grads); } // Should converge close to (0, 0) assert!(params[[0, 0]].abs() < 0.5); assert!(params[[0, 1]].abs() < 0.5); } #[test] fn test_sgd_momentum_convergence() { // Test that SGD with momentum can minimize a simple quadratic function let optimizer_type = OptimizerType::Sgd { learning_rate: 0.01, momentum: 0.9, }; let mut optimizer = Optimizer::new(optimizer_type); // Start with params far from optimum (0, 0) let mut params = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap(); // Gradient of f(x, y) = x^2 + y^2 is (2x, 2y) for _ in 0..200 { let grads = Array2::from_shape_vec((1, 2), vec![2.0 * params[[0, 0]], 2.0 * params[[0, 1]]]) .unwrap(); let _ = optimizer.step(&mut params, &grads); } // Should converge close to (0, 0) assert!(params[[0, 0]].abs() < 0.5); assert!(params[[0, 1]].abs() < 0.5); } // ==================== Loss Function Tests ==================== #[test] fn test_mse_loss_zero_when_equal() { let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let target = pred.clone(); let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); assert!( (loss - 0.0).abs() < 1e-6, "MSE should be 0 when pred == target" ); } #[test] fn test_mse_loss_positive() { let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let target = Array2::from_shape_vec((2, 2), vec![2.0, 3.0, 4.0, 5.0]).unwrap(); let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); // Each element differs by 1, so squared diff = 1, mean = 1 assert!((loss - 1.0).abs() < 1e-6, "MSE should be 1.0, got {}", loss); } #[test] fn test_mse_loss_varying_diffs() { let pred = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 0.0]).unwrap(); let target = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); // Squared diffs: 1, 4, 9, 16. Mean = 30/4 = 7.5 assert!((loss - 7.5).abs() < 1e-6, "MSE should be 7.5, got {}", loss); } #[test] fn test_mse_gradient_shape() { let pred = Array2::from_shape_vec((2, 3), vec![0.0; 6]).unwrap(); let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); assert_eq!(grad.shape(), pred.shape()); } #[test] fn test_mse_gradient_direction() { let pred = Array2::from_shape_vec((1, 2), vec![0.0, 2.0]).unwrap(); let target = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap(); let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); // grad = 2*(pred - target)/n = 2*(-1, 1)/2 = (-1, 1) assert!( grad[[0, 0]] < 0.0, "Gradient should be negative when pred < target" ); assert!( grad[[0, 1]] > 0.0, "Gradient should be positive when pred > target" ); } #[test] fn test_mse_gradient_zero_when_equal() { let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); let target = pred.clone(); let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); assert!( grad.iter().all(|&x| x.abs() < 1e-6), "Gradient should be zero when pred == target" ); } #[test] fn test_bce_loss_perfect_predictions() { let pred = Array2::from_shape_vec((1, 2), vec![0.999, 0.001]).unwrap(); let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); // Near-perfect predictions should have low loss assert!( loss < 0.1, "BCE should be low for good predictions, got {}", loss ); } #[test] fn test_bce_loss_bad_predictions() { let pred = Array2::from_shape_vec((1, 2), vec![0.001, 0.999]).unwrap(); let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); // Bad predictions should have high loss assert!( loss > 1.0, "BCE should be high for bad predictions, got {}", loss ); } #[test] fn test_bce_loss_numerical_stability() { // Test with extreme values that could cause numerical issues let pred = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap(); let target = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap(); let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); assert!( loss.is_finite(), "BCE should be finite even with extreme values" ); } #[test] fn test_bce_gradient_shape() { let pred = Array2::from_shape_vec((3, 2), vec![0.5; 6]).unwrap(); let target = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]).unwrap(); let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); assert_eq!(grad.shape(), pred.shape()); } #[test] fn test_bce_gradient_direction() { let pred = Array2::from_shape_vec((1, 2), vec![0.3, 0.7]).unwrap(); let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); // When target=1 and pred<1, gradient should push pred up (negative gradient) assert!( grad[[0, 0]] < 0.0, "Gradient should be negative to increase pred towards 1" ); // When target=0 and pred>0, gradient should push pred down (positive gradient) assert!( grad[[0, 1]] > 0.0, "Gradient should be positive to decrease pred towards 0" ); } #[test] fn test_cross_entropy_one_hot() { // Softmax-like predictions (sum to 1 per row) let pred = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.8, 0.1]).unwrap(); let target = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap(); let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap(); // Good predictions should have reasonable loss assert!( loss > 0.0 && loss < 1.0, "CE should be reasonable for good predictions, got {}", loss ); } #[test] fn test_cross_entropy_wrong_class() { let pred = Array2::from_shape_vec((1, 3), vec![0.1, 0.1, 0.8]).unwrap(); let target = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap(); let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap(); // Predicting wrong class should have high loss assert!( loss > 1.0, "CE should be high for wrong predictions, got {}", loss ); } #[test] fn test_cross_entropy_gradient_shape() { let pred = Array2::from_shape_vec((2, 4), vec![0.25; 8]).unwrap(); let target = Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap(); let grad = Loss::gradient(LossType::CrossEntropy, &pred, &target).unwrap(); assert_eq!(grad.shape(), pred.shape()); } #[test] fn test_loss_dimension_mismatch_error() { let pred = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap(); let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); let result = Loss::compute(LossType::Mse, &pred, &target); assert!(result.is_err(), "Should error on dimension mismatch"); let result = Loss::gradient(LossType::Mse, &pred, &target); assert!( result.is_err(), "Gradient should error on dimension mismatch" ); } #[test] fn test_loss_empty_array_error() { let pred = Array2::from_shape_vec((0, 2), vec![]).unwrap(); let target = Array2::from_shape_vec((0, 2), vec![]).unwrap(); let result = Loss::compute(LossType::Mse, &pred, &target); assert!(result.is_err(), "Should error on empty arrays"); let result = Loss::gradient(LossType::Mse, &pred, &target); assert!(result.is_err(), "Gradient should error on empty arrays"); } #[test] fn test_loss_gradient_numerical_check() { // Numerical gradient check for MSE let pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.8]).unwrap(); let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); let analytical_grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); // Compute numerical gradient let eps = 1e-5; for i in 0..2 { let mut pred_plus = pred.clone(); let mut pred_minus = pred.clone(); pred_plus[[0, i]] += eps; pred_minus[[0, i]] -= eps; let loss_plus = Loss::compute(LossType::Mse, &pred_plus, &target).unwrap(); let loss_minus = Loss::compute(LossType::Mse, &pred_minus, &target).unwrap(); let numerical_grad = (loss_plus - loss_minus) / (2.0 * eps); let error = (analytical_grad[[0, i]] - numerical_grad).abs(); assert!( error < 1e-3, "Numerical gradient check failed: analytical={}, numerical={}", analytical_grad[[0, i]], numerical_grad ); } } #[test] fn test_training_loop_integration() { // Integration test: use Loss with Optimizer let mut optimizer = Optimizer::new(OptimizerType::Sgd { learning_rate: 0.1, momentum: 0.0, }); let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); let mut pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap(); let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); // Perform a few optimization steps for _ in 0..10 { let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); optimizer.step(&mut pred, &grad).unwrap(); } let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); assert!( final_loss < initial_loss, "Loss should decrease during training" ); } }