Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
582
vendor/ruvector/crates/ruvector-gnn/src/ewc.rs
vendored
Normal file
582
vendor/ruvector/crates/ruvector-gnn/src/ewc.rs
vendored
Normal file
@@ -0,0 +1,582 @@
|
||||
/// Elastic Weight Consolidation (EWC) for preventing catastrophic forgetting in GNNs
|
||||
///
|
||||
/// EWC adds a regularization term that penalizes changes to important weights,
|
||||
/// where importance is measured by the Fisher information matrix diagonal.
|
||||
///
|
||||
/// The EWC loss term is: L_EWC = λ/2 * Σ F_i * (θ_i - θ*_i)²
|
||||
/// where:
|
||||
/// - λ is the regularization strength
|
||||
/// - F_i is the Fisher information for weight i
|
||||
/// - θ_i is the current weight
|
||||
/// - θ*_i is the anchor weight from the previous task
|
||||
use std::f32;
|
||||
|
||||
/// Elastic Weight Consolidation implementation
|
||||
///
|
||||
/// Prevents catastrophic forgetting by penalizing changes to important weights
|
||||
/// learned from previous tasks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticWeightConsolidation {
|
||||
/// Fisher information diagonal (importance of each weight)
|
||||
/// Higher values indicate more important weights
|
||||
fisher_diag: Vec<f32>,
|
||||
|
||||
/// Anchor weights (optimal weights from previous task)
|
||||
/// These are the weights we want to stay close to
|
||||
anchor_weights: Vec<f32>,
|
||||
|
||||
/// Regularization strength (λ)
|
||||
/// Controls how strongly we penalize deviations from anchor weights
|
||||
lambda: f32,
|
||||
|
||||
/// Whether EWC is active
|
||||
/// EWC is only active after consolidation has been called
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl ElasticWeightConsolidation {
|
||||
/// Create a new EWC instance with specified regularization strength
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `lambda` - Regularization strength (typically 10-10000)
|
||||
///
|
||||
/// # Returns
|
||||
/// A new inactive EWC instance
|
||||
pub fn new(lambda: f32) -> Self {
|
||||
assert!(lambda >= 0.0, "Lambda must be non-negative");
|
||||
|
||||
Self {
|
||||
fisher_diag: Vec::new(),
|
||||
anchor_weights: Vec::new(),
|
||||
lambda,
|
||||
active: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute Fisher information diagonal from gradients
|
||||
///
|
||||
/// The Fisher information measures the importance of each weight.
|
||||
/// It's approximated as the mean squared gradient over samples:
|
||||
/// F_i ≈ (1/N) * Σ (∂L/∂θ_i)²
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `gradients` - Slice of gradient vectors for each sample
|
||||
/// * `sample_count` - Number of samples (for normalization)
|
||||
pub fn compute_fisher(&mut self, gradients: &[&[f32]], sample_count: usize) {
|
||||
if gradients.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let num_weights = gradients[0].len();
|
||||
|
||||
// Always reset Fisher diagonal to zero before computing
|
||||
// (Fisher information should be computed fresh from current gradients)
|
||||
self.fisher_diag = vec![0.0; num_weights];
|
||||
|
||||
// Accumulate squared gradients
|
||||
for grad in gradients {
|
||||
assert_eq!(
|
||||
grad.len(),
|
||||
num_weights,
|
||||
"All gradient vectors must have the same length"
|
||||
);
|
||||
|
||||
for (i, &g) in grad.iter().enumerate() {
|
||||
self.fisher_diag[i] += g * g;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize by sample count
|
||||
let normalization = 1.0 / (sample_count as f32).max(1.0);
|
||||
for f in &mut self.fisher_diag {
|
||||
*f *= normalization;
|
||||
}
|
||||
}
|
||||
|
||||
/// Save current weights as anchor and activate EWC
|
||||
///
|
||||
/// This should be called after training on a task, before moving to the next task.
|
||||
/// It marks the current weights as important and activates the EWC penalty.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `weights` - Current model weights to save as anchor
|
||||
pub fn consolidate(&mut self, weights: &[f32]) {
|
||||
assert!(
|
||||
!self.fisher_diag.is_empty(),
|
||||
"Must compute Fisher information before consolidating"
|
||||
);
|
||||
assert_eq!(
|
||||
weights.len(),
|
||||
self.fisher_diag.len(),
|
||||
"Weight count must match Fisher information size"
|
||||
);
|
||||
|
||||
self.anchor_weights = weights.to_vec();
|
||||
self.active = true;
|
||||
}
|
||||
|
||||
/// Compute EWC penalty term
|
||||
///
|
||||
/// Returns: λ/2 * Σ F_i * (θ_i - θ*_i)²
|
||||
///
|
||||
/// This penalty is added to the loss function to discourage changes
|
||||
/// to important weights.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `weights` - Current model weights
|
||||
///
|
||||
/// # Returns
|
||||
/// The EWC penalty value (0.0 if not active)
|
||||
pub fn penalty(&self, weights: &[f32]) -> f32 {
|
||||
if !self.active {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
weights.len(),
|
||||
self.anchor_weights.len(),
|
||||
"Weight count must match anchor weights"
|
||||
);
|
||||
|
||||
let mut penalty = 0.0;
|
||||
|
||||
for i in 0..weights.len() {
|
||||
let diff = weights[i] - self.anchor_weights[i];
|
||||
penalty += self.fisher_diag[i] * diff * diff;
|
||||
}
|
||||
|
||||
// Multiply by λ/2
|
||||
penalty * self.lambda * 0.5
|
||||
}
|
||||
|
||||
/// Compute EWC gradient
|
||||
///
|
||||
/// Returns: λ * F_i * (θ_i - θ*_i) for each weight i
|
||||
///
|
||||
/// This gradient is added to the model gradients during training
|
||||
/// to push weights back toward their anchor values.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `weights` - Current model weights
|
||||
///
|
||||
/// # Returns
|
||||
/// Gradient vector (all zeros if not active)
|
||||
pub fn gradient(&self, weights: &[f32]) -> Vec<f32> {
|
||||
if !self.active {
|
||||
return vec![0.0; weights.len()];
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
weights.len(),
|
||||
self.anchor_weights.len(),
|
||||
"Weight count must match anchor weights"
|
||||
);
|
||||
|
||||
let mut grad = Vec::with_capacity(weights.len());
|
||||
|
||||
for i in 0..weights.len() {
|
||||
let diff = weights[i] - self.anchor_weights[i];
|
||||
grad.push(self.lambda * self.fisher_diag[i] * diff);
|
||||
}
|
||||
|
||||
grad
|
||||
}
|
||||
|
||||
/// Check if EWC is active
|
||||
///
|
||||
/// # Returns
|
||||
/// true if consolidate() has been called, false otherwise
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.active
|
||||
}
|
||||
|
||||
/// Get the regularization strength
|
||||
pub fn lambda(&self) -> f32 {
|
||||
self.lambda
|
||||
}
|
||||
|
||||
/// Update the regularization strength
|
||||
pub fn set_lambda(&mut self, lambda: f32) {
|
||||
assert!(lambda >= 0.0, "Lambda must be non-negative");
|
||||
self.lambda = lambda;
|
||||
}
|
||||
|
||||
/// Get the Fisher information diagonal
|
||||
pub fn fisher_diag(&self) -> &[f32] {
|
||||
&self.fisher_diag
|
||||
}
|
||||
|
||||
/// Get the anchor weights
|
||||
pub fn anchor_weights(&self) -> &[f32] {
|
||||
&self.anchor_weights
|
||||
}
|
||||
|
||||
/// Reset EWC to inactive state
|
||||
pub fn reset(&mut self) {
|
||||
self.fisher_diag.clear();
|
||||
self.anchor_weights.clear();
|
||||
self.active = false;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
let ewc = ElasticWeightConsolidation::new(1000.0);
|
||||
assert_eq!(ewc.lambda(), 1000.0);
|
||||
assert!(!ewc.is_active());
|
||||
assert!(ewc.fisher_diag().is_empty());
|
||||
assert!(ewc.anchor_weights().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Lambda must be non-negative")]
|
||||
fn test_new_negative_lambda() {
|
||||
ElasticWeightConsolidation::new(-1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_fisher_single_sample() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Single gradient: [1.0, 2.0, 3.0]
|
||||
let grad1 = vec![1.0, 2.0, 3.0];
|
||||
let gradients = vec![grad1.as_slice()];
|
||||
|
||||
ewc.compute_fisher(&gradients, 1);
|
||||
|
||||
// Fisher should be squared gradients
|
||||
assert_eq!(ewc.fisher_diag(), &[1.0, 4.0, 9.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_fisher_multiple_samples() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Two gradients
|
||||
let grad1 = vec![1.0, 2.0, 3.0];
|
||||
let grad2 = vec![2.0, 1.0, 1.0];
|
||||
let gradients = vec![grad1.as_slice(), grad2.as_slice()];
|
||||
|
||||
ewc.compute_fisher(&gradients, 2);
|
||||
|
||||
// Fisher should be mean of squared gradients
|
||||
// Position 0: (1² + 2²) / 2 = 2.5
|
||||
// Position 1: (2² + 1²) / 2 = 2.5
|
||||
// Position 2: (3² + 1²) / 2 = 5.0
|
||||
let expected = vec![2.5, 2.5, 5.0];
|
||||
assert_eq!(ewc.fisher_diag().len(), expected.len());
|
||||
for (actual, exp) in ewc.fisher_diag().iter().zip(expected.iter()) {
|
||||
assert!((actual - exp).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_fisher_accumulates() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// First computation
|
||||
let grad1 = vec![1.0, 2.0];
|
||||
ewc.compute_fisher(&[grad1.as_slice()], 1);
|
||||
assert_eq!(ewc.fisher_diag(), &[1.0, 4.0]);
|
||||
|
||||
// Second computation accumulates on top of first
|
||||
// When fisher_diag has same length, it's reset to zero first in compute_fisher
|
||||
// then accumulates: 0 + 2^2 = 4, 0 + 1^2 = 1
|
||||
// normalized by 1/1 = 4.0, 1.0
|
||||
let grad2 = vec![2.0, 1.0];
|
||||
ewc.compute_fisher(&[grad2.as_slice()], 1);
|
||||
// Fisher is reset and recomputed with new gradients
|
||||
assert_eq!(ewc.fisher_diag(), &[4.0, 1.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "All gradient vectors must have the same length")]
|
||||
fn test_compute_fisher_mismatched_sizes() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
let grad1 = vec![1.0, 2.0];
|
||||
let grad2 = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad1.as_slice(), grad2.as_slice()], 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_consolidate() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Setup Fisher information
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
// Consolidate weights
|
||||
let weights = vec![0.5, 1.0, 1.5];
|
||||
ewc.consolidate(&weights);
|
||||
|
||||
assert!(ewc.is_active());
|
||||
assert_eq!(ewc.anchor_weights(), &weights);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Must compute Fisher information before consolidating")]
|
||||
fn test_consolidate_without_fisher() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
let weights = vec![1.0, 2.0];
|
||||
ewc.consolidate(&weights);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Weight count must match Fisher information size")]
|
||||
fn test_consolidate_size_mismatch() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
let grad = vec![1.0, 2.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let weights = vec![1.0, 2.0, 3.0]; // Wrong size
|
||||
ewc.consolidate(&weights);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_penalty_inactive() {
|
||||
let ewc = ElasticWeightConsolidation::new(100.0);
|
||||
let weights = vec![1.0, 2.0, 3.0];
|
||||
|
||||
assert_eq!(ewc.penalty(&weights), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_penalty_no_deviation() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Setup
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let weights = vec![0.5, 1.0, 1.5];
|
||||
ewc.consolidate(&weights);
|
||||
|
||||
// Penalty should be 0 when weights match anchor
|
||||
assert_eq!(ewc.penalty(&weights), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_penalty_with_deviation() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Fisher diagonal: [1.0, 4.0, 9.0]
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
// Anchor weights: [0.0, 0.0, 0.0]
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc.consolidate(&anchor);
|
||||
|
||||
// Current weights: [1.0, 1.0, 1.0]
|
||||
let weights = vec![1.0, 1.0, 1.0];
|
||||
|
||||
// Penalty = λ/2 * Σ F_i * (w_i - w*_i)²
|
||||
// = 100/2 * (1.0 * 1² + 4.0 * 1² + 9.0 * 1²)
|
||||
// = 50 * 14 = 700
|
||||
let penalty = ewc.penalty(&weights);
|
||||
assert!((penalty - 700.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_penalty_increases_with_deviation() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
let grad = vec![1.0, 1.0, 1.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc.consolidate(&anchor);
|
||||
|
||||
// Small deviation
|
||||
let weights1 = vec![0.1, 0.1, 0.1];
|
||||
let penalty1 = ewc.penalty(&weights1);
|
||||
|
||||
// Larger deviation
|
||||
let weights2 = vec![0.5, 0.5, 0.5];
|
||||
let penalty2 = ewc.penalty(&weights2);
|
||||
|
||||
// Penalty should increase
|
||||
assert!(penalty2 > penalty1);
|
||||
|
||||
// Penalty should scale quadratically
|
||||
// (0.5/0.1)² = 25
|
||||
assert!((penalty2 / penalty1 - 25.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_inactive() {
|
||||
let ewc = ElasticWeightConsolidation::new(100.0);
|
||||
let weights = vec![1.0, 2.0, 3.0];
|
||||
|
||||
let grad = ewc.gradient(&weights);
|
||||
assert_eq!(grad, vec![0.0, 0.0, 0.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_no_deviation() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let weights = vec![0.5, 1.0, 1.5];
|
||||
ewc.consolidate(&weights);
|
||||
|
||||
// Gradient should be 0 when weights match anchor
|
||||
let grad = ewc.gradient(&weights);
|
||||
assert_eq!(grad, vec![0.0, 0.0, 0.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_points_toward_anchor() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Fisher diagonal: [1.0, 4.0, 9.0]
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
// Anchor at origin
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc.consolidate(&anchor);
|
||||
|
||||
// Weights moved positive
|
||||
let weights = vec![1.0, 1.0, 1.0];
|
||||
|
||||
// Gradient = λ * F_i * (w_i - w*_i)
|
||||
// = 100 * [1.0, 4.0, 9.0] * [1.0, 1.0, 1.0]
|
||||
// = [100, 400, 900]
|
||||
let grad = ewc.gradient(&weights);
|
||||
assert_eq!(grad.len(), 3);
|
||||
assert!((grad[0] - 100.0).abs() < 1e-4);
|
||||
assert!((grad[1] - 400.0).abs() < 1e-4);
|
||||
assert!((grad[2] - 900.0).abs() < 1e-4);
|
||||
|
||||
// Weights moved negative
|
||||
let weights = vec![-1.0, -1.0, -1.0];
|
||||
let grad = ewc.gradient(&weights);
|
||||
|
||||
// Gradient should point opposite direction (toward anchor)
|
||||
assert!(grad[0] < 0.0);
|
||||
assert!(grad[1] < 0.0);
|
||||
assert!(grad[2] < 0.0);
|
||||
assert!((grad[0] + 100.0).abs() < 1e-4);
|
||||
assert!((grad[1] + 400.0).abs() < 1e-4);
|
||||
assert!((grad[2] + 900.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gradient_magnitude_scales_with_fisher() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Fisher with varying importance
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc.consolidate(&anchor);
|
||||
|
||||
let weights = vec![1.0, 1.0, 1.0];
|
||||
let grad = ewc.gradient(&weights);
|
||||
|
||||
// Gradient magnitude should increase with Fisher importance
|
||||
assert!(grad[0].abs() < grad[1].abs());
|
||||
assert!(grad[1].abs() < grad[2].abs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lambda_scaling() {
|
||||
let mut ewc1 = ElasticWeightConsolidation::new(100.0);
|
||||
let mut ewc2 = ElasticWeightConsolidation::new(200.0);
|
||||
|
||||
// Same setup for both
|
||||
let grad = vec![1.0, 1.0, 1.0];
|
||||
ewc1.compute_fisher(&[grad.as_slice()], 1);
|
||||
ewc2.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let anchor = vec![0.0, 0.0, 0.0];
|
||||
ewc1.consolidate(&anchor);
|
||||
ewc2.consolidate(&anchor);
|
||||
|
||||
let weights = vec![1.0, 1.0, 1.0];
|
||||
|
||||
// Penalty and gradient should scale with lambda
|
||||
let penalty1 = ewc1.penalty(&weights);
|
||||
let penalty2 = ewc2.penalty(&weights);
|
||||
assert!((penalty2 / penalty1 - 2.0).abs() < 1e-4);
|
||||
|
||||
let grad1 = ewc1.gradient(&weights);
|
||||
let grad2 = ewc2.gradient(&weights);
|
||||
assert!((grad2[0] / grad1[0] - 2.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_lambda() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
assert_eq!(ewc.lambda(), 100.0);
|
||||
|
||||
ewc.set_lambda(500.0);
|
||||
assert_eq!(ewc.lambda(), 500.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Lambda must be non-negative")]
|
||||
fn test_set_lambda_negative() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
ewc.set_lambda(-10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let mut ewc = ElasticWeightConsolidation::new(100.0);
|
||||
|
||||
// Setup active EWC
|
||||
let grad = vec![1.0, 2.0, 3.0];
|
||||
ewc.compute_fisher(&[grad.as_slice()], 1);
|
||||
|
||||
let weights = vec![0.5, 1.0, 1.5];
|
||||
ewc.consolidate(&weights);
|
||||
|
||||
assert!(ewc.is_active());
|
||||
|
||||
// Reset
|
||||
ewc.reset();
|
||||
|
||||
assert!(!ewc.is_active());
|
||||
assert!(ewc.fisher_diag().is_empty());
|
||||
assert!(ewc.anchor_weights().is_empty());
|
||||
assert_eq!(ewc.lambda(), 100.0); // Lambda preserved
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequential_task_learning() {
|
||||
// Simulate learning two tasks sequentially
|
||||
let mut ewc = ElasticWeightConsolidation::new(1000.0);
|
||||
|
||||
// Task 1: Learn weights [1.0, 2.0, 3.0]
|
||||
let task1_grad = vec![2.0, 1.0, 3.0];
|
||||
ewc.compute_fisher(&[task1_grad.as_slice()], 1);
|
||||
|
||||
let task1_weights = vec![1.0, 2.0, 3.0];
|
||||
ewc.consolidate(&task1_weights);
|
||||
|
||||
// Task 2: Try to learn very different weights
|
||||
let task2_weights = vec![5.0, 6.0, 7.0];
|
||||
|
||||
// EWC penalty should be significant
|
||||
let penalty = ewc.penalty(&task2_weights);
|
||||
assert!(penalty > 10000.0); // Large penalty for large deviation
|
||||
|
||||
// Gradient should point back toward task 1 weights
|
||||
let grad = ewc.gradient(&task2_weights);
|
||||
assert!(grad[0] > 0.0); // Push toward lower value
|
||||
assert!(grad[1] > 0.0);
|
||||
assert!(grad[2] > 0.0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user