Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
236
crates/ruvector-dag/tests/integration/sona_tests.rs
Normal file
236
crates/ruvector-dag/tests/integration/sona_tests.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
//! SONA learning integration tests
|
||||
|
||||
use ruvector_dag::dag::{OperatorNode, OperatorType, QueryDag};
|
||||
use ruvector_dag::sona::*;
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_adaptation() {
|
||||
let mut lora = MicroLoRA::new(MicroLoRAConfig::default(), 256);
|
||||
|
||||
let input = ndarray::Array1::from_vec(vec![0.1; 256]);
|
||||
let output1 = lora.forward(&input);
|
||||
|
||||
// Adapt
|
||||
let gradient = ndarray::Array1::from_vec(vec![0.01; 256]);
|
||||
lora.adapt(&gradient, 0.1);
|
||||
|
||||
let output2 = lora.forward(&input);
|
||||
|
||||
// Output should change after adaptation
|
||||
let diff: f32 = output1
|
||||
.iter()
|
||||
.zip(output2.iter())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.sum();
|
||||
|
||||
assert!(diff > 0.0, "Output should change after adaptation");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trajectory_buffer() {
|
||||
let buffer = DagTrajectoryBuffer::new(10);
|
||||
|
||||
// Push trajectories
|
||||
for i in 0..15 {
|
||||
buffer.push(DagTrajectory::new(
|
||||
i as u64,
|
||||
vec![0.1; 256],
|
||||
"topological".to_string(),
|
||||
100.0,
|
||||
150.0,
|
||||
));
|
||||
}
|
||||
|
||||
// Buffer should not exceed capacity
|
||||
assert!(buffer.len() <= 10);
|
||||
|
||||
// Drain should return all
|
||||
let drained = buffer.drain();
|
||||
assert!(!drained.is_empty());
|
||||
assert!(buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reasoning_bank_clustering() {
|
||||
let mut bank = DagReasoningBank::new(ReasoningBankConfig {
|
||||
num_clusters: 5,
|
||||
pattern_dim: 256,
|
||||
max_patterns: 100,
|
||||
similarity_threshold: 0.5,
|
||||
});
|
||||
|
||||
// Store patterns
|
||||
for i in 0..50 {
|
||||
let pattern: Vec<f32> = (0..256)
|
||||
.map(|j| ((i * 256 + j) as f32 / 1000.0).sin())
|
||||
.collect();
|
||||
bank.store_pattern(pattern, 0.8);
|
||||
}
|
||||
|
||||
assert_eq!(bank.pattern_count(), 50);
|
||||
|
||||
// Cluster
|
||||
bank.recompute_clusters();
|
||||
|
||||
// Query similar
|
||||
let query: Vec<f32> = (0..256).map(|j| (j as f32 / 1000.0).sin()).collect();
|
||||
let results = bank.query_similar(&query, 5);
|
||||
|
||||
assert!(results.len() <= 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ewc_prevents_forgetting() {
|
||||
let mut ewc = EwcPlusPlus::new(EwcConfig::default());
|
||||
|
||||
// Initial parameters
|
||||
let params1 = ndarray::Array1::from_vec(vec![1.0; 256]);
|
||||
let fisher1 = ndarray::Array1::from_vec(vec![0.1; 256]);
|
||||
|
||||
ewc.consolidate(¶ms1, &fisher1);
|
||||
|
||||
// Penalty should be 0 for original params
|
||||
let penalty0 = ewc.penalty(¶ms1);
|
||||
assert!(penalty0 < 0.001);
|
||||
|
||||
// Penalty should increase for deviated params
|
||||
let params2 = ndarray::Array1::from_vec(vec![2.0; 256]);
|
||||
let penalty1 = ewc.penalty(¶ms2);
|
||||
|
||||
assert!(penalty1 > penalty0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trajectory_buffer_ordering() {
|
||||
let buffer = DagTrajectoryBuffer::new(100);
|
||||
|
||||
// Push trajectories with different timestamps
|
||||
for i in 0..10 {
|
||||
buffer.push(DagTrajectory::new(
|
||||
i as u64,
|
||||
vec![0.1; 256],
|
||||
"test".to_string(),
|
||||
100.0,
|
||||
150.0,
|
||||
));
|
||||
}
|
||||
|
||||
let trajectories = buffer.drain();
|
||||
|
||||
// Should maintain insertion order
|
||||
for (idx, traj) in trajectories.iter().enumerate() {
|
||||
assert_eq!(traj.query_hash, idx as u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_rank_adaptation() {
|
||||
let config = MicroLoRAConfig {
|
||||
rank: 8,
|
||||
alpha: 16.0,
|
||||
dropout: 0.1,
|
||||
};
|
||||
|
||||
let lora = MicroLoRA::new(config, 256);
|
||||
let input = ndarray::Array1::from_vec(vec![0.5; 256]);
|
||||
let output = lora.forward(&input);
|
||||
|
||||
assert_eq!(output.len(), 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reasoning_bank_similarity_threshold() {
|
||||
let config = ReasoningBankConfig {
|
||||
num_clusters: 3,
|
||||
pattern_dim: 64,
|
||||
max_patterns: 50,
|
||||
similarity_threshold: 0.9, // High threshold
|
||||
};
|
||||
|
||||
let mut bank = DagReasoningBank::new(config);
|
||||
|
||||
// Store identical patterns
|
||||
let pattern = vec![1.0; 64];
|
||||
for _ in 0..10 {
|
||||
bank.store_pattern(pattern.clone(), 0.8);
|
||||
}
|
||||
|
||||
// Query should return similar patterns
|
||||
let results = bank.query_similar(&pattern, 5);
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ewc_consolidation_updates() {
|
||||
let mut ewc = EwcPlusPlus::new(EwcConfig {
|
||||
lambda: 1000.0,
|
||||
decay: 0.9,
|
||||
online: true,
|
||||
});
|
||||
|
||||
let params1 = ndarray::Array1::from_vec(vec![1.0; 256]);
|
||||
let fisher1 = ndarray::Array1::from_vec(vec![0.5; 256]);
|
||||
|
||||
ewc.consolidate(¶ms1, &fisher1);
|
||||
|
||||
// Second consolidation
|
||||
let params2 = ndarray::Array1::from_vec(vec![1.5; 256]);
|
||||
let fisher2 = ndarray::Array1::from_vec(vec![0.3; 256]);
|
||||
|
||||
ewc.consolidate(¶ms2, &fisher2);
|
||||
|
||||
// Penalty should consider both consolidations
|
||||
let params3 = ndarray::Array1::from_vec(vec![2.0; 256]);
|
||||
let penalty = ewc.penalty(¶ms3);
|
||||
|
||||
assert!(penalty > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trajectory_buffer_capacity() {
|
||||
let buffer = DagTrajectoryBuffer::new(5);
|
||||
|
||||
for i in 0..10 {
|
||||
buffer.push(DagTrajectory::new(
|
||||
i as u64,
|
||||
vec![0.1; 256],
|
||||
"test".to_string(),
|
||||
100.0,
|
||||
150.0,
|
||||
));
|
||||
}
|
||||
|
||||
// Should only keep last 5
|
||||
assert_eq!(buffer.len(), 5);
|
||||
|
||||
let trajectories = buffer.drain();
|
||||
assert_eq!(trajectories.len(), 5);
|
||||
|
||||
// Should have IDs 5-9 (most recent)
|
||||
let ids: Vec<u64> = trajectories.iter().map(|t| t.query_hash).collect();
|
||||
assert!(ids.contains(&5));
|
||||
assert!(ids.contains(&9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reasoning_bank_cluster_count() {
|
||||
let config = ReasoningBankConfig {
|
||||
num_clusters: 4,
|
||||
pattern_dim: 128,
|
||||
max_patterns: 100,
|
||||
similarity_threshold: 0.5,
|
||||
};
|
||||
|
||||
let mut bank = DagReasoningBank::new(config);
|
||||
|
||||
// Store diverse patterns
|
||||
for i in 0..20 {
|
||||
let pattern: Vec<f32> = (0..128).map(|j| ((i + j) as f32 / 10.0).sin()).collect();
|
||||
bank.store_pattern(pattern, 0.7);
|
||||
}
|
||||
|
||||
bank.recompute_clusters();
|
||||
|
||||
// Should have created clusters
|
||||
assert!(bank.cluster_count() <= 4);
|
||||
}
|
||||
Reference in New Issue
Block a user