Files
wifi-densepose/vendor/ruvector/crates/ruvector-gnn/src/ewc.rs

583 lines
17 KiB
Rust

/// Elastic Weight Consolidation (EWC) for preventing catastrophic forgetting in GNNs
///
/// EWC adds a regularization term that penalizes changes to important weights,
/// where importance is measured by the Fisher information matrix diagonal.
///
/// The EWC loss term is: L_EWC = λ/2 * Σ F_i * (θ_i - θ*_i)²
/// where:
/// - λ is the regularization strength
/// - F_i is the Fisher information for weight i
/// - θ_i is the current weight
/// - θ*_i is the anchor weight from the previous task
use std::f32;
/// Elastic Weight Consolidation implementation
///
/// Prevents catastrophic forgetting by penalizing changes to important weights
/// learned from previous tasks.
#[derive(Debug, Clone)]
pub struct ElasticWeightConsolidation {
/// Fisher information diagonal (importance of each weight)
/// Higher values indicate more important weights
fisher_diag: Vec<f32>,
/// Anchor weights (optimal weights from previous task)
/// These are the weights we want to stay close to
anchor_weights: Vec<f32>,
/// Regularization strength (λ)
/// Controls how strongly we penalize deviations from anchor weights
lambda: f32,
/// Whether EWC is active
/// EWC is only active after consolidation has been called
active: bool,
}
impl ElasticWeightConsolidation {
/// Create a new EWC instance with specified regularization strength
///
/// # Arguments
/// * `lambda` - Regularization strength (typically 10-10000)
///
/// # Returns
/// A new inactive EWC instance
pub fn new(lambda: f32) -> Self {
assert!(lambda >= 0.0, "Lambda must be non-negative");
Self {
fisher_diag: Vec::new(),
anchor_weights: Vec::new(),
lambda,
active: false,
}
}
/// Compute Fisher information diagonal from gradients
///
/// The Fisher information measures the importance of each weight.
/// It's approximated as the mean squared gradient over samples:
/// F_i ≈ (1/N) * Σ (∂L/∂θ_i)²
///
/// # Arguments
/// * `gradients` - Slice of gradient vectors for each sample
/// * `sample_count` - Number of samples (for normalization)
pub fn compute_fisher(&mut self, gradients: &[&[f32]], sample_count: usize) {
if gradients.is_empty() {
return;
}
let num_weights = gradients[0].len();
// Always reset Fisher diagonal to zero before computing
// (Fisher information should be computed fresh from current gradients)
self.fisher_diag = vec![0.0; num_weights];
// Accumulate squared gradients
for grad in gradients {
assert_eq!(
grad.len(),
num_weights,
"All gradient vectors must have the same length"
);
for (i, &g) in grad.iter().enumerate() {
self.fisher_diag[i] += g * g;
}
}
// Normalize by sample count
let normalization = 1.0 / (sample_count as f32).max(1.0);
for f in &mut self.fisher_diag {
*f *= normalization;
}
}
/// Save current weights as anchor and activate EWC
///
/// This should be called after training on a task, before moving to the next task.
/// It marks the current weights as important and activates the EWC penalty.
///
/// # Arguments
/// * `weights` - Current model weights to save as anchor
pub fn consolidate(&mut self, weights: &[f32]) {
assert!(
!self.fisher_diag.is_empty(),
"Must compute Fisher information before consolidating"
);
assert_eq!(
weights.len(),
self.fisher_diag.len(),
"Weight count must match Fisher information size"
);
self.anchor_weights = weights.to_vec();
self.active = true;
}
/// Compute EWC penalty term
///
/// Returns: λ/2 * Σ F_i * (θ_i - θ*_i)²
///
/// This penalty is added to the loss function to discourage changes
/// to important weights.
///
/// # Arguments
/// * `weights` - Current model weights
///
/// # Returns
/// The EWC penalty value (0.0 if not active)
pub fn penalty(&self, weights: &[f32]) -> f32 {
if !self.active {
return 0.0;
}
assert_eq!(
weights.len(),
self.anchor_weights.len(),
"Weight count must match anchor weights"
);
let mut penalty = 0.0;
for i in 0..weights.len() {
let diff = weights[i] - self.anchor_weights[i];
penalty += self.fisher_diag[i] * diff * diff;
}
// Multiply by λ/2
penalty * self.lambda * 0.5
}
/// Compute EWC gradient
///
/// Returns: λ * F_i * (θ_i - θ*_i) for each weight i
///
/// This gradient is added to the model gradients during training
/// to push weights back toward their anchor values.
///
/// # Arguments
/// * `weights` - Current model weights
///
/// # Returns
/// Gradient vector (all zeros if not active)
pub fn gradient(&self, weights: &[f32]) -> Vec<f32> {
if !self.active {
return vec![0.0; weights.len()];
}
assert_eq!(
weights.len(),
self.anchor_weights.len(),
"Weight count must match anchor weights"
);
let mut grad = Vec::with_capacity(weights.len());
for i in 0..weights.len() {
let diff = weights[i] - self.anchor_weights[i];
grad.push(self.lambda * self.fisher_diag[i] * diff);
}
grad
}
/// Check if EWC is active
///
/// # Returns
/// true if consolidate() has been called, false otherwise
pub fn is_active(&self) -> bool {
self.active
}
/// Get the regularization strength
pub fn lambda(&self) -> f32 {
self.lambda
}
/// Update the regularization strength
pub fn set_lambda(&mut self, lambda: f32) {
assert!(lambda >= 0.0, "Lambda must be non-negative");
self.lambda = lambda;
}
/// Get the Fisher information diagonal
pub fn fisher_diag(&self) -> &[f32] {
&self.fisher_diag
}
/// Get the anchor weights
pub fn anchor_weights(&self) -> &[f32] {
&self.anchor_weights
}
/// Reset EWC to inactive state
pub fn reset(&mut self) {
self.fisher_diag.clear();
self.anchor_weights.clear();
self.active = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let ewc = ElasticWeightConsolidation::new(1000.0);
assert_eq!(ewc.lambda(), 1000.0);
assert!(!ewc.is_active());
assert!(ewc.fisher_diag().is_empty());
assert!(ewc.anchor_weights().is_empty());
}
#[test]
#[should_panic(expected = "Lambda must be non-negative")]
fn test_new_negative_lambda() {
ElasticWeightConsolidation::new(-1.0);
}
#[test]
fn test_compute_fisher_single_sample() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Single gradient: [1.0, 2.0, 3.0]
let grad1 = vec![1.0, 2.0, 3.0];
let gradients = vec![grad1.as_slice()];
ewc.compute_fisher(&gradients, 1);
// Fisher should be squared gradients
assert_eq!(ewc.fisher_diag(), &[1.0, 4.0, 9.0]);
}
#[test]
fn test_compute_fisher_multiple_samples() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Two gradients
let grad1 = vec![1.0, 2.0, 3.0];
let grad2 = vec![2.0, 1.0, 1.0];
let gradients = vec![grad1.as_slice(), grad2.as_slice()];
ewc.compute_fisher(&gradients, 2);
// Fisher should be mean of squared gradients
// Position 0: (1² + 2²) / 2 = 2.5
// Position 1: (2² + 1²) / 2 = 2.5
// Position 2: (3² + 1²) / 2 = 5.0
let expected = vec![2.5, 2.5, 5.0];
assert_eq!(ewc.fisher_diag().len(), expected.len());
for (actual, exp) in ewc.fisher_diag().iter().zip(expected.iter()) {
assert!((actual - exp).abs() < 1e-6);
}
}
#[test]
fn test_compute_fisher_accumulates() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// First computation
let grad1 = vec![1.0, 2.0];
ewc.compute_fisher(&[grad1.as_slice()], 1);
assert_eq!(ewc.fisher_diag(), &[1.0, 4.0]);
// Second computation accumulates on top of first
// When fisher_diag has same length, it's reset to zero first in compute_fisher
// then accumulates: 0 + 2^2 = 4, 0 + 1^2 = 1
// normalized by 1/1 = 4.0, 1.0
let grad2 = vec![2.0, 1.0];
ewc.compute_fisher(&[grad2.as_slice()], 1);
// Fisher is reset and recomputed with new gradients
assert_eq!(ewc.fisher_diag(), &[4.0, 1.0]);
}
#[test]
#[should_panic(expected = "All gradient vectors must have the same length")]
fn test_compute_fisher_mismatched_sizes() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let grad1 = vec![1.0, 2.0];
let grad2 = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad1.as_slice(), grad2.as_slice()], 2);
}
#[test]
fn test_consolidate() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Setup Fisher information
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
// Consolidate weights
let weights = vec![0.5, 1.0, 1.5];
ewc.consolidate(&weights);
assert!(ewc.is_active());
assert_eq!(ewc.anchor_weights(), &weights);
}
#[test]
#[should_panic(expected = "Must compute Fisher information before consolidating")]
fn test_consolidate_without_fisher() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let weights = vec![1.0, 2.0];
ewc.consolidate(&weights);
}
#[test]
#[should_panic(expected = "Weight count must match Fisher information size")]
fn test_consolidate_size_mismatch() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let grad = vec![1.0, 2.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let weights = vec![1.0, 2.0, 3.0]; // Wrong size
ewc.consolidate(&weights);
}
#[test]
fn test_penalty_inactive() {
let ewc = ElasticWeightConsolidation::new(100.0);
let weights = vec![1.0, 2.0, 3.0];
assert_eq!(ewc.penalty(&weights), 0.0);
}
#[test]
fn test_penalty_no_deviation() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Setup
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let weights = vec![0.5, 1.0, 1.5];
ewc.consolidate(&weights);
// Penalty should be 0 when weights match anchor
assert_eq!(ewc.penalty(&weights), 0.0);
}
#[test]
fn test_penalty_with_deviation() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Fisher diagonal: [1.0, 4.0, 9.0]
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
// Anchor weights: [0.0, 0.0, 0.0]
let anchor = vec![0.0, 0.0, 0.0];
ewc.consolidate(&anchor);
// Current weights: [1.0, 1.0, 1.0]
let weights = vec![1.0, 1.0, 1.0];
// Penalty = λ/2 * Σ F_i * (w_i - w*_i)²
// = 100/2 * (1.0 * 1² + 4.0 * 1² + 9.0 * 1²)
// = 50 * 14 = 700
let penalty = ewc.penalty(&weights);
assert!((penalty - 700.0).abs() < 1e-4);
}
#[test]
fn test_penalty_increases_with_deviation() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let grad = vec![1.0, 1.0, 1.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let anchor = vec![0.0, 0.0, 0.0];
ewc.consolidate(&anchor);
// Small deviation
let weights1 = vec![0.1, 0.1, 0.1];
let penalty1 = ewc.penalty(&weights1);
// Larger deviation
let weights2 = vec![0.5, 0.5, 0.5];
let penalty2 = ewc.penalty(&weights2);
// Penalty should increase
assert!(penalty2 > penalty1);
// Penalty should scale quadratically
// (0.5/0.1)² = 25
assert!((penalty2 / penalty1 - 25.0).abs() < 1e-4);
}
#[test]
fn test_gradient_inactive() {
let ewc = ElasticWeightConsolidation::new(100.0);
let weights = vec![1.0, 2.0, 3.0];
let grad = ewc.gradient(&weights);
assert_eq!(grad, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_gradient_no_deviation() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let weights = vec![0.5, 1.0, 1.5];
ewc.consolidate(&weights);
// Gradient should be 0 when weights match anchor
let grad = ewc.gradient(&weights);
assert_eq!(grad, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_gradient_points_toward_anchor() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Fisher diagonal: [1.0, 4.0, 9.0]
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
// Anchor at origin
let anchor = vec![0.0, 0.0, 0.0];
ewc.consolidate(&anchor);
// Weights moved positive
let weights = vec![1.0, 1.0, 1.0];
// Gradient = λ * F_i * (w_i - w*_i)
// = 100 * [1.0, 4.0, 9.0] * [1.0, 1.0, 1.0]
// = [100, 400, 900]
let grad = ewc.gradient(&weights);
assert_eq!(grad.len(), 3);
assert!((grad[0] - 100.0).abs() < 1e-4);
assert!((grad[1] - 400.0).abs() < 1e-4);
assert!((grad[2] - 900.0).abs() < 1e-4);
// Weights moved negative
let weights = vec![-1.0, -1.0, -1.0];
let grad = ewc.gradient(&weights);
// Gradient should point opposite direction (toward anchor)
assert!(grad[0] < 0.0);
assert!(grad[1] < 0.0);
assert!(grad[2] < 0.0);
assert!((grad[0] + 100.0).abs() < 1e-4);
assert!((grad[1] + 400.0).abs() < 1e-4);
assert!((grad[2] + 900.0).abs() < 1e-4);
}
#[test]
fn test_gradient_magnitude_scales_with_fisher() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Fisher with varying importance
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let anchor = vec![0.0, 0.0, 0.0];
ewc.consolidate(&anchor);
let weights = vec![1.0, 1.0, 1.0];
let grad = ewc.gradient(&weights);
// Gradient magnitude should increase with Fisher importance
assert!(grad[0].abs() < grad[1].abs());
assert!(grad[1].abs() < grad[2].abs());
}
#[test]
fn test_lambda_scaling() {
let mut ewc1 = ElasticWeightConsolidation::new(100.0);
let mut ewc2 = ElasticWeightConsolidation::new(200.0);
// Same setup for both
let grad = vec![1.0, 1.0, 1.0];
ewc1.compute_fisher(&[grad.as_slice()], 1);
ewc2.compute_fisher(&[grad.as_slice()], 1);
let anchor = vec![0.0, 0.0, 0.0];
ewc1.consolidate(&anchor);
ewc2.consolidate(&anchor);
let weights = vec![1.0, 1.0, 1.0];
// Penalty and gradient should scale with lambda
let penalty1 = ewc1.penalty(&weights);
let penalty2 = ewc2.penalty(&weights);
assert!((penalty2 / penalty1 - 2.0).abs() < 1e-4);
let grad1 = ewc1.gradient(&weights);
let grad2 = ewc2.gradient(&weights);
assert!((grad2[0] / grad1[0] - 2.0).abs() < 1e-4);
}
#[test]
fn test_set_lambda() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
assert_eq!(ewc.lambda(), 100.0);
ewc.set_lambda(500.0);
assert_eq!(ewc.lambda(), 500.0);
}
#[test]
#[should_panic(expected = "Lambda must be non-negative")]
fn test_set_lambda_negative() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
ewc.set_lambda(-10.0);
}
#[test]
fn test_reset() {
let mut ewc = ElasticWeightConsolidation::new(100.0);
// Setup active EWC
let grad = vec![1.0, 2.0, 3.0];
ewc.compute_fisher(&[grad.as_slice()], 1);
let weights = vec![0.5, 1.0, 1.5];
ewc.consolidate(&weights);
assert!(ewc.is_active());
// Reset
ewc.reset();
assert!(!ewc.is_active());
assert!(ewc.fisher_diag().is_empty());
assert!(ewc.anchor_weights().is_empty());
assert_eq!(ewc.lambda(), 100.0); // Lambda preserved
}
#[test]
fn test_sequential_task_learning() {
// Simulate learning two tasks sequentially
let mut ewc = ElasticWeightConsolidation::new(1000.0);
// Task 1: Learn weights [1.0, 2.0, 3.0]
let task1_grad = vec![2.0, 1.0, 3.0];
ewc.compute_fisher(&[task1_grad.as_slice()], 1);
let task1_weights = vec![1.0, 2.0, 3.0];
ewc.consolidate(&task1_weights);
// Task 2: Try to learn very different weights
let task2_weights = vec![5.0, 6.0, 7.0];
// EWC penalty should be significant
let penalty = ewc.penalty(&task2_weights);
assert!(penalty > 10000.0); // Large penalty for large deviation
// Gradient should point back toward task 1 weights
let grad = ewc.gradient(&task2_weights);
assert!(grad[0] > 0.0); // Push toward lower value
assert!(grad[1] > 0.0);
assert!(grad[2] > 0.0);
}
}