git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
529 lines
14 KiB
Rust
529 lines
14 KiB
Rust
#![allow(
|
|
clippy::all,
|
|
unused_imports,
|
|
unused_variables,
|
|
dead_code,
|
|
unused_mut,
|
|
unused_assignments,
|
|
non_camel_case_types,
|
|
clippy::approx_constant,
|
|
unexpected_cfgs,
|
|
unused_must_use,
|
|
unused_parens
|
|
)]
|
|
//! Integration tests for LoRA (Low-Rank Adaptation)
|
|
//!
|
|
//! Tests MicroLoRA adaptation, forward pass, gradient accumulation,
|
|
//! EWC state management, and serialization.
|
|
|
|
use ruvllm::{
|
|
error::Result,
|
|
lora::{AdaptFeedback, LoraAdapter, MicroLoRA, MicroLoraConfig, TargetModule},
|
|
};
|
|
use std::collections::HashMap;
|
|
|
|
/// Create a test MicroLoRA configuration
|
|
fn create_test_config(dim: usize) -> MicroLoraConfig {
|
|
MicroLoraConfig {
|
|
rank: 2,
|
|
alpha: 4.0,
|
|
dropout: 0.0,
|
|
target_modules: vec![TargetModule::QProj, TargetModule::VProj],
|
|
in_features: dim,
|
|
out_features: dim,
|
|
use_bias: false,
|
|
standard_init: true,
|
|
gradient_checkpointing: false,
|
|
}
|
|
}
|
|
|
|
/// Create test input data
|
|
fn create_test_input(dim: usize) -> Vec<f32> {
|
|
(0..dim).map(|i| (i as f32) * 0.01).collect()
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_creation() {
|
|
let config = create_test_config(256);
|
|
let lora = MicroLoRA::new(config);
|
|
|
|
assert_eq!(lora.config().rank, 2);
|
|
assert_eq!(lora.config().alpha, 4.0);
|
|
assert!(lora.is_enabled());
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_forward() {
|
|
let config = create_test_config(64);
|
|
let lora = MicroLoRA::new(config);
|
|
|
|
let input = create_test_input(64);
|
|
|
|
// Forward pass for Q projection
|
|
let output = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
assert_eq!(output.len(), 64);
|
|
assert!(output.iter().all(|&v| v.is_finite()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_adapt_changes_output() {
|
|
let config = MicroLoraConfig {
|
|
rank: 2,
|
|
alpha: 4.0,
|
|
dropout: 0.0,
|
|
target_modules: vec![TargetModule::QProj],
|
|
in_features: 64,
|
|
out_features: 64,
|
|
use_bias: false,
|
|
standard_init: true,
|
|
gradient_checkpointing: false,
|
|
};
|
|
|
|
let lora = MicroLoRA::new(config);
|
|
let input = create_test_input(64);
|
|
|
|
// Forward pass before adaptation
|
|
let output_before = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
// Apply adaptation with feedback
|
|
let feedback = AdaptFeedback::from_quality(0.8);
|
|
lora.adapt(&input, feedback).unwrap();
|
|
|
|
// Apply accumulated updates
|
|
lora.apply_updates(0.01);
|
|
|
|
// Forward pass after adaptation
|
|
let output_after = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
assert_eq!(output_before.len(), output_after.len());
|
|
|
|
// Output should change after adaptation
|
|
let changed = output_before
|
|
.iter()
|
|
.zip(output_after.iter())
|
|
.any(|(a, b)| (a - b).abs() > 1e-10);
|
|
let all_near_zero = output_before.iter().all(|&v| v.abs() < 1e-6);
|
|
|
|
assert!(
|
|
changed || all_near_zero,
|
|
"Adaptation should change output or both should be zero"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_lora_forward_dimensions() {
|
|
let input_dim = 128;
|
|
let output_dim = 128;
|
|
|
|
let config = MicroLoraConfig {
|
|
rank: 2,
|
|
alpha: 4.0,
|
|
dropout: 0.0,
|
|
target_modules: vec![TargetModule::QProj],
|
|
in_features: input_dim,
|
|
out_features: output_dim,
|
|
use_bias: false,
|
|
standard_init: true,
|
|
gradient_checkpointing: false,
|
|
};
|
|
|
|
let lora = MicroLoRA::new(config);
|
|
let input = create_test_input(input_dim);
|
|
let output = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
assert_eq!(output.len(), output_dim);
|
|
assert!(output.iter().all(|&v| v.is_finite()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_lora_adapter_creation() {
|
|
let adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
|
|
|
assert_eq!(adapter.rank(), 2);
|
|
assert_eq!(adapter.param_count(), 64 * 2 + 2 * 64); // A matrix + B matrix
|
|
}
|
|
|
|
#[test]
|
|
fn test_lora_adapter_forward() {
|
|
let adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
|
let input = ndarray::Array1::from_vec(create_test_input(64));
|
|
|
|
let output = adapter.forward(&input);
|
|
|
|
assert_eq!(output.len(), 64);
|
|
assert!(output.iter().all(|&v| v.is_finite()));
|
|
|
|
// With zero-initialized B, output should be zero
|
|
let sum: f32 = output.iter().sum();
|
|
assert!(
|
|
sum.abs() < 1e-6,
|
|
"Initial forward should be ~0, got {}",
|
|
sum
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_lora_adapter_gradient_accumulation() {
|
|
let mut adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
|
let input = ndarray::Array1::from_elem(64, 0.1);
|
|
let grad_output = ndarray::Array1::from_elem(64, 0.1);
|
|
|
|
// Accumulate gradient
|
|
adapter.accumulate_gradient(&input, &grad_output, 0.8);
|
|
assert_eq!(adapter.pending_updates(), 1);
|
|
|
|
// Apply gradients
|
|
adapter.apply_gradients(0.01);
|
|
assert_eq!(adapter.pending_updates(), 0);
|
|
|
|
// After update, forward should produce non-zero output
|
|
let output = adapter.forward(&input);
|
|
let sum: f32 = output.iter().map(|x| x.abs()).sum();
|
|
assert!(sum > 0.0, "After update, output should be non-zero");
|
|
}
|
|
|
|
// Note: EwcState is not exported from the lora module, so EWC-specific
|
|
// tests are implemented in the unit tests within micro_lora.rs
|
|
|
|
#[test]
|
|
fn test_adapt_feedback_creation() {
|
|
let feedback = AdaptFeedback::from_quality(0.85);
|
|
|
|
assert_eq!(feedback.quality, 0.85);
|
|
assert_eq!(feedback.reward, Some(0.85));
|
|
assert!(feedback.gradient_estimate.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_adapt_feedback_with_gradient() {
|
|
let gradient = vec![0.1; 64];
|
|
let feedback = AdaptFeedback::with_gradient(0.9, gradient.clone());
|
|
|
|
assert_eq!(feedback.quality, 0.9);
|
|
assert_eq!(feedback.gradient_estimate.len(), 64);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adapt_feedback_for_module() {
|
|
let feedback = AdaptFeedback::from_quality(0.8).for_module(TargetModule::QProj);
|
|
|
|
assert_eq!(feedback.source_module, Some(TargetModule::QProj));
|
|
}
|
|
|
|
#[test]
|
|
fn test_adapt_feedback_with_session() {
|
|
let feedback = AdaptFeedback::from_quality(0.8).with_session("session-123".to_string());
|
|
|
|
assert_eq!(feedback.session_id, Some("session-123".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_adaptations() {
|
|
let config = create_test_config(64);
|
|
let lora = MicroLoRA::new(config);
|
|
let input = create_test_input(64);
|
|
|
|
// Multiple adaptation cycles
|
|
for i in 0..5 {
|
|
let quality = 0.5 + (i as f32 * 0.1);
|
|
let feedback = AdaptFeedback::from_quality(quality);
|
|
lora.adapt(&input, feedback).unwrap();
|
|
}
|
|
|
|
assert_eq!(lora.adaptation_count(), 5);
|
|
|
|
// Apply updates
|
|
lora.apply_updates(0.01);
|
|
|
|
// Verify output is valid
|
|
let output = lora.forward(&input, &TargetModule::QProj);
|
|
assert_eq!(output.len(), 64);
|
|
assert!(output.iter().all(|&v| v.is_finite()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_lora_with_different_ranks() {
|
|
let ranks = [1, 2];
|
|
let input = create_test_input(64);
|
|
|
|
for rank in ranks {
|
|
let config = MicroLoraConfig {
|
|
rank,
|
|
alpha: rank as f32 * 2.0,
|
|
dropout: 0.0,
|
|
target_modules: vec![TargetModule::QProj],
|
|
in_features: 64,
|
|
out_features: 64,
|
|
use_bias: false,
|
|
standard_init: true,
|
|
gradient_checkpointing: false,
|
|
};
|
|
|
|
let lora = MicroLoRA::new(config);
|
|
let output = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
assert_eq!(
|
|
output.len(),
|
|
64,
|
|
"Rank {} should produce correct output size",
|
|
rank
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_target_module_variants() {
|
|
let modules = vec![
|
|
TargetModule::QProj,
|
|
TargetModule::KProj,
|
|
TargetModule::VProj,
|
|
TargetModule::OProj,
|
|
TargetModule::GateProj,
|
|
TargetModule::UpProj,
|
|
TargetModule::DownProj,
|
|
TargetModule::Embed,
|
|
TargetModule::LmHead,
|
|
];
|
|
|
|
for module in &modules {
|
|
let name = module.as_str();
|
|
assert!(!name.is_empty());
|
|
}
|
|
|
|
assert_eq!(TargetModule::QProj.as_str(), "q_proj");
|
|
assert_eq!(TargetModule::VProj.as_str(), "v_proj");
|
|
}
|
|
|
|
#[test]
|
|
fn test_target_module_defaults() {
|
|
let defaults = TargetModule::defaults();
|
|
assert_eq!(defaults.len(), 2);
|
|
assert!(defaults.contains(&TargetModule::QProj));
|
|
assert!(defaults.contains(&TargetModule::VProj));
|
|
}
|
|
|
|
#[test]
|
|
fn test_target_module_attention() {
|
|
let attention = TargetModule::attention();
|
|
assert_eq!(attention.len(), 4);
|
|
assert!(attention.contains(&TargetModule::QProj));
|
|
assert!(attention.contains(&TargetModule::KProj));
|
|
assert!(attention.contains(&TargetModule::VProj));
|
|
assert!(attention.contains(&TargetModule::OProj));
|
|
}
|
|
|
|
#[test]
|
|
fn test_target_module_mlp() {
|
|
let mlp = TargetModule::mlp();
|
|
assert_eq!(mlp.len(), 3);
|
|
assert!(mlp.contains(&TargetModule::GateProj));
|
|
assert!(mlp.contains(&TargetModule::UpProj));
|
|
assert!(mlp.contains(&TargetModule::DownProj));
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_config_memory() {
|
|
let config = MicroLoraConfig {
|
|
rank: 2,
|
|
alpha: 4.0,
|
|
dropout: 0.0,
|
|
target_modules: vec![TargetModule::QProj, TargetModule::VProj],
|
|
in_features: 768,
|
|
out_features: 768,
|
|
use_bias: false,
|
|
standard_init: true,
|
|
gradient_checkpointing: false,
|
|
};
|
|
|
|
let memory = config.memory_bytes();
|
|
// 2 modules * (768 * 2 + 2 * 768) * 4 bytes
|
|
assert!(memory < 1024 * 1024, "Memory should be < 1MB for MicroLoRA");
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_enable_disable() {
|
|
let config = create_test_config(64);
|
|
let mut lora = MicroLoRA::new(config);
|
|
let input = create_test_input(64);
|
|
|
|
assert!(lora.is_enabled());
|
|
|
|
// Disable
|
|
lora.set_enabled(false);
|
|
assert!(!lora.is_enabled());
|
|
|
|
// Forward when disabled should return zeros
|
|
let output = lora.forward(&input, &TargetModule::QProj);
|
|
assert!(output.iter().all(|&v| v == 0.0));
|
|
|
|
// Re-enable
|
|
lora.set_enabled(true);
|
|
assert!(lora.is_enabled());
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_reset() {
|
|
let config = create_test_config(64);
|
|
let lora = MicroLoRA::new(config);
|
|
let input = create_test_input(64);
|
|
|
|
// Perform some adaptations
|
|
for _ in 0..5 {
|
|
let feedback = AdaptFeedback::from_quality(0.8);
|
|
lora.adapt(&input, feedback).unwrap();
|
|
}
|
|
|
|
assert!(lora.adaptation_count() > 0);
|
|
|
|
// Reset
|
|
lora.reset();
|
|
|
|
assert_eq!(lora.adaptation_count(), 0);
|
|
assert_eq!(lora.forward_count(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_memory_usage() {
|
|
let config = create_test_config(64);
|
|
let lora = MicroLoRA::new(config);
|
|
|
|
let memory = lora.memory_bytes();
|
|
let params = lora.param_count();
|
|
|
|
assert!(memory > 0);
|
|
assert!(params > 0);
|
|
assert_eq!(memory, params * std::mem::size_of::<f32>());
|
|
}
|
|
|
|
#[test]
|
|
fn test_lora_adapter_simd_forward() {
|
|
let adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
|
let input = create_test_input(64);
|
|
let mut output = vec![0.0f32; 64];
|
|
|
|
adapter.forward_simd(&input, &mut output);
|
|
|
|
// Compare with regular forward
|
|
let input_array = ndarray::Array1::from_vec(input.clone());
|
|
let expected = adapter.forward(&input_array);
|
|
|
|
for (o, e) in output.iter().zip(expected.iter()) {
|
|
assert!(
|
|
(o - e).abs() < 1e-5,
|
|
"SIMD forward mismatch: {} vs {}",
|
|
o,
|
|
e
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_with_custom_dimensions() {
|
|
let config = MicroLoraConfig {
|
|
rank: 2,
|
|
alpha: 4.0,
|
|
dropout: 0.0,
|
|
target_modules: vec![TargetModule::QProj, TargetModule::VProj],
|
|
in_features: 256, // Default dimensions
|
|
out_features: 256,
|
|
use_bias: false,
|
|
standard_init: true,
|
|
gradient_checkpointing: false,
|
|
};
|
|
|
|
// Create with custom dimensions per module
|
|
let mut dimensions = HashMap::new();
|
|
dimensions.insert(TargetModule::QProj, (128, 128));
|
|
dimensions.insert(TargetModule::VProj, (128, 128));
|
|
|
|
let lora = MicroLoRA::with_dimensions(config, dimensions);
|
|
|
|
let input = create_test_input(128);
|
|
let output = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
assert_eq!(output.len(), 128);
|
|
}
|
|
|
|
#[test]
|
|
fn test_micro_lora_save_load() {
|
|
let config = create_test_config(64);
|
|
let lora = MicroLoRA::new(config);
|
|
let input = create_test_input(64);
|
|
|
|
// Apply some adaptation
|
|
let feedback = AdaptFeedback::from_quality(0.85);
|
|
lora.adapt(&input, feedback).unwrap();
|
|
lora.apply_updates(0.01);
|
|
|
|
// Export state
|
|
let state = lora.export_state();
|
|
|
|
assert_eq!(state.config.rank, 2);
|
|
assert!(!state.adapters.is_empty());
|
|
|
|
// Restore from state
|
|
let lora_restored = MicroLoRA::from_state(state).unwrap();
|
|
|
|
// Both should produce same output
|
|
let output_original = lora.forward(&input, &TargetModule::QProj);
|
|
let output_restored = lora_restored.forward(&input, &TargetModule::QProj);
|
|
|
|
for (a, b) in output_original.iter().zip(output_restored.iter()) {
|
|
assert!(
|
|
(a - b).abs() < 1e-5,
|
|
"Restored model should match: {} vs {}",
|
|
a,
|
|
b
|
|
);
|
|
}
|
|
}
|
|
|
|
// Note: test_lora_apply_updates_with_ewc removed as EwcState is not exported
|
|
|
|
#[test]
|
|
fn test_lora_adapter_reset() {
|
|
let mut adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
|
let input = ndarray::Array1::from_elem(64, 0.1);
|
|
let grad_output = ndarray::Array1::from_elem(64, 0.1);
|
|
|
|
// Accumulate some gradients and apply
|
|
adapter.accumulate_gradient(&input, &grad_output, 0.8);
|
|
adapter.apply_gradients(0.01);
|
|
|
|
// Reset
|
|
adapter.reset();
|
|
|
|
assert_eq!(adapter.pending_updates(), 0);
|
|
|
|
// B matrix should be reset to zero
|
|
let output = adapter.forward(&input);
|
|
let sum: f32 = output.iter().sum();
|
|
assert!(sum.abs() < 1e-6, "After reset, output should be ~0");
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_for_hidden_dim() {
|
|
let config = MicroLoraConfig::for_hidden_dim(512);
|
|
|
|
assert_eq!(config.in_features, 512);
|
|
assert_eq!(config.out_features, 512);
|
|
assert_eq!(config.rank, 2); // Default rank
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_builder_methods() {
|
|
let config = MicroLoraConfig::for_hidden_dim(256)
|
|
.with_rank(1)
|
|
.with_alpha(8.0)
|
|
.with_targets(vec![
|
|
TargetModule::QProj,
|
|
TargetModule::KProj,
|
|
TargetModule::VProj,
|
|
]);
|
|
|
|
assert_eq!(config.rank, 1);
|
|
assert_eq!(config.alpha, 8.0);
|
|
assert_eq!(config.target_modules.len(), 3);
|
|
}
|