500 lines
15 KiB
Rust
500 lines
15 KiB
Rust
//! EWC++ (Enhanced Elastic Weight Consolidation) for SONA
|
|
//!
|
|
//! Prevents catastrophic forgetting with:
|
|
//! - Online Fisher information estimation
|
|
//! - Multi-task memory with circular buffer
|
|
//! - Automatic task boundary detection
|
|
//! - Adaptive lambda scheduling
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::VecDeque;
|
|
|
|
/// EWC++ configuration
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct EwcConfig {
|
|
/// Number of parameters
|
|
pub param_count: usize,
|
|
/// Maximum tasks to remember
|
|
pub max_tasks: usize,
|
|
/// Initial lambda
|
|
pub initial_lambda: f32,
|
|
/// Minimum lambda
|
|
pub min_lambda: f32,
|
|
/// Maximum lambda
|
|
pub max_lambda: f32,
|
|
/// Fisher EMA decay factor
|
|
pub fisher_ema_decay: f32,
|
|
/// Task boundary detection threshold
|
|
pub boundary_threshold: f32,
|
|
/// Gradient history for boundary detection
|
|
pub gradient_history_size: usize,
|
|
}
|
|
|
|
impl Default for EwcConfig {
|
|
fn default() -> Self {
|
|
// OPTIMIZED DEFAULTS based on @ruvector/sona v0.1.1 benchmarks:
|
|
// - Lambda 2000 optimal for catastrophic forgetting prevention
|
|
// - Higher max_lambda (15000) for aggressive protection when needed
|
|
Self {
|
|
param_count: 1000,
|
|
max_tasks: 10,
|
|
initial_lambda: 2000.0, // OPTIMIZED: Better forgetting prevention
|
|
min_lambda: 100.0,
|
|
max_lambda: 15000.0, // OPTIMIZED: Higher ceiling for multi-task
|
|
fisher_ema_decay: 0.999,
|
|
boundary_threshold: 2.0,
|
|
gradient_history_size: 100,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Task-specific Fisher information
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct TaskFisher {
|
|
/// Task ID
|
|
pub task_id: usize,
|
|
/// Fisher diagonal
|
|
pub fisher: Vec<f32>,
|
|
/// Optimal weights for this task
|
|
pub optimal_weights: Vec<f32>,
|
|
/// Task importance (for weighted consolidation)
|
|
pub importance: f32,
|
|
}
|
|
|
|
/// EWC++ implementation
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct EwcPlusPlus {
|
|
/// Configuration
|
|
config: EwcConfig,
|
|
/// Current Fisher information (online estimate)
|
|
current_fisher: Vec<f32>,
|
|
/// Current optimal weights
|
|
current_weights: Vec<f32>,
|
|
/// Task memory (circular buffer)
|
|
task_memory: VecDeque<TaskFisher>,
|
|
/// Current task ID
|
|
current_task_id: usize,
|
|
/// Current lambda
|
|
lambda: f32,
|
|
/// Gradient history for boundary detection
|
|
gradient_history: VecDeque<Vec<f32>>,
|
|
/// Running gradient mean
|
|
gradient_mean: Vec<f32>,
|
|
/// Running gradient variance
|
|
gradient_var: Vec<f32>,
|
|
/// Samples seen for current task
|
|
samples_seen: u64,
|
|
}
|
|
|
|
impl EwcPlusPlus {
|
|
/// Create new EWC++
|
|
pub fn new(config: EwcConfig) -> Self {
|
|
let param_count = config.param_count;
|
|
let initial_lambda = config.initial_lambda;
|
|
|
|
Self {
|
|
config: config.clone(),
|
|
current_fisher: vec![0.0; param_count],
|
|
current_weights: vec![0.0; param_count],
|
|
task_memory: VecDeque::with_capacity(config.max_tasks),
|
|
current_task_id: 0,
|
|
lambda: initial_lambda,
|
|
gradient_history: VecDeque::with_capacity(config.gradient_history_size),
|
|
gradient_mean: vec![0.0; param_count],
|
|
gradient_var: vec![1.0; param_count],
|
|
samples_seen: 0,
|
|
}
|
|
}
|
|
|
|
/// Update Fisher information online using EMA
|
|
pub fn update_fisher(&mut self, gradients: &[f32]) {
|
|
if gradients.len() != self.config.param_count {
|
|
return;
|
|
}
|
|
|
|
let decay = self.config.fisher_ema_decay;
|
|
|
|
// Online Fisher update: F_t = decay * F_{t-1} + (1 - decay) * g^2
|
|
for (i, &g) in gradients.iter().enumerate() {
|
|
self.current_fisher[i] = decay * self.current_fisher[i] + (1.0 - decay) * g * g;
|
|
}
|
|
|
|
// Update gradient statistics for boundary detection
|
|
self.update_gradient_stats(gradients);
|
|
self.samples_seen += 1;
|
|
}
|
|
|
|
/// Update gradient statistics for boundary detection
|
|
fn update_gradient_stats(&mut self, gradients: &[f32]) {
|
|
// Store in history
|
|
if self.gradient_history.len() >= self.config.gradient_history_size {
|
|
self.gradient_history.pop_front();
|
|
}
|
|
self.gradient_history.push_back(gradients.to_vec());
|
|
|
|
// Update running mean and variance (Welford's algorithm)
|
|
let n = self.samples_seen as f32 + 1.0;
|
|
|
|
for (i, &g) in gradients.iter().enumerate() {
|
|
let delta = g - self.gradient_mean[i];
|
|
self.gradient_mean[i] += delta / n;
|
|
let delta2 = g - self.gradient_mean[i];
|
|
self.gradient_var[i] += delta * delta2;
|
|
}
|
|
}
|
|
|
|
/// Detect task boundary using distribution shift
|
|
pub fn detect_task_boundary(&self, gradients: &[f32]) -> bool {
|
|
if self.samples_seen < 50 || gradients.len() != self.config.param_count {
|
|
return false;
|
|
}
|
|
|
|
// Compute z-score of current gradients vs running stats
|
|
let mut z_score_sum = 0.0f32;
|
|
let mut count = 0;
|
|
|
|
for (i, &g) in gradients.iter().enumerate() {
|
|
let var = self.gradient_var[i] / self.samples_seen as f32;
|
|
if var > 1e-8 {
|
|
let std = var.sqrt();
|
|
let z = (g - self.gradient_mean[i]).abs() / std;
|
|
z_score_sum += z;
|
|
count += 1;
|
|
}
|
|
}
|
|
|
|
if count == 0 {
|
|
return false;
|
|
}
|
|
|
|
let avg_z = z_score_sum / count as f32;
|
|
avg_z > self.config.boundary_threshold
|
|
}
|
|
|
|
/// Start new task - saves current Fisher to memory
|
|
pub fn start_new_task(&mut self) {
|
|
// Save current task's Fisher
|
|
let task_fisher = TaskFisher {
|
|
task_id: self.current_task_id,
|
|
fisher: self.current_fisher.clone(),
|
|
optimal_weights: self.current_weights.clone(),
|
|
importance: 1.0,
|
|
};
|
|
|
|
// Add to circular buffer
|
|
if self.task_memory.len() >= self.config.max_tasks {
|
|
self.task_memory.pop_front();
|
|
}
|
|
self.task_memory.push_back(task_fisher);
|
|
|
|
// Reset for new task
|
|
self.current_task_id += 1;
|
|
self.current_fisher.fill(0.0);
|
|
self.gradient_history.clear();
|
|
self.gradient_mean.fill(0.0);
|
|
self.gradient_var.fill(1.0);
|
|
self.samples_seen = 0;
|
|
|
|
// Adapt lambda based on task count
|
|
self.adapt_lambda();
|
|
}
|
|
|
|
/// Adapt lambda based on accumulated tasks
|
|
fn adapt_lambda(&mut self) {
|
|
let task_count = self.task_memory.len();
|
|
if task_count == 0 {
|
|
return;
|
|
}
|
|
|
|
// Increase lambda as more tasks accumulate (more to protect)
|
|
let scale = 1.0 + 0.1 * task_count as f32;
|
|
self.lambda = (self.config.initial_lambda * scale)
|
|
.clamp(self.config.min_lambda, self.config.max_lambda);
|
|
}
|
|
|
|
/// Apply EWC++ constraints to gradients
|
|
pub fn apply_constraints(&self, gradients: &[f32]) -> Vec<f32> {
|
|
if gradients.len() != self.config.param_count {
|
|
return gradients.to_vec();
|
|
}
|
|
|
|
let mut constrained = gradients.to_vec();
|
|
|
|
// Apply constraint from each remembered task
|
|
for task in &self.task_memory {
|
|
for (i, g) in constrained.iter_mut().enumerate() {
|
|
// Penalty: lambda * F_i * (w_i - w*_i)
|
|
// Gradient of penalty: lambda * F_i
|
|
// Project gradient to preserve important weights
|
|
let importance = task.fisher[i] * task.importance;
|
|
if importance > 1e-8 {
|
|
let penalty_grad = self.lambda * importance;
|
|
// Reduce gradient magnitude for important parameters
|
|
*g *= 1.0 / (1.0 + penalty_grad);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Also apply current task's Fisher (online)
|
|
for (i, g) in constrained.iter_mut().enumerate() {
|
|
if self.current_fisher[i] > 1e-8 {
|
|
let penalty_grad = self.lambda * self.current_fisher[i] * 0.1; // Lower weight for current
|
|
*g *= 1.0 / (1.0 + penalty_grad);
|
|
}
|
|
}
|
|
|
|
constrained
|
|
}
|
|
|
|
/// Compute EWC regularization loss
|
|
pub fn regularization_loss(&self, current_weights: &[f32]) -> f32 {
|
|
if current_weights.len() != self.config.param_count {
|
|
return 0.0;
|
|
}
|
|
|
|
let mut loss = 0.0f32;
|
|
|
|
for task in &self.task_memory {
|
|
for ((&cw, &ow), &fi) in current_weights
|
|
.iter()
|
|
.zip(task.optimal_weights.iter())
|
|
.zip(task.fisher.iter())
|
|
.take(self.config.param_count)
|
|
{
|
|
let diff = cw - ow;
|
|
loss += fi * diff * diff * task.importance;
|
|
}
|
|
}
|
|
|
|
self.lambda * loss / 2.0
|
|
}
|
|
|
|
/// Update optimal weights reference
|
|
pub fn set_optimal_weights(&mut self, weights: &[f32]) {
|
|
if weights.len() == self.config.param_count {
|
|
self.current_weights.copy_from_slice(weights);
|
|
}
|
|
}
|
|
|
|
/// Consolidate all tasks (merge Fisher information)
|
|
pub fn consolidate_all_tasks(&mut self) {
|
|
if self.task_memory.is_empty() {
|
|
return;
|
|
}
|
|
|
|
// Compute weighted average of Fisher matrices
|
|
let mut consolidated_fisher = vec![0.0f32; self.config.param_count];
|
|
let mut total_importance = 0.0f32;
|
|
|
|
for task in &self.task_memory {
|
|
for (i, &f) in task.fisher.iter().enumerate() {
|
|
consolidated_fisher[i] += f * task.importance;
|
|
}
|
|
total_importance += task.importance;
|
|
}
|
|
|
|
if total_importance > 0.0 {
|
|
for f in &mut consolidated_fisher {
|
|
*f /= total_importance;
|
|
}
|
|
}
|
|
|
|
// Store as single consolidated task
|
|
let consolidated = TaskFisher {
|
|
task_id: 0,
|
|
fisher: consolidated_fisher,
|
|
optimal_weights: self.current_weights.clone(),
|
|
importance: total_importance,
|
|
};
|
|
|
|
self.task_memory.clear();
|
|
self.task_memory.push_back(consolidated);
|
|
}
|
|
|
|
/// Get current lambda
|
|
pub fn lambda(&self) -> f32 {
|
|
self.lambda
|
|
}
|
|
|
|
/// Set lambda manually
|
|
pub fn set_lambda(&mut self, lambda: f32) {
|
|
self.lambda = lambda.clamp(self.config.min_lambda, self.config.max_lambda);
|
|
}
|
|
|
|
/// Get task count
|
|
pub fn task_count(&self) -> usize {
|
|
self.task_memory.len()
|
|
}
|
|
|
|
/// Get current task ID
|
|
pub fn current_task_id(&self) -> usize {
|
|
self.current_task_id
|
|
}
|
|
|
|
/// Get samples seen for current task
|
|
pub fn samples_seen(&self) -> u64 {
|
|
self.samples_seen
|
|
}
|
|
|
|
/// Get parameter importance scores
|
|
pub fn importance_scores(&self) -> Vec<f32> {
|
|
let mut scores = self.current_fisher.clone();
|
|
|
|
for task in &self.task_memory {
|
|
for (i, &f) in task.fisher.iter().enumerate() {
|
|
scores[i] += f * task.importance;
|
|
}
|
|
}
|
|
|
|
scores
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_ewc_creation() {
|
|
let config = EwcConfig {
|
|
param_count: 100,
|
|
..Default::default()
|
|
};
|
|
let ewc = EwcPlusPlus::new(config);
|
|
|
|
assert_eq!(ewc.task_count(), 0);
|
|
assert_eq!(ewc.current_task_id(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_fisher_update() {
|
|
let config = EwcConfig {
|
|
param_count: 10,
|
|
..Default::default()
|
|
};
|
|
let mut ewc = EwcPlusPlus::new(config);
|
|
|
|
let gradients = vec![0.5; 10];
|
|
ewc.update_fisher(&gradients);
|
|
|
|
assert!(ewc.samples_seen() > 0);
|
|
assert!(ewc.current_fisher.iter().any(|&f| f > 0.0));
|
|
}
|
|
|
|
#[test]
|
|
fn test_task_boundary() {
|
|
let config = EwcConfig {
|
|
param_count: 10,
|
|
gradient_history_size: 10,
|
|
boundary_threshold: 2.0,
|
|
..Default::default()
|
|
};
|
|
let mut ewc = EwcPlusPlus::new(config);
|
|
|
|
// Train on consistent gradients
|
|
for _ in 0..60 {
|
|
let gradients = vec![0.1; 10];
|
|
ewc.update_fisher(&gradients);
|
|
}
|
|
|
|
// Normal gradient should not trigger boundary
|
|
let normal = vec![0.1; 10];
|
|
assert!(!ewc.detect_task_boundary(&normal));
|
|
|
|
// Very different gradient might trigger boundary
|
|
let different = vec![10.0; 10];
|
|
// May or may not trigger depending on variance
|
|
}
|
|
|
|
#[test]
|
|
fn test_constraint_application() {
|
|
let config = EwcConfig {
|
|
param_count: 5,
|
|
..Default::default()
|
|
};
|
|
let mut ewc = EwcPlusPlus::new(config);
|
|
|
|
// Build up some Fisher information
|
|
for _ in 0..10 {
|
|
ewc.update_fisher(&vec![1.0; 5]);
|
|
}
|
|
ewc.start_new_task();
|
|
|
|
// Apply constraints
|
|
let gradients = vec![1.0; 5];
|
|
let constrained = ewc.apply_constraints(&gradients);
|
|
|
|
// Constrained gradients should be smaller
|
|
let orig_mag: f32 = gradients.iter().map(|x| x.abs()).sum();
|
|
let const_mag: f32 = constrained.iter().map(|x| x.abs()).sum();
|
|
assert!(const_mag <= orig_mag);
|
|
}
|
|
|
|
#[test]
|
|
fn test_regularization_loss() {
|
|
let config = EwcConfig {
|
|
param_count: 5,
|
|
initial_lambda: 100.0,
|
|
..Default::default()
|
|
};
|
|
let mut ewc = EwcPlusPlus::new(config);
|
|
|
|
// Set up optimal weights and Fisher
|
|
ewc.set_optimal_weights(&vec![0.0; 5]);
|
|
for _ in 0..10 {
|
|
ewc.update_fisher(&vec![1.0; 5]);
|
|
}
|
|
ewc.start_new_task();
|
|
|
|
// Loss should be zero when at optimal
|
|
let at_optimal = ewc.regularization_loss(&vec![0.0; 5]);
|
|
|
|
// Loss should be positive when deviated
|
|
let deviated = ewc.regularization_loss(&vec![1.0; 5]);
|
|
assert!(deviated > at_optimal);
|
|
}
|
|
|
|
#[test]
|
|
fn test_task_consolidation() {
|
|
let config = EwcConfig {
|
|
param_count: 5,
|
|
max_tasks: 5,
|
|
..Default::default()
|
|
};
|
|
let mut ewc = EwcPlusPlus::new(config);
|
|
|
|
// Create multiple tasks
|
|
for _ in 0..3 {
|
|
for _ in 0..10 {
|
|
ewc.update_fisher(&vec![1.0; 5]);
|
|
}
|
|
ewc.start_new_task();
|
|
}
|
|
|
|
assert_eq!(ewc.task_count(), 3);
|
|
|
|
ewc.consolidate_all_tasks();
|
|
assert_eq!(ewc.task_count(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_lambda_adaptation() {
|
|
let config = EwcConfig {
|
|
param_count: 5,
|
|
initial_lambda: 1000.0,
|
|
..Default::default()
|
|
};
|
|
let mut ewc = EwcPlusPlus::new(config);
|
|
|
|
let initial_lambda = ewc.lambda();
|
|
|
|
// Add tasks
|
|
for _ in 0..5 {
|
|
ewc.start_new_task();
|
|
}
|
|
|
|
// Lambda should have increased
|
|
assert!(ewc.lambda() >= initial_lambda);
|
|
}
|
|
}
|