Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
288
crates/ruvllm/tests/adapter_integration.rs
Normal file
288
crates/ruvllm/tests/adapter_integration.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Integration tests for task-specific LoRA adapters
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use ruvllm::lora::{
|
||||
AdaptFeedback, AdapterMerger, AdapterTrainer, AdapterTrainingConfig, HotSwapManager,
|
||||
MergeConfig, MergeStrategy, RuvLtraAdapters, SyntheticDataGenerator, TargetModule,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_adapter_creation_all() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
|
||||
// Test all 5 pre-defined adapters
|
||||
for name in &["coder", "researcher", "security", "architect", "reviewer"] {
|
||||
let lora = adapters.create_lora(name, 256).unwrap();
|
||||
assert!(lora.is_enabled());
|
||||
assert!(lora.param_count() > 0);
|
||||
println!("{}: {} params", name, lora.param_count());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_synthetic_data_generation() {
|
||||
let generator = SyntheticDataGenerator::new(256, 42);
|
||||
|
||||
for task_type in &["coder", "researcher", "security", "architect", "reviewer"] {
|
||||
let dataset = generator.generate(task_type, 100);
|
||||
|
||||
assert_eq!(dataset.feature_dim, 256);
|
||||
assert!(dataset.examples.len() > 0);
|
||||
assert!(dataset.validation.len() > 0);
|
||||
|
||||
// Check quality scores are valid
|
||||
for example in &dataset.examples {
|
||||
assert!(example.quality >= 0.0 && example.quality <= 1.0);
|
||||
}
|
||||
|
||||
let stats = dataset.stats();
|
||||
println!(
|
||||
"{}: train={}, val={}, avg_quality={:.2}",
|
||||
task_type, stats.train_size, stats.val_size, stats.avg_quality
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapter_training() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let lora = adapters.create_lora("coder", 256).unwrap();
|
||||
|
||||
let generator = SyntheticDataGenerator::new(256, 42);
|
||||
let dataset = generator.generate("coder", 100);
|
||||
|
||||
let config = AdapterTrainingConfig::quick();
|
||||
let mut trainer = AdapterTrainer::new(config);
|
||||
|
||||
let result = trainer.train(&lora, &dataset).unwrap();
|
||||
|
||||
assert!(result.epochs_completed > 0);
|
||||
assert!(result.total_steps > 0);
|
||||
assert!(result.final_loss >= 0.0);
|
||||
|
||||
println!(
|
||||
"Training result: {} epochs, {} steps, loss={:.4}",
|
||||
result.epochs_completed, result.total_steps, result.final_loss
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapter_inference() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let lora = adapters.create_lora("coder", 256).unwrap();
|
||||
|
||||
let input = vec![0.5; 256];
|
||||
let output = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
assert_eq!(output.len(), 256);
|
||||
|
||||
let mean = output.iter().sum::<f32>() / output.len() as f32;
|
||||
println!("Mean output: {:.4}", mean);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_average() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let lora1 = adapters.create_lora("coder", 256).unwrap();
|
||||
let lora2 = adapters.create_lora("researcher", 256).unwrap();
|
||||
|
||||
let adapters_to_merge = vec![
|
||||
("coder".to_string(), lora1),
|
||||
("researcher".to_string(), lora2),
|
||||
];
|
||||
|
||||
let config = MergeConfig::average();
|
||||
let merger = AdapterMerger::new(config);
|
||||
|
||||
let merged = merger
|
||||
.merge(&adapters_to_merge, &adapters.coder, 256)
|
||||
.unwrap();
|
||||
|
||||
assert!(merged.is_enabled());
|
||||
assert!(merged.param_count() > 0);
|
||||
|
||||
println!("Merged adapter: {} params", merged.param_count());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_weighted() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let lora1 = adapters.create_lora("coder", 256).unwrap();
|
||||
let lora2 = adapters.create_lora("security", 256).unwrap();
|
||||
|
||||
let adapters_to_merge = vec![
|
||||
("coder".to_string(), lora1),
|
||||
("security".to_string(), lora2),
|
||||
];
|
||||
|
||||
let mut weights = HashMap::new();
|
||||
weights.insert("coder".to_string(), 0.7);
|
||||
weights.insert("security".to_string(), 0.3);
|
||||
|
||||
let config = MergeConfig::weighted(weights);
|
||||
let merger = AdapterMerger::new(config);
|
||||
|
||||
let merged = merger
|
||||
.merge(&adapters_to_merge, &adapters.coder, 256)
|
||||
.unwrap();
|
||||
|
||||
assert!(merged.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_slerp() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let lora1 = adapters.create_lora("coder", 256).unwrap();
|
||||
let lora2 = adapters.create_lora("reviewer", 256).unwrap();
|
||||
|
||||
let adapters_to_merge = vec![
|
||||
("coder".to_string(), lora1),
|
||||
("reviewer".to_string(), lora2),
|
||||
];
|
||||
|
||||
let config = MergeConfig::slerp(0.5);
|
||||
let merger = AdapterMerger::new(config);
|
||||
|
||||
let merged = merger
|
||||
.merge(&adapters_to_merge, &adapters.coder, 256)
|
||||
.unwrap();
|
||||
|
||||
assert!(merged.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hot_swap() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let lora1 = adapters.create_lora("coder", 256).unwrap();
|
||||
let lora2 = adapters.create_lora("security", 256).unwrap();
|
||||
|
||||
let mut manager = HotSwapManager::new();
|
||||
|
||||
manager.set_active(lora1);
|
||||
assert!(manager.active().is_some());
|
||||
|
||||
manager.prepare_standby(lora2);
|
||||
manager.swap().unwrap();
|
||||
|
||||
assert!(manager.active().is_some());
|
||||
assert!(!manager.is_swapping());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_per_request_adaptation() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let lora = adapters.create_lora("coder", 256).unwrap();
|
||||
|
||||
let input = vec![0.5; 256];
|
||||
|
||||
// Baseline
|
||||
let baseline = lora.forward(&input, &TargetModule::QProj);
|
||||
let baseline_mean = baseline.iter().sum::<f32>() / baseline.len() as f32;
|
||||
|
||||
// Adapt
|
||||
let feedback = AdaptFeedback::from_quality(0.9);
|
||||
lora.adapt(&input, feedback).unwrap();
|
||||
lora.apply_updates(0.01);
|
||||
|
||||
// After adaptation
|
||||
let adapted = lora.forward(&input, &TargetModule::QProj);
|
||||
let adapted_mean = adapted.iter().sum::<f32>() / adapted.len() as f32;
|
||||
|
||||
println!(
|
||||
"Baseline mean: {:.4}, Adapted mean: {:.4}",
|
||||
baseline_mean, adapted_mean
|
||||
);
|
||||
|
||||
assert_eq!(lora.adaptation_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_persistence() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let lora = adapters.create_lora("coder", 256).unwrap();
|
||||
|
||||
// Adapt the model
|
||||
let input = vec![0.5; 256];
|
||||
let feedback = AdaptFeedback::from_quality(0.9);
|
||||
lora.adapt(&input, feedback).unwrap();
|
||||
lora.apply_updates(0.01);
|
||||
|
||||
// Save
|
||||
let path = "/tmp/test_adapter.bin";
|
||||
lora.save(path).unwrap();
|
||||
|
||||
// Load
|
||||
let loaded = ruvllm::lora::MicroLoRA::load(path).unwrap();
|
||||
|
||||
assert_eq!(loaded.param_count(), lora.param_count());
|
||||
assert_eq!(loaded.memory_bytes(), lora.memory_bytes());
|
||||
|
||||
println!("Saved and loaded adapter: {} params", loaded.param_count());
|
||||
|
||||
// Cleanup
|
||||
std::fs::remove_file(path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapter_memory_footprint() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
|
||||
for name in &["coder", "researcher", "security", "architect", "reviewer"] {
|
||||
let config = adapters.get(name).unwrap();
|
||||
let mem_256 = config.estimate_memory(256);
|
||||
let mem_768 = config.estimate_memory(768);
|
||||
let mem_4096 = config.estimate_memory(4096);
|
||||
|
||||
println!(
|
||||
"{}: 256d={:.1}KB, 768d={:.1}KB, 4096d={:.1}KB",
|
||||
name,
|
||||
mem_256 as f32 / 1024.0,
|
||||
mem_768 as f32 / 1024.0,
|
||||
mem_4096 as f32 / 1024.0
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapter_composition() {
|
||||
let adapters = RuvLtraAdapters::new();
|
||||
let generator = SyntheticDataGenerator::new(256, 42);
|
||||
|
||||
// Create and train 3 adapters
|
||||
let datasets = generator.generate_all(50);
|
||||
|
||||
let mut trained_adapters = Vec::new();
|
||||
for (name, dataset) in datasets.into_iter().take(3) {
|
||||
let lora = adapters.create_lora(&name, 256).unwrap();
|
||||
let mut trainer = AdapterTrainer::new(AdapterTrainingConfig::quick());
|
||||
trainer.train(&lora, &dataset).unwrap();
|
||||
trained_adapters.push((name, lora));
|
||||
}
|
||||
|
||||
// TIES merge
|
||||
let ties_config = MergeConfig::ties(0.6);
|
||||
let ties_merger = AdapterMerger::new(ties_config);
|
||||
let ties_merged = ties_merger
|
||||
.merge(&trained_adapters, &adapters.coder, 256)
|
||||
.unwrap();
|
||||
|
||||
assert!(ties_merged.is_enabled());
|
||||
|
||||
println!("TIES merged adapter: {} params", ties_merged.param_count());
|
||||
}
|
||||
}
|
||||
571
crates/ruvllm/tests/ane_integration.rs
Normal file
571
crates/ruvllm/tests/ane_integration.rs
Normal file
@@ -0,0 +1,571 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Integration tests for Apple Neural Engine (ANE) / Core ML functionality
|
||||
//!
|
||||
//! These tests verify end-to-end functionality of the ANE/CoreML backend,
|
||||
//! including hybrid pipeline switching, fallback behavior, and memory management.
|
||||
//!
|
||||
//! ## Running Tests
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Run all ANE tests (requires Apple Silicon)
|
||||
//! cargo test --features coreml ane_integration
|
||||
//!
|
||||
//! # Run with hybrid pipeline support
|
||||
//! cargo test --features hybrid-ane ane_integration
|
||||
//!
|
||||
//! # Run on non-Apple Silicon (tests fallback behavior)
|
||||
//! cargo test ane_integration
|
||||
//! ```
|
||||
|
||||
// Import from the crate being tested
|
||||
// Note: CoreMLBackend methods require the coreml feature
|
||||
#[cfg(feature = "coreml")]
|
||||
use ruvllm::backends::CoreMLBackend;
|
||||
use ruvllm::backends::{
|
||||
AneCapabilities, ComputeUnits, GenerateParams, LlmBackend, ModelArchitecture, ModelConfig,
|
||||
Quantization,
|
||||
};
|
||||
use ruvllm::error::{Result, RuvLLMError};
|
||||
|
||||
// ============================================================================
|
||||
// Platform Detection Helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Check if running on Apple Silicon
|
||||
fn is_apple_silicon() -> bool {
|
||||
cfg!(all(target_os = "macos", target_arch = "aarch64"))
|
||||
}
|
||||
|
||||
/// Check if ANE is available
|
||||
fn is_ane_available() -> bool {
|
||||
let caps = AneCapabilities::detect();
|
||||
caps.available
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Core ML Backend Integration Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_ane_capabilities_detection() {
|
||||
let caps = AneCapabilities::detect();
|
||||
|
||||
if is_apple_silicon() {
|
||||
assert!(caps.available, "ANE should be available on Apple Silicon");
|
||||
assert!(caps.tops > 0.0, "TOPS should be positive on Apple Silicon");
|
||||
assert!(
|
||||
caps.max_model_size_mb > 0,
|
||||
"Max model size should be positive"
|
||||
);
|
||||
assert!(
|
||||
!caps.supported_ops.is_empty(),
|
||||
"Should have supported operations"
|
||||
);
|
||||
|
||||
// Verify common operations are supported
|
||||
let expected_ops = ["MatMul", "GELU", "SiLU", "LayerNorm", "Softmax"];
|
||||
for op in &expected_ops {
|
||||
assert!(
|
||||
caps.supported_ops.iter().any(|s| s == *op),
|
||||
"Operation {} should be supported",
|
||||
op
|
||||
);
|
||||
}
|
||||
} else {
|
||||
assert!(
|
||||
!caps.available,
|
||||
"ANE should not be available on non-Apple Silicon"
|
||||
);
|
||||
assert_eq!(caps.tops, 0.0, "TOPS should be 0 when unavailable");
|
||||
assert_eq!(
|
||||
caps.max_model_size_mb, 0,
|
||||
"Max model size should be 0 when unavailable"
|
||||
);
|
||||
assert!(
|
||||
caps.supported_ops.is_empty(),
|
||||
"No operations when unavailable"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_units_selection() {
|
||||
// Test default selection
|
||||
let default = ComputeUnits::default();
|
||||
assert_eq!(default, ComputeUnits::All);
|
||||
|
||||
// Test ANE-focused configuration
|
||||
let ane_focus = ComputeUnits::CpuAndNeuralEngine;
|
||||
assert!(ane_focus.uses_ane());
|
||||
assert!(!ane_focus.uses_gpu());
|
||||
|
||||
// Test GPU-focused configuration
|
||||
let gpu_focus = ComputeUnits::CpuAndGpu;
|
||||
assert!(!gpu_focus.uses_ane());
|
||||
assert!(gpu_focus.uses_gpu());
|
||||
|
||||
// Test all units
|
||||
let all = ComputeUnits::All;
|
||||
assert!(all.uses_ane());
|
||||
assert!(all.uses_gpu());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_suitability_for_ane() {
|
||||
let caps = AneCapabilities::detect();
|
||||
|
||||
if is_apple_silicon() {
|
||||
// Small models should be suitable
|
||||
assert!(caps.is_model_suitable(500), "500MB model should fit");
|
||||
assert!(caps.is_model_suitable(1000), "1GB model should fit");
|
||||
assert!(caps.is_model_suitable(2048), "2GB model should fit");
|
||||
|
||||
// Large models may not fit
|
||||
// (depends on actual device, but 10GB is likely too large)
|
||||
// Skip this assertion as it's hardware-dependent
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Core ML Backend Creation Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "coreml")]
|
||||
fn test_coreml_backend_creation() {
|
||||
if is_apple_silicon() {
|
||||
let result = CoreMLBackend::new();
|
||||
assert!(result.is_ok(), "Should create backend on Apple Silicon");
|
||||
|
||||
let backend = result.unwrap();
|
||||
assert!(!backend.is_model_loaded());
|
||||
assert!(backend.model_info().is_none());
|
||||
} else {
|
||||
let result = CoreMLBackend::new();
|
||||
assert!(result.is_err(), "Should fail on non-Apple Silicon");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "coreml")]
|
||||
fn test_coreml_backend_configuration() {
|
||||
if !is_apple_silicon() {
|
||||
return; // Skip on non-Apple Silicon
|
||||
}
|
||||
|
||||
let backend = CoreMLBackend::new()
|
||||
.unwrap()
|
||||
.with_compute_units(ComputeUnits::CpuAndNeuralEngine);
|
||||
|
||||
let caps = backend.ane_capabilities();
|
||||
assert!(caps.available);
|
||||
assert!(caps.tops > 0.0);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Fallback Behavior Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_fallback_when_coreml_unavailable() {
|
||||
// When coreml feature is not enabled, CoreMLBackend type doesn't exist
|
||||
// so we can only test the AneCapabilities fallback
|
||||
#[cfg(not(feature = "coreml"))]
|
||||
{
|
||||
// Without coreml feature, ANE capabilities should report unavailable
|
||||
let caps = AneCapabilities::detect();
|
||||
// On non-Apple Silicon or without the feature, it should gracefully handle this
|
||||
if !is_apple_silicon() {
|
||||
assert!(
|
||||
!caps.available,
|
||||
"ANE should not be available without coreml feature on non-Apple Silicon"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "coreml")]
|
||||
{
|
||||
if !is_apple_silicon() {
|
||||
let result = CoreMLBackend::new();
|
||||
assert!(result.is_err());
|
||||
|
||||
let err = result.unwrap_err();
|
||||
let err_str = err.to_string();
|
||||
assert!(
|
||||
err_str.contains("not available"),
|
||||
"Should indicate ANE not available"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_graceful_degradation() {
|
||||
// Even when ANE is not available, the AneCapabilities struct should work
|
||||
let caps = AneCapabilities {
|
||||
available: false,
|
||||
tops: 0.0,
|
||||
max_model_size_mb: 0,
|
||||
supported_ops: vec![],
|
||||
};
|
||||
|
||||
// All operations should return false/empty gracefully
|
||||
assert!(!caps.is_model_suitable(100));
|
||||
assert!(!caps.is_model_suitable(0));
|
||||
assert!(!caps.available);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Model Loading Error Handling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[cfg(all(feature = "coreml", target_os = "macos", target_arch = "aarch64"))]
|
||||
fn test_unsupported_model_format_error() {
|
||||
let mut backend = CoreMLBackend::new().unwrap();
|
||||
|
||||
// Try various unsupported formats
|
||||
let unsupported_formats = [
|
||||
"model.safetensors",
|
||||
"model.bin",
|
||||
"model.pt",
|
||||
"model.pth",
|
||||
"model.onnx",
|
||||
];
|
||||
|
||||
for format in &unsupported_formats {
|
||||
let result = backend.load_model(format, ModelConfig::default());
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should reject unsupported format: {}",
|
||||
format
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(all(feature = "coreml", target_os = "macos", target_arch = "aarch64"))]
|
||||
fn test_nonexistent_model_error() {
|
||||
let mut backend = CoreMLBackend::new().unwrap();
|
||||
|
||||
let result = backend.load_model("/nonexistent/path/model.mlmodel", ModelConfig::default());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(all(feature = "coreml", target_os = "macos", target_arch = "aarch64"))]
|
||||
fn test_gguf_conversion_error() {
|
||||
let mut backend = CoreMLBackend::new().unwrap();
|
||||
|
||||
// GGUF conversion is not yet implemented
|
||||
let result = backend.load_model("/path/to/model.gguf", ModelConfig::default());
|
||||
assert!(result.is_err());
|
||||
|
||||
let err = result.unwrap_err();
|
||||
let err_str = err.to_string();
|
||||
assert!(
|
||||
err_str.contains("not") || err_str.contains("conversion"),
|
||||
"Error should mention conversion issue: {}",
|
||||
err_str
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Management Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[cfg(all(feature = "coreml", target_os = "macos", target_arch = "aarch64"))]
|
||||
fn test_model_unloading() {
|
||||
let mut backend = CoreMLBackend::new().unwrap();
|
||||
|
||||
// Initial state
|
||||
assert!(!backend.is_model_loaded());
|
||||
|
||||
// Unload should be safe even without loaded model
|
||||
backend.unload_model();
|
||||
assert!(!backend.is_model_loaded());
|
||||
assert!(backend.model_info().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(all(feature = "coreml", target_os = "macos", target_arch = "aarch64"))]
|
||||
fn test_multiple_unload_calls() {
|
||||
let mut backend = CoreMLBackend::new().unwrap();
|
||||
|
||||
// Multiple unload calls should be safe
|
||||
for _ in 0..5 {
|
||||
backend.unload_model();
|
||||
assert!(!backend.is_model_loaded());
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hybrid Pipeline Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(feature = "hybrid-ane")]
|
||||
mod hybrid_pipeline_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_feature_enabled() {
|
||||
// Verify hybrid-ane feature combines metal-compute and coreml
|
||||
// This test just confirms the feature flag works
|
||||
assert!(true, "Hybrid ANE feature is enabled");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
fn test_hybrid_configuration() {
|
||||
// Test that we can configure for hybrid operation
|
||||
let ane_caps = AneCapabilities::detect();
|
||||
|
||||
if ane_caps.available {
|
||||
// In hybrid mode, we'd route:
|
||||
// - MatMul/FFN to ANE
|
||||
// - Attention to GPU (Metal)
|
||||
assert!(ane_caps.supported_ops.contains(&"MatMul".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Characteristics Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_ane_tops_values() {
|
||||
// Test known TOPS values for various chips
|
||||
struct ChipSpec {
|
||||
name: &'static str,
|
||||
min_tops: f32,
|
||||
max_tops: f32,
|
||||
}
|
||||
|
||||
// Known Apple Silicon TOPS ranges
|
||||
let chip_specs = [
|
||||
ChipSpec {
|
||||
name: "M1",
|
||||
min_tops: 11.0,
|
||||
max_tops: 11.5,
|
||||
},
|
||||
ChipSpec {
|
||||
name: "M1 Pro/Max",
|
||||
min_tops: 11.0,
|
||||
max_tops: 11.5,
|
||||
},
|
||||
ChipSpec {
|
||||
name: "M2",
|
||||
min_tops: 15.0,
|
||||
max_tops: 16.0,
|
||||
},
|
||||
ChipSpec {
|
||||
name: "M3",
|
||||
min_tops: 18.0,
|
||||
max_tops: 18.5,
|
||||
},
|
||||
ChipSpec {
|
||||
name: "M4",
|
||||
min_tops: 35.0,
|
||||
max_tops: 40.0,
|
||||
},
|
||||
];
|
||||
|
||||
if is_apple_silicon() {
|
||||
let caps = AneCapabilities::detect();
|
||||
// Detected TOPS should fall within one of the known ranges
|
||||
let in_known_range = chip_specs
|
||||
.iter()
|
||||
.any(|spec| caps.tops >= spec.min_tops && caps.tops <= spec.max_tops + 5.0);
|
||||
|
||||
// Just verify it's a reasonable positive value
|
||||
assert!(caps.tops > 0.0, "TOPS should be positive");
|
||||
assert!(caps.tops < 100.0, "TOPS should be reasonable (< 100)");
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Type Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_error_messages() {
|
||||
// Test that error messages are informative
|
||||
let caps = AneCapabilities {
|
||||
available: false,
|
||||
tops: 0.0,
|
||||
max_model_size_mb: 0,
|
||||
supported_ops: vec![],
|
||||
};
|
||||
|
||||
// Debug output should be readable
|
||||
let debug = format!("{:?}", caps);
|
||||
assert!(debug.contains("available"));
|
||||
assert!(debug.contains("false"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "coreml")]
|
||||
fn test_error_chain() {
|
||||
if !is_apple_silicon() {
|
||||
let result: Result<CoreMLBackend> = CoreMLBackend::new();
|
||||
let err = result.unwrap_err();
|
||||
|
||||
// Error should be a Config error
|
||||
match &err {
|
||||
RuvLLMError::Config(msg) => {
|
||||
assert!(msg.contains("not available") || msg.contains("feature"));
|
||||
}
|
||||
other => {
|
||||
panic!("Expected Config error, got {:?}", other);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Thread Safety Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_ane_capabilities_thread_safe() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let caps = Arc::new(AneCapabilities::detect());
|
||||
|
||||
let handles: Vec<_> = (0..4)
|
||||
.map(|i| {
|
||||
let caps = Arc::clone(&caps);
|
||||
thread::spawn(move || {
|
||||
// Read operations should be thread-safe
|
||||
let _ = caps.available;
|
||||
let _ = caps.tops;
|
||||
let _ = caps.max_model_size_mb;
|
||||
let _ = caps.is_model_suitable(1000);
|
||||
let _ = format!("{:?}", caps);
|
||||
i
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Thread should complete successfully");
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Benchmark-style Tests (Run with --release)
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore] // Run with: cargo test --release -- --ignored
|
||||
fn test_ane_capabilities_detection_performance() {
|
||||
use std::time::Instant;
|
||||
|
||||
let iterations = 1000;
|
||||
let start = Instant::now();
|
||||
|
||||
for _ in 0..iterations {
|
||||
let _ = AneCapabilities::detect();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let avg_ns = duration.as_nanos() as f64 / iterations as f64;
|
||||
|
||||
println!(
|
||||
"AneCapabilities::detect() average time: {:.2} ns ({:.2} us)",
|
||||
avg_ns,
|
||||
avg_ns / 1000.0
|
||||
);
|
||||
|
||||
// Detection should be fast (< 1ms)
|
||||
assert!(
|
||||
avg_ns < 1_000_000.0,
|
||||
"Detection should be < 1ms, was {} ns",
|
||||
avg_ns
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Documentation Examples Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_readme_example_capabilities() {
|
||||
// Example from module documentation
|
||||
let caps = AneCapabilities::detect();
|
||||
|
||||
if caps.available {
|
||||
println!("ANE available with {} TOPS", caps.tops);
|
||||
println!("Max model size: {} MB", caps.max_model_size_mb);
|
||||
println!("Supported ops: {:?}", caps.supported_ops);
|
||||
} else {
|
||||
println!("ANE not available on this device");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_readme_example_compute_units() {
|
||||
// Example from module documentation
|
||||
let units = ComputeUnits::CpuAndNeuralEngine;
|
||||
|
||||
println!("Compute units: {}", units.description());
|
||||
println!("Uses ANE: {}", units.uses_ane());
|
||||
println!("Uses GPU: {}", units.uses_gpu());
|
||||
|
||||
assert!(units.uses_ane());
|
||||
assert!(!units.uses_gpu());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Property-based Test Helpers
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_model_suitability_monotonic() {
|
||||
// Model suitability should be monotonic: if a larger model fits, smaller ones should too
|
||||
let caps = AneCapabilities {
|
||||
available: true,
|
||||
tops: 38.0,
|
||||
max_model_size_mb: 2048,
|
||||
supported_ops: vec!["MatMul".to_string()],
|
||||
};
|
||||
|
||||
// If 2048 fits, all smaller sizes should fit
|
||||
if caps.is_model_suitable(2048) {
|
||||
for size in [0, 1, 100, 500, 1000, 1500, 2000, 2047] {
|
||||
assert!(
|
||||
caps.is_model_suitable(size),
|
||||
"Size {} should fit if {} fits",
|
||||
size,
|
||||
2048
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// If 2049 doesn't fit, all larger sizes shouldn't fit either
|
||||
if !caps.is_model_suitable(2049) {
|
||||
for size in [2050, 3000, 4096, 10000] {
|
||||
assert!(
|
||||
!caps.is_model_suitable(size),
|
||||
"Size {} should not fit if {} doesn't fit",
|
||||
size,
|
||||
2049
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
683
crates/ruvllm/tests/ane_test_utils.rs
Normal file
683
crates/ruvllm/tests/ane_test_utils.rs
Normal file
@@ -0,0 +1,683 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Test utilities for ANE/Core ML testing
|
||||
//!
|
||||
//! This module provides shared test utilities, fixtures, and helper functions
|
||||
//! for testing Apple Neural Engine and Core ML functionality.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - Random tensor generators with various distributions
|
||||
//! - Comparison utilities with configurable tolerance
|
||||
//! - Small test model generators for quick testing
|
||||
//! - Platform detection helpers
|
||||
//! - Benchmark utilities
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
// ============================================================================
|
||||
// Platform Detection
|
||||
// ============================================================================
|
||||
|
||||
/// Check if running on Apple Silicon
|
||||
pub fn is_apple_silicon() -> bool {
|
||||
cfg!(all(target_os = "macos", target_arch = "aarch64"))
|
||||
}
|
||||
|
||||
/// Check if the coreml feature is enabled
|
||||
pub fn is_coreml_enabled() -> bool {
|
||||
cfg!(feature = "coreml")
|
||||
}
|
||||
|
||||
/// Check if both Apple Silicon and coreml feature are available
|
||||
pub fn is_ane_test_enabled() -> bool {
|
||||
is_apple_silicon() && is_coreml_enabled()
|
||||
}
|
||||
|
||||
/// Skip message for non-Apple Silicon platforms
|
||||
pub fn skip_non_apple_silicon() -> Option<&'static str> {
|
||||
if !is_apple_silicon() {
|
||||
Some("Test skipped: requires Apple Silicon")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Skip message for non-coreml builds
|
||||
pub fn skip_non_coreml() -> Option<&'static str> {
|
||||
if !is_coreml_enabled() {
|
||||
Some("Test skipped: requires coreml feature")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Random Tensor Generators
|
||||
// ============================================================================
|
||||
|
||||
/// Simple linear congruential generator for reproducible random numbers
|
||||
pub struct SimpleRng {
|
||||
state: u64,
|
||||
}
|
||||
|
||||
impl SimpleRng {
|
||||
/// Create a new RNG with the given seed
|
||||
pub fn new(seed: u64) -> Self {
|
||||
Self { state: seed }
|
||||
}
|
||||
|
||||
/// Generate the next random u64
|
||||
pub fn next_u64(&mut self) -> u64 {
|
||||
// LCG parameters (same as glibc)
|
||||
self.state = self.state.wrapping_mul(1103515245).wrapping_add(12345);
|
||||
self.state
|
||||
}
|
||||
|
||||
/// Generate a random f32 in [0, 1)
|
||||
pub fn next_f32(&mut self) -> f32 {
|
||||
(self.next_u64() as f64 / u64::MAX as f64) as f32
|
||||
}
|
||||
|
||||
/// Generate a random f32 in [min, max)
|
||||
pub fn next_f32_range(&mut self, min: f32, max: f32) -> f32 {
|
||||
min + self.next_f32() * (max - min)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random tensor with uniform distribution
|
||||
pub fn random_tensor_uniform(size: usize, min: f32, max: f32, seed: u64) -> Vec<f32> {
|
||||
let mut rng = SimpleRng::new(seed);
|
||||
(0..size).map(|_| rng.next_f32_range(min, max)).collect()
|
||||
}
|
||||
|
||||
/// Generate a random tensor with approximate normal distribution
|
||||
/// Uses Box-Muller transform for simplicity
|
||||
pub fn random_tensor_normal(size: usize, mean: f32, std: f32, seed: u64) -> Vec<f32> {
|
||||
let mut rng = SimpleRng::new(seed);
|
||||
let mut result = Vec::with_capacity(size);
|
||||
|
||||
while result.len() < size {
|
||||
let u1 = rng.next_f32().max(1e-10); // Avoid log(0)
|
||||
let u2 = rng.next_f32();
|
||||
|
||||
let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
|
||||
let z1 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).sin();
|
||||
|
||||
result.push(mean + z0 * std);
|
||||
if result.len() < size {
|
||||
result.push(mean + z1 * std);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Generate a tensor with sequential values
|
||||
pub fn sequential_tensor(size: usize, start: f32, step: f32) -> Vec<f32> {
|
||||
(0..size).map(|i| start + (i as f32) * step).collect()
|
||||
}
|
||||
|
||||
/// Generate a tensor filled with a constant value
|
||||
pub fn constant_tensor(size: usize, value: f32) -> Vec<f32> {
|
||||
vec![value; size]
|
||||
}
|
||||
|
||||
/// Generate an identity matrix
|
||||
pub fn identity_matrix(size: usize) -> Vec<f32> {
|
||||
let mut result = vec![0.0; size * size];
|
||||
for i in 0..size {
|
||||
result[i * size + i] = 1.0;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Generate a zero matrix
|
||||
pub fn zero_matrix(rows: usize, cols: usize) -> Vec<f32> {
|
||||
vec![0.0; rows * cols]
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Comparison Utilities
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for tensor comparison
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompareConfig {
|
||||
/// Absolute tolerance
|
||||
pub atol: f32,
|
||||
/// Relative tolerance
|
||||
pub rtol: f32,
|
||||
/// Whether to print differences
|
||||
pub verbose: bool,
|
||||
/// Maximum number of differences to report
|
||||
pub max_diffs: usize,
|
||||
}
|
||||
|
||||
impl Default for CompareConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
atol: 1e-5,
|
||||
rtol: 1e-4,
|
||||
verbose: false,
|
||||
max_diffs: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompareConfig {
|
||||
/// Create a loose tolerance config (for ANE vs CPU comparison)
|
||||
pub fn loose() -> Self {
|
||||
Self {
|
||||
atol: 1e-3,
|
||||
rtol: 1e-2,
|
||||
verbose: true,
|
||||
max_diffs: 5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a strict tolerance config
|
||||
pub fn strict() -> Self {
|
||||
Self {
|
||||
atol: 1e-6,
|
||||
rtol: 1e-5,
|
||||
verbose: true,
|
||||
max_diffs: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of tensor comparison
|
||||
#[derive(Debug)]
|
||||
pub struct CompareResult {
|
||||
/// Whether the tensors are approximately equal
|
||||
pub equal: bool,
|
||||
/// Maximum absolute difference
|
||||
pub max_abs_diff: f32,
|
||||
/// Maximum relative difference
|
||||
pub max_rel_diff: f32,
|
||||
/// Index of maximum absolute difference
|
||||
pub max_abs_diff_idx: usize,
|
||||
/// Number of elements that differ
|
||||
pub num_diffs: usize,
|
||||
/// Total number of elements compared
|
||||
pub num_elements: usize,
|
||||
/// List of (index, expected, actual, abs_diff) for differences
|
||||
pub differences: Vec<(usize, f32, f32, f32)>,
|
||||
}
|
||||
|
||||
/// Compare two tensors element-wise with configurable tolerance
|
||||
pub fn compare_tensors(expected: &[f32], actual: &[f32], config: &CompareConfig) -> CompareResult {
|
||||
assert_eq!(expected.len(), actual.len(), "Tensor sizes must match");
|
||||
|
||||
let mut max_abs_diff = 0.0f32;
|
||||
let mut max_rel_diff = 0.0f32;
|
||||
let mut max_abs_diff_idx = 0;
|
||||
let mut differences = Vec::new();
|
||||
|
||||
for (i, (&e, &a)) in expected.iter().zip(actual.iter()).enumerate() {
|
||||
let abs_diff = (e - a).abs();
|
||||
let rel_diff = if e.abs() > 1e-10 {
|
||||
abs_diff / e.abs()
|
||||
} else {
|
||||
abs_diff
|
||||
};
|
||||
|
||||
if abs_diff > max_abs_diff {
|
||||
max_abs_diff = abs_diff;
|
||||
max_abs_diff_idx = i;
|
||||
}
|
||||
if rel_diff > max_rel_diff {
|
||||
max_rel_diff = rel_diff;
|
||||
}
|
||||
|
||||
// Check if this element differs beyond tolerance
|
||||
let within_tol = abs_diff <= config.atol + config.rtol * e.abs();
|
||||
if !within_tol && differences.len() < config.max_diffs {
|
||||
differences.push((i, e, a, abs_diff));
|
||||
}
|
||||
}
|
||||
|
||||
let equal =
|
||||
max_abs_diff <= config.atol || max_rel_diff <= config.rtol || differences.is_empty();
|
||||
|
||||
if config.verbose && !equal {
|
||||
eprintln!("Tensor comparison failed:");
|
||||
eprintln!(
|
||||
" Max abs diff: {} at index {}",
|
||||
max_abs_diff, max_abs_diff_idx
|
||||
);
|
||||
eprintln!(" Max rel diff: {}", max_rel_diff);
|
||||
eprintln!(" Differences ({}/{}):", differences.len(), expected.len());
|
||||
for (idx, exp, act, diff) in &differences {
|
||||
eprintln!(
|
||||
" [{}]: expected={}, actual={}, diff={}",
|
||||
idx, exp, act, diff
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
CompareResult {
|
||||
equal,
|
||||
max_abs_diff,
|
||||
max_rel_diff,
|
||||
max_abs_diff_idx,
|
||||
num_diffs: differences.len(),
|
||||
num_elements: expected.len(),
|
||||
differences,
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple approximate equality check
|
||||
pub fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
|
||||
(a - b).abs() < eps
|
||||
}
|
||||
|
||||
/// Check if all elements in a tensor are finite
|
||||
pub fn all_finite(tensor: &[f32]) -> bool {
|
||||
tensor.iter().all(|v| v.is_finite())
|
||||
}
|
||||
|
||||
/// Check if a tensor sums to approximately 1.0 (for softmax output)
|
||||
pub fn sums_to_one(tensor: &[f32], eps: f32) -> bool {
|
||||
let sum: f32 = tensor.iter().sum();
|
||||
approx_eq(sum, 1.0, eps)
|
||||
}
|
||||
|
||||
/// Check if all elements are in range [min, max]
|
||||
pub fn all_in_range(tensor: &[f32], min: f32, max: f32) -> bool {
|
||||
tensor.iter().all(|&v| v >= min && v <= max)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Small Test Model Generators
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for a small test model
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TestModelConfig {
|
||||
/// Hidden dimension
|
||||
pub hidden_dim: usize,
|
||||
/// Number of attention heads
|
||||
pub num_heads: usize,
|
||||
/// Intermediate (FFN) dimension
|
||||
pub intermediate_dim: usize,
|
||||
/// Vocabulary size
|
||||
pub vocab_size: usize,
|
||||
/// Maximum sequence length
|
||||
pub max_seq_len: usize,
|
||||
/// Number of layers
|
||||
pub num_layers: usize,
|
||||
}
|
||||
|
||||
impl Default for TestModelConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hidden_dim: 64,
|
||||
num_heads: 4,
|
||||
intermediate_dim: 128,
|
||||
vocab_size: 1000,
|
||||
max_seq_len: 128,
|
||||
num_layers: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TestModelConfig {
|
||||
/// Create a tiny model config for quick tests
|
||||
pub fn tiny() -> Self {
|
||||
Self {
|
||||
hidden_dim: 32,
|
||||
num_heads: 2,
|
||||
intermediate_dim: 64,
|
||||
vocab_size: 256,
|
||||
max_seq_len: 32,
|
||||
num_layers: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a small model config
|
||||
pub fn small() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create a medium model config for more thorough testing
|
||||
pub fn medium() -> Self {
|
||||
Self {
|
||||
hidden_dim: 256,
|
||||
num_heads: 8,
|
||||
intermediate_dim: 512,
|
||||
vocab_size: 4096,
|
||||
max_seq_len: 256,
|
||||
num_layers: 4,
|
||||
}
|
||||
}
|
||||
|
||||
/// Head dimension
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.hidden_dim / self.num_heads
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate random weights for a layer
|
||||
pub struct TestWeights {
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
impl TestWeights {
|
||||
/// Create a new weight generator with the given seed
|
||||
pub fn new(seed: u64) -> Self {
|
||||
Self { seed }
|
||||
}
|
||||
|
||||
/// Generate weights for a linear layer
|
||||
pub fn linear(&mut self, in_features: usize, out_features: usize) -> Vec<f32> {
|
||||
// Xavier initialization scale
|
||||
let scale = (2.0 / (in_features + out_features) as f32).sqrt();
|
||||
let weights = random_tensor_uniform(in_features * out_features, -scale, scale, self.seed);
|
||||
self.seed += 1;
|
||||
weights
|
||||
}
|
||||
|
||||
/// Generate bias for a linear layer
|
||||
pub fn bias(&mut self, features: usize) -> Vec<f32> {
|
||||
let bias = random_tensor_uniform(features, -0.01, 0.01, self.seed);
|
||||
self.seed += 1;
|
||||
bias
|
||||
}
|
||||
|
||||
/// Generate layer norm weights (initialized to 1.0)
|
||||
pub fn layer_norm_weight(&self, features: usize) -> Vec<f32> {
|
||||
vec![1.0; features]
|
||||
}
|
||||
|
||||
/// Generate layer norm bias (initialized to 0.0)
|
||||
pub fn layer_norm_bias(&self, features: usize) -> Vec<f32> {
|
||||
vec![0.0; features]
|
||||
}
|
||||
|
||||
/// Generate embedding table
|
||||
pub fn embedding(&mut self, vocab_size: usize, hidden_dim: usize) -> Vec<f32> {
|
||||
let scale = 0.02;
|
||||
let weights = random_tensor_normal(vocab_size * hidden_dim, 0.0, scale, self.seed);
|
||||
self.seed += 1;
|
||||
weights
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Benchmark Utilities
|
||||
// ============================================================================
|
||||
|
||||
/// Result of a benchmark run
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BenchmarkResult {
|
||||
/// Name of the benchmark
|
||||
pub name: String,
|
||||
/// Total time for all iterations
|
||||
pub total_time: Duration,
|
||||
/// Number of iterations
|
||||
pub iterations: usize,
|
||||
/// Average time per iteration
|
||||
pub avg_time: Duration,
|
||||
/// Minimum time per iteration
|
||||
pub min_time: Duration,
|
||||
/// Maximum time per iteration
|
||||
pub max_time: Duration,
|
||||
}
|
||||
|
||||
impl BenchmarkResult {
|
||||
/// Print the benchmark result
|
||||
pub fn print(&self) {
|
||||
println!(
|
||||
"{}: avg={:?}, min={:?}, max={:?} ({} iterations)",
|
||||
self.name, self.avg_time, self.min_time, self.max_time, self.iterations
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a simple benchmark
|
||||
pub fn benchmark<F>(name: &str, iterations: usize, mut f: F) -> BenchmarkResult
|
||||
where
|
||||
F: FnMut(),
|
||||
{
|
||||
// Warmup
|
||||
for _ in 0..3 {
|
||||
f();
|
||||
}
|
||||
|
||||
let mut times = Vec::with_capacity(iterations);
|
||||
let total_start = Instant::now();
|
||||
|
||||
for _ in 0..iterations {
|
||||
let start = Instant::now();
|
||||
f();
|
||||
times.push(start.elapsed());
|
||||
}
|
||||
|
||||
let total_time = total_start.elapsed();
|
||||
let avg_time = total_time / iterations as u32;
|
||||
let min_time = times.iter().min().cloned().unwrap_or(Duration::ZERO);
|
||||
let max_time = times.iter().max().cloned().unwrap_or(Duration::ZERO);
|
||||
|
||||
BenchmarkResult {
|
||||
name: name.to_string(),
|
||||
total_time,
|
||||
iterations,
|
||||
avg_time,
|
||||
min_time,
|
||||
max_time,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare two benchmark results
|
||||
pub fn compare_benchmarks(baseline: &BenchmarkResult, optimized: &BenchmarkResult) -> f64 {
|
||||
baseline.avg_time.as_secs_f64() / optimized.avg_time.as_secs_f64()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test Data Fixtures
|
||||
// ============================================================================
|
||||
|
||||
/// Common test data for activation function tests
|
||||
pub struct ActivationTestData {
|
||||
/// Input values covering various ranges
|
||||
pub inputs: Vec<f32>,
|
||||
/// Expected GELU outputs (approximate)
|
||||
pub expected_gelu: Vec<f32>,
|
||||
/// Expected SiLU outputs (approximate)
|
||||
pub expected_silu: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Default for ActivationTestData {
|
||||
fn default() -> Self {
|
||||
let inputs: Vec<f32> = vec![-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0];
|
||||
|
||||
// Pre-computed expected values (approximate)
|
||||
let expected_gelu: Vec<f32> = vec![
|
||||
-0.004, // GELU(-3)
|
||||
-0.045, // GELU(-2)
|
||||
-0.159, // GELU(-1)
|
||||
-0.154, // GELU(-0.5)
|
||||
0.0, // GELU(0)
|
||||
0.346, // GELU(0.5)
|
||||
0.841, // GELU(1)
|
||||
1.955, // GELU(2)
|
||||
2.996, // GELU(3)
|
||||
];
|
||||
|
||||
let expected_silu: Vec<f32> = inputs
|
||||
.iter()
|
||||
.map(|&x: &f32| x / (1.0_f32 + (-x).exp()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
inputs,
|
||||
expected_gelu,
|
||||
expected_silu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Common test data for matrix multiplication tests
|
||||
pub struct MatmulTestData {
|
||||
/// 2x2 matrix A
|
||||
pub a_2x2: Vec<f32>,
|
||||
/// 2x2 matrix B
|
||||
pub b_2x2: Vec<f32>,
|
||||
/// Expected C = A * B (2x2)
|
||||
pub c_2x2: Vec<f32>,
|
||||
/// Identity matrix 2x2
|
||||
pub identity_2x2: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Default for MatmulTestData {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
a_2x2: vec![1.0, 2.0, 3.0, 4.0],
|
||||
b_2x2: vec![5.0, 6.0, 7.0, 8.0],
|
||||
c_2x2: vec![19.0, 22.0, 43.0, 50.0], // A * B
|
||||
identity_2x2: vec![1.0, 0.0, 0.0, 1.0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests for the test utilities
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simple_rng() {
|
||||
let mut rng1 = SimpleRng::new(42);
|
||||
let mut rng2 = SimpleRng::new(42);
|
||||
|
||||
// Same seed should produce same sequence
|
||||
for _ in 0..10 {
|
||||
assert_eq!(rng1.next_u64(), rng2.next_u64());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_tensor_uniform() {
|
||||
let tensor = random_tensor_uniform(100, 0.0, 1.0, 42);
|
||||
assert_eq!(tensor.len(), 100);
|
||||
assert!(tensor.iter().all(|&v| v >= 0.0 && v < 1.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_tensor_normal() {
|
||||
let tensor = random_tensor_normal(1000, 0.0, 1.0, 42);
|
||||
assert_eq!(tensor.len(), 1000);
|
||||
|
||||
// Check approximate mean (should be close to 0)
|
||||
let mean: f32 = tensor.iter().sum::<f32>() / tensor.len() as f32;
|
||||
assert!(mean.abs() < 0.2, "Mean should be close to 0, got {}", mean);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequential_tensor() {
|
||||
let tensor = sequential_tensor(5, 0.0, 1.0);
|
||||
assert_eq!(tensor, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identity_matrix() {
|
||||
let identity = identity_matrix(3);
|
||||
assert_eq!(identity, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0,]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compare_tensors_equal() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
let result = compare_tensors(&a, &b, &CompareConfig::default());
|
||||
assert!(result.equal);
|
||||
assert_eq!(result.num_diffs, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compare_tensors_within_tolerance() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.00001, 2.00001, 3.00001];
|
||||
let result = compare_tensors(&a, &b, &CompareConfig::default());
|
||||
assert!(result.equal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compare_tensors_different() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.0, 2.5, 3.0]; // Middle element differs
|
||||
let config = CompareConfig::strict();
|
||||
let result = compare_tensors(&a, &b, &config);
|
||||
assert!(!result.equal);
|
||||
assert!(result.num_diffs > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_finite() {
|
||||
assert!(all_finite(&[1.0, 2.0, 3.0]));
|
||||
assert!(!all_finite(&[1.0, f32::NAN, 3.0]));
|
||||
assert!(!all_finite(&[1.0, f32::INFINITY, 3.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sums_to_one() {
|
||||
assert!(sums_to_one(&[0.25, 0.25, 0.25, 0.25], 1e-5));
|
||||
assert!(sums_to_one(&[0.1, 0.2, 0.3, 0.4], 1e-5));
|
||||
assert!(!sums_to_one(&[0.1, 0.2, 0.3], 1e-5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_benchmark() {
|
||||
let result = benchmark("test_add", 10, || {
|
||||
let _sum: i32 = (0..1000).sum();
|
||||
});
|
||||
assert_eq!(result.iterations, 10);
|
||||
assert!(result.avg_time > Duration::ZERO);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config() {
|
||||
let config = TestModelConfig::tiny();
|
||||
assert_eq!(config.head_dim(), 16); // 32 / 2
|
||||
|
||||
let config = TestModelConfig::default();
|
||||
assert_eq!(config.head_dim(), 16); // 64 / 4
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weight_generator() {
|
||||
let mut gen = TestWeights::new(42);
|
||||
|
||||
let linear = gen.linear(64, 128);
|
||||
assert_eq!(linear.len(), 64 * 128);
|
||||
|
||||
let bias = gen.bias(128);
|
||||
assert_eq!(bias.len(), 128);
|
||||
|
||||
let ln_weight = gen.layer_norm_weight(64);
|
||||
assert!(ln_weight.iter().all(|&v| v == 1.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_test_data() {
|
||||
let data = MatmulTestData::default();
|
||||
assert_eq!(data.a_2x2.len(), 4);
|
||||
assert_eq!(data.c_2x2, vec![19.0, 22.0, 43.0, 50.0]);
|
||||
}
|
||||
}
|
||||
692
crates/ruvllm/tests/autodetect_integration.rs
Normal file
692
crates/ruvllm/tests/autodetect_integration.rs
Normal file
@@ -0,0 +1,692 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Auto-Detection Integration Tests
|
||||
//!
|
||||
//! Tests the system capabilities detection, optimal configuration generation,
|
||||
//! and intelligent hardware-aware settings for LLM inference using the
|
||||
//! actual autodetect module.
|
||||
|
||||
use ruvllm::autodetect::{
|
||||
Architecture, ComputeBackend, CoreInfo, CpuFeatures, GpuBackend, GpuCapabilities,
|
||||
InferenceConfig, Platform, SystemCapabilities,
|
||||
};
|
||||
use ruvllm::backends::Quantization;
|
||||
use std::collections::HashSet;
|
||||
|
||||
// ============================================================================
|
||||
// System Detection Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_system_capabilities_detection() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
|
||||
// Platform detection
|
||||
#[cfg(target_os = "macos")]
|
||||
assert_eq!(caps.platform, Platform::MacOS);
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
assert_eq!(caps.platform, Platform::Linux);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
assert_eq!(caps.platform, Platform::Windows);
|
||||
|
||||
// Architecture detection
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
assert_eq!(caps.arch, Architecture::Aarch64);
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
assert_eq!(caps.arch, Architecture::X86_64);
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
assert_eq!(caps.arch, Architecture::Wasm32);
|
||||
|
||||
// CPU features should have baseline set
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
assert!(
|
||||
caps.cpu_features.neon,
|
||||
"NEON should be available on aarch64"
|
||||
);
|
||||
|
||||
// Memory should be positive
|
||||
assert!(caps.memory_mb > 0, "Memory should be detected");
|
||||
|
||||
// Cores should be positive
|
||||
assert!(
|
||||
caps.cores.physical_cores > 0,
|
||||
"Physical cores should be detected"
|
||||
);
|
||||
assert!(
|
||||
caps.cores.logical_cores > 0,
|
||||
"Logical cores should be detected"
|
||||
);
|
||||
assert!(
|
||||
caps.cores.logical_cores >= caps.cores.physical_cores,
|
||||
"Logical cores should be >= physical cores"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optimal_config_generation() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
let config = caps.optimal_config();
|
||||
|
||||
// Verify reasonable defaults
|
||||
assert!(config.batch_size >= 1, "Batch size should be at least 1");
|
||||
assert!(
|
||||
config.thread_count >= 1,
|
||||
"Thread count should be at least 1"
|
||||
);
|
||||
assert!(config.block_size >= 16, "Block size should be at least 16");
|
||||
|
||||
// Thread count should not exceed logical cores
|
||||
assert!(
|
||||
config.thread_count <= caps.cores.logical_cores,
|
||||
"Thread count {} should not exceed logical cores {}",
|
||||
config.thread_count,
|
||||
caps.cores.logical_cores
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_recommendation_small_model() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
|
||||
// Small model (3GB) - should use FP16 or Q8 on most systems
|
||||
let q_small = caps.optimal_quantization(3.0);
|
||||
|
||||
if caps.memory_mb >= 16384 {
|
||||
// With 16GB+ RAM, FP16 or Q8 should be recommended
|
||||
assert!(
|
||||
matches!(q_small, Quantization::F16 | Quantization::Q8),
|
||||
"Small model with 16GB+ RAM should use F16 or Q8, got {:?}",
|
||||
q_small
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_recommendation_large_model() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
|
||||
// Large model (70GB) - should use Q4K or Q4
|
||||
let q_large = caps.optimal_quantization(70.0);
|
||||
|
||||
// Unless you have 256GB+ RAM, this should be Q4K or Q4
|
||||
if caps.memory_mb < 256 * 1024 {
|
||||
assert!(
|
||||
matches!(
|
||||
q_large,
|
||||
Quantization::Q4K | Quantization::Q4 | Quantization::Q2K
|
||||
),
|
||||
"Large model should use aggressive quantization, got {:?}",
|
||||
q_large
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_config_matches_manual() {
|
||||
let auto = InferenceConfig::auto();
|
||||
let caps = SystemCapabilities::detect();
|
||||
let manual = caps.optimal_config();
|
||||
|
||||
// Auto should produce same result as manual
|
||||
assert_eq!(
|
||||
auto.batch_size, manual.batch_size,
|
||||
"Auto batch size should match manual"
|
||||
);
|
||||
assert_eq!(
|
||||
auto.thread_count, manual.thread_count,
|
||||
"Auto thread count should match manual"
|
||||
);
|
||||
assert_eq!(
|
||||
auto.block_size, manual.block_size,
|
||||
"Auto block size should match manual"
|
||||
);
|
||||
assert_eq!(
|
||||
auto.compute_backend, manual.compute_backend,
|
||||
"Auto compute backend should match manual"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_platform_specific_gpu_detection() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
{
|
||||
// Apple Silicon should detect Metal
|
||||
assert!(caps.gpu.is_some(), "Apple Silicon should have GPU");
|
||||
let gpu = caps.gpu.as_ref().unwrap();
|
||||
assert_eq!(gpu.backend, GpuBackend::Metal);
|
||||
}
|
||||
|
||||
#[cfg(all(target_os = "macos", target_arch = "x86_64"))]
|
||||
{
|
||||
// Intel Mac should detect Metal
|
||||
assert!(caps.gpu.is_some(), "Intel Mac should have GPU");
|
||||
let gpu = caps.gpu.as_ref().unwrap();
|
||||
assert_eq!(gpu.backend, GpuBackend::Metal);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cpu_feature_detection_aarch64() {
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
let features = CpuFeatures::detect();
|
||||
|
||||
// NEON is mandatory on aarch64
|
||||
assert!(features.neon, "NEON must be available on aarch64");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cpu_feature_detection_x86_64() {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
let features = CpuFeatures::detect();
|
||||
|
||||
// SSE4.2 should be common on modern x86_64
|
||||
// Note: This depends on compile-time detection or runtime check
|
||||
println!(
|
||||
"SSE4.2: {}, AVX2: {}, AVX-512: {}",
|
||||
features.sse42, features.avx2, features.avx512
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_detection() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
|
||||
// Memory should be in reasonable range (256MB to 1TB)
|
||||
assert!(caps.memory_mb >= 256, "Memory should be at least 256MB");
|
||||
assert!(
|
||||
caps.memory_mb <= 1024 * 1024,
|
||||
"Memory should be at most 1TB"
|
||||
);
|
||||
|
||||
println!(
|
||||
"Detected memory: {} MB ({:.1} GB)",
|
||||
caps.memory_mb,
|
||||
caps.memory_mb as f64 / 1024.0
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_core_count_detection() {
|
||||
let cores = CoreInfo::detect();
|
||||
|
||||
// Physical cores should be reasonable
|
||||
assert!(
|
||||
cores.physical_cores >= 1,
|
||||
"Should have at least 1 physical core"
|
||||
);
|
||||
assert!(
|
||||
cores.physical_cores <= 256,
|
||||
"Should have at most 256 physical cores"
|
||||
);
|
||||
|
||||
// Logical cores should be >= physical
|
||||
assert!(
|
||||
cores.logical_cores >= cores.physical_cores,
|
||||
"Logical cores {} should >= physical cores {}",
|
||||
cores.logical_cores,
|
||||
cores.physical_cores
|
||||
);
|
||||
|
||||
println!(
|
||||
"Detected cores: {} physical, {} logical",
|
||||
cores.physical_cores, cores.logical_cores
|
||||
);
|
||||
|
||||
// Check heterogeneous cores on Apple Silicon
|
||||
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
|
||||
{
|
||||
if let (Some(perf), Some(eff)) = (cores.performance_cores, cores.efficiency_cores) {
|
||||
println!(" Performance cores: {}, Efficiency cores: {}", perf, eff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recommended_batch_size_scaling() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
|
||||
// Test that batch size decreases with longer sequences
|
||||
let batch_512 = caps.recommended_batch_size(512);
|
||||
let batch_4096 = caps.recommended_batch_size(4096);
|
||||
let batch_16384 = caps.recommended_batch_size(16384);
|
||||
|
||||
assert!(
|
||||
batch_512 >= batch_4096,
|
||||
"Shorter sequences should allow larger batches"
|
||||
);
|
||||
assert!(
|
||||
batch_4096 >= batch_16384,
|
||||
"Medium sequences should allow larger batches than long ones"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inference_config_presets() {
|
||||
let auto = InferenceConfig::auto();
|
||||
let low_mem = InferenceConfig::low_memory();
|
||||
let high_throughput = InferenceConfig::high_throughput();
|
||||
let low_latency = InferenceConfig::low_latency();
|
||||
|
||||
// Low memory should use aggressive quantization
|
||||
assert!(
|
||||
matches!(
|
||||
low_mem.quantization,
|
||||
Quantization::Q4 | Quantization::Q4K | Quantization::Q2K
|
||||
),
|
||||
"Low memory config should use aggressive quantization"
|
||||
);
|
||||
assert_eq!(low_mem.batch_size, 1, "Low memory should use batch size 1");
|
||||
|
||||
// Low latency should use batch size 1
|
||||
assert_eq!(
|
||||
low_latency.batch_size, 1,
|
||||
"Low latency should use batch size 1"
|
||||
);
|
||||
|
||||
// All configs should have flash attention enabled
|
||||
assert!(auto.use_flash_attention);
|
||||
assert!(low_mem.use_flash_attention);
|
||||
assert!(high_throughput.use_flash_attention);
|
||||
assert!(low_latency.use_flash_attention);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_backend_selection() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
let config = caps.optimal_config();
|
||||
|
||||
// On macOS with GPU, should select Metal
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
if caps.gpu.is_some() {
|
||||
assert_eq!(
|
||||
config.compute_backend,
|
||||
ComputeBackend::Metal,
|
||||
"Should select Metal on macOS with GPU"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// On aarch64 without GPU, should select NEON
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
if caps.gpu.is_none() {
|
||||
assert_eq!(
|
||||
config.compute_backend,
|
||||
ComputeBackend::CpuNeon,
|
||||
"Should select NEON on aarch64 without GPU"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify GPU backends are detected as GPU
|
||||
assert!(ComputeBackend::Metal.is_gpu());
|
||||
assert!(ComputeBackend::Cuda.is_gpu());
|
||||
assert!(ComputeBackend::WebGPU.is_gpu());
|
||||
assert!(!ComputeBackend::CpuNeon.is_gpu());
|
||||
assert!(!ComputeBackend::CpuAvx2.is_gpu());
|
||||
assert!(!ComputeBackend::CpuScalar.is_gpu());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_summary() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
let summary = caps.summary();
|
||||
|
||||
println!("System Summary: {}", summary);
|
||||
|
||||
// Summary should contain useful information
|
||||
assert!(!summary.is_empty(), "Summary should not be empty");
|
||||
assert!(
|
||||
summary.contains("cores") || summary.contains("RAM"),
|
||||
"Summary should contain cores or RAM info"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_can_run_model() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
|
||||
// Should be able to run a tiny model
|
||||
assert!(caps.can_run_model(0.1), "Should be able to run 100MB model");
|
||||
|
||||
// Likely can't run a 1TB model
|
||||
assert!(
|
||||
!caps.can_run_model(1000.0),
|
||||
"Should not be able to run 1TB model"
|
||||
);
|
||||
|
||||
// Test boundary conditions
|
||||
// Note: can_run_model uses available_memory_mb which defaults to memory_mb / 2
|
||||
let available_gb = caps.available_memory_mb.unwrap_or(caps.memory_mb / 2) as f32 / 1024.0;
|
||||
let max_model = (available_gb - 2.0) / 0.4; // Reverse the formula from can_run_model
|
||||
|
||||
if max_model > 0.0 {
|
||||
// Should be able to run a model slightly smaller than max
|
||||
assert!(
|
||||
caps.can_run_model(max_model * 0.8),
|
||||
"Should be able to run model at 80% of max"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimated_tokens_per_second() {
|
||||
let auto = InferenceConfig::auto();
|
||||
let tps = auto.estimated_tokens_per_second();
|
||||
|
||||
assert!(tps > 0.0, "Estimated tokens per second should be positive");
|
||||
|
||||
// Metal and CUDA should have higher estimates than CPU
|
||||
let metal_tps = {
|
||||
let mut config = auto.clone();
|
||||
config.compute_backend = ComputeBackend::Metal;
|
||||
config.estimated_tokens_per_second()
|
||||
};
|
||||
|
||||
let cpu_tps = {
|
||||
let mut config = auto.clone();
|
||||
config.compute_backend = ComputeBackend::CpuScalar;
|
||||
config.estimated_tokens_per_second()
|
||||
};
|
||||
|
||||
assert!(
|
||||
metal_tps > cpu_tps,
|
||||
"Metal should have higher estimated TPS than CPU scalar"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hardware Fingerprinting Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_hardware_fingerprint_stability() {
|
||||
// Run detection multiple times and verify consistency
|
||||
let cap1 = SystemCapabilities::detect();
|
||||
let cap2 = SystemCapabilities::detect();
|
||||
|
||||
assert_eq!(cap1.platform, cap2.platform);
|
||||
assert_eq!(cap1.arch, cap2.arch);
|
||||
assert_eq!(cap1.cores.logical_cores, cap2.cores.logical_cores);
|
||||
assert_eq!(cap1.cpu_features.neon, cap2.cpu_features.neon);
|
||||
|
||||
// Memory may vary slightly due to system activity, but should be close
|
||||
let mem_diff = (cap1.memory_mb as i64 - cap2.memory_mb as i64).abs();
|
||||
assert!(mem_diff < 100, "Memory detection should be stable");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_supported_platforms() {
|
||||
// Verify all platform variants are distinct
|
||||
let platforms = vec![
|
||||
Platform::MacOS,
|
||||
Platform::Linux,
|
||||
Platform::Windows,
|
||||
Platform::Wasm,
|
||||
Platform::IOS,
|
||||
Platform::Android,
|
||||
Platform::Unknown,
|
||||
];
|
||||
|
||||
let unique: HashSet<_> = platforms.iter().collect();
|
||||
assert_eq!(unique.len(), 7, "All platform variants should be distinct");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_architecture_variants() {
|
||||
let archs = vec![
|
||||
Architecture::Aarch64,
|
||||
Architecture::X86_64,
|
||||
Architecture::Wasm32,
|
||||
Architecture::Unknown,
|
||||
];
|
||||
|
||||
let unique: HashSet<_> = archs.iter().collect();
|
||||
assert_eq!(
|
||||
unique.len(),
|
||||
4,
|
||||
"All architecture variants should be distinct"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_gpu_backend_variants() {
|
||||
let backends = vec![
|
||||
GpuBackend::Metal,
|
||||
GpuBackend::Cuda,
|
||||
GpuBackend::WebGPU,
|
||||
GpuBackend::Vulkan,
|
||||
GpuBackend::OpenCL,
|
||||
];
|
||||
|
||||
let unique: HashSet<_> = backends.iter().collect();
|
||||
assert_eq!(
|
||||
unique.len(),
|
||||
5,
|
||||
"All GPU backend variants should be distinct"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_compute_backend_variants() {
|
||||
let backends = vec![
|
||||
ComputeBackend::Metal,
|
||||
ComputeBackend::Cuda,
|
||||
ComputeBackend::WebGPU,
|
||||
ComputeBackend::CpuAvx512,
|
||||
ComputeBackend::CpuAvx2,
|
||||
ComputeBackend::CpuNeon,
|
||||
ComputeBackend::CpuScalar,
|
||||
];
|
||||
|
||||
let unique: HashSet<_> = backends.iter().collect();
|
||||
assert_eq!(
|
||||
unique.len(),
|
||||
7,
|
||||
"All compute backend variants should be distinct"
|
||||
);
|
||||
|
||||
// Verify relative performance ordering
|
||||
assert!(
|
||||
ComputeBackend::Cuda.relative_performance() > ComputeBackend::Metal.relative_performance()
|
||||
);
|
||||
assert!(
|
||||
ComputeBackend::Metal.relative_performance()
|
||||
> ComputeBackend::CpuAvx512.relative_performance()
|
||||
);
|
||||
assert!(
|
||||
ComputeBackend::CpuAvx512.relative_performance()
|
||||
> ComputeBackend::CpuAvx2.relative_performance()
|
||||
);
|
||||
assert!(
|
||||
ComputeBackend::CpuAvx2.relative_performance()
|
||||
>= ComputeBackend::CpuNeon.relative_performance()
|
||||
);
|
||||
assert!(
|
||||
ComputeBackend::CpuNeon.relative_performance()
|
||||
> ComputeBackend::CpuScalar.relative_performance()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_can_fit_model() {
|
||||
// Test with a synthetic GPU
|
||||
let gpu = GpuCapabilities {
|
||||
backend: GpuBackend::Metal,
|
||||
vram_mb: Some(16 * 1024), // 16GB
|
||||
compute_units: Some(128),
|
||||
name: Some("Test GPU".to_string()),
|
||||
supports_fp16: true,
|
||||
supports_int8: true,
|
||||
has_tensor_cores: true,
|
||||
max_shared_memory: Some(32 * 1024),
|
||||
};
|
||||
|
||||
// 16GB should fit 7B model (needs ~10GB with overhead)
|
||||
assert!(gpu.can_fit_model(7.0), "16GB VRAM should fit 7B model");
|
||||
|
||||
// 16GB should not fit 70B model (needs ~100GB)
|
||||
assert!(
|
||||
!gpu.can_fit_model(70.0),
|
||||
"16GB VRAM should not fit 70B model"
|
||||
);
|
||||
|
||||
// Edge case: unknown VRAM
|
||||
let gpu_unknown = GpuCapabilities {
|
||||
backend: GpuBackend::Metal,
|
||||
vram_mb: None,
|
||||
compute_units: None,
|
||||
name: Some("Unknown GPU".to_string()),
|
||||
supports_fp16: true,
|
||||
supports_int8: true,
|
||||
has_tensor_cores: false,
|
||||
max_shared_memory: None,
|
||||
};
|
||||
|
||||
// Unknown VRAM should assume it can fit (optimistic)
|
||||
assert!(
|
||||
gpu_unknown.can_fit_model(7.0),
|
||||
"Unknown VRAM should optimistically assume model fits"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// System Capabilities Display Test
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_system_capabilities_display() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
|
||||
println!("\n=== System Capabilities ===");
|
||||
println!("Platform: {:?}", caps.platform);
|
||||
println!("Architecture: {:?}", caps.arch);
|
||||
println!(
|
||||
"Memory: {} MB ({:.1} GB)",
|
||||
caps.memory_mb,
|
||||
caps.memory_mb as f64 / 1024.0
|
||||
);
|
||||
println!(
|
||||
"Cores: {} physical, {} logical",
|
||||
caps.cores.physical_cores, caps.cores.logical_cores
|
||||
);
|
||||
|
||||
if let Some(ref gpu) = caps.gpu {
|
||||
println!("GPU: {:?} - {:?}", gpu.backend, gpu.name);
|
||||
if let Some(vram) = gpu.vram_mb {
|
||||
println!(" VRAM: {} MB", vram);
|
||||
}
|
||||
println!(
|
||||
" FP16: {}, INT8: {}, Tensor Cores: {}",
|
||||
gpu.supports_fp16, gpu.supports_int8, gpu.has_tensor_cores
|
||||
);
|
||||
} else {
|
||||
println!("GPU: None");
|
||||
}
|
||||
|
||||
println!("\nCPU Features:");
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
println!(" NEON: {}", caps.cpu_features.neon);
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
println!(" SSE4.2: {}", caps.cpu_features.sse42);
|
||||
println!(" AVX2: {}", caps.cpu_features.avx2);
|
||||
println!(" AVX-512: {}", caps.cpu_features.avx512);
|
||||
}
|
||||
|
||||
println!(
|
||||
" Best SIMD width: {} bits",
|
||||
caps.cpu_features.best_simd_width()
|
||||
);
|
||||
println!(
|
||||
" SIMD float lanes: {}",
|
||||
caps.cpu_features.simd_float_lanes()
|
||||
);
|
||||
|
||||
let config = caps.optimal_config();
|
||||
println!("\n=== Optimal Configuration ===");
|
||||
println!("Compute Backend: {:?}", config.compute_backend);
|
||||
println!("Quantization: {:?}", config.quantization);
|
||||
println!("Batch Size: {}", config.batch_size);
|
||||
println!("Thread Count: {}", config.thread_count);
|
||||
println!("Block Size: {}", config.block_size);
|
||||
println!("Flash Attention: {}", config.use_flash_attention);
|
||||
println!("Device Type: {:?}", config.device_type);
|
||||
println!("DType: {:?}", config.dtype);
|
||||
println!("Estimated TPS: {:.1}", config.estimated_tokens_per_second());
|
||||
|
||||
println!("\n=== Summary ===");
|
||||
println!("{}", caps.summary());
|
||||
|
||||
// Test passes if we get here without panicking
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Attention Config Integration
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_optimal_attention_config() {
|
||||
let caps = SystemCapabilities::detect();
|
||||
let attn_config = caps.optimal_attention_config();
|
||||
|
||||
// Verify reasonable attention configuration
|
||||
assert!(attn_config.num_heads > 0, "Should have at least 1 head");
|
||||
assert!(
|
||||
attn_config.num_kv_heads > 0,
|
||||
"Should have at least 1 KV head"
|
||||
);
|
||||
assert!(attn_config.head_dim > 0, "Should have positive head dim");
|
||||
assert!(
|
||||
attn_config.max_seq_len >= 1024,
|
||||
"Should support at least 1K context"
|
||||
);
|
||||
|
||||
// GQA ratio should be valid
|
||||
let gqa_ratio = attn_config.gqa_ratio();
|
||||
assert!(gqa_ratio >= 1, "GQA ratio should be at least 1");
|
||||
assert!(
|
||||
attn_config.num_heads % attn_config.num_kv_heads == 0,
|
||||
"num_heads should be divisible by num_kv_heads"
|
||||
);
|
||||
|
||||
// Scale should be reasonable
|
||||
let scale = attn_config.effective_scale();
|
||||
assert!(
|
||||
scale > 0.0 && scale < 1.0,
|
||||
"Scale should be between 0 and 1"
|
||||
);
|
||||
|
||||
println!(
|
||||
"Attention Config: {} heads, {} KV heads, {} head_dim, {} max_seq_len, GQA {}:1",
|
||||
attn_config.num_heads,
|
||||
attn_config.num_kv_heads,
|
||||
attn_config.head_dim,
|
||||
attn_config.max_seq_len,
|
||||
gqa_ratio
|
||||
);
|
||||
}
|
||||
649
crates/ruvllm/tests/backend_integration.rs
Normal file
649
crates/ruvllm/tests/backend_integration.rs
Normal file
@@ -0,0 +1,649 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Integration tests for LLM backends
|
||||
//!
|
||||
//! Tests the LLM backend infrastructure including model loading,
|
||||
//! text generation, streaming, and embeddings extraction.
|
||||
|
||||
use ruvllm::{
|
||||
backends::{
|
||||
create_backend, DType, DeviceType, GenerateParams, LlmBackend, ModelArchitecture,
|
||||
ModelConfig, ModelInfo, Quantization, SpecialTokens, TokenStream, Tokenizer,
|
||||
},
|
||||
error::Result,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Mock backend for testing without requiring actual model files
|
||||
#[derive(Debug)]
|
||||
struct MockBackend {
|
||||
model_info: Option<ModelInfo>,
|
||||
loaded: bool,
|
||||
}
|
||||
|
||||
impl MockBackend {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
model_info: None,
|
||||
loaded: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmBackend for MockBackend {
|
||||
fn load_model(&mut self, model_id: &str, config: ModelConfig) -> Result<()> {
|
||||
self.model_info = Some(ModelInfo {
|
||||
name: model_id.to_string(),
|
||||
architecture: config.architecture,
|
||||
num_parameters: 100_000,
|
||||
vocab_size: 32000,
|
||||
hidden_size: 768,
|
||||
num_layers: 12,
|
||||
max_context_length: config.max_sequence_length,
|
||||
quantization: config.quantization,
|
||||
memory_usage: 1024 * 1024 * 100, // 100MB
|
||||
});
|
||||
self.loaded = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate(&self, prompt: &str, _params: GenerateParams) -> Result<String> {
|
||||
if !self.loaded {
|
||||
return Err(ruvllm::RuvLLMError::Backend("Model not loaded".to_string()));
|
||||
}
|
||||
Ok(format!("Response to: {}", prompt))
|
||||
}
|
||||
|
||||
fn generate_stream(
|
||||
&self,
|
||||
_prompt: &str,
|
||||
_params: GenerateParams,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<ruvllm::backends::GeneratedToken>> + Send + '_>>
|
||||
{
|
||||
if !self.loaded {
|
||||
return Err(ruvllm::RuvLLMError::Backend("Model not loaded".to_string()));
|
||||
}
|
||||
|
||||
let tokens = vec![
|
||||
ruvllm::backends::GeneratedToken {
|
||||
id: 1,
|
||||
text: "Hello".to_string(),
|
||||
logprob: Some(-0.5),
|
||||
is_special: false,
|
||||
},
|
||||
ruvllm::backends::GeneratedToken {
|
||||
id: 2,
|
||||
text: " world".to_string(),
|
||||
logprob: Some(-0.3),
|
||||
is_special: false,
|
||||
},
|
||||
ruvllm::backends::GeneratedToken {
|
||||
id: 3,
|
||||
text: "!".to_string(),
|
||||
logprob: Some(-0.1),
|
||||
is_special: false,
|
||||
},
|
||||
];
|
||||
|
||||
Ok(Box::new(tokens.into_iter().map(Ok)))
|
||||
}
|
||||
|
||||
fn generate_stream_v2(&self, _prompt: &str, _params: GenerateParams) -> Result<TokenStream> {
|
||||
if !self.loaded {
|
||||
return Err(ruvllm::RuvLLMError::Backend("Model not loaded".to_string()));
|
||||
}
|
||||
// Return a mock stream using channel
|
||||
let (tx, stream) = TokenStream::channel();
|
||||
// Drop tx immediately since we don't need to send anything for this mock
|
||||
drop(tx);
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn get_embeddings(&self, _text: &str) -> Result<Vec<f32>> {
|
||||
if !self.loaded {
|
||||
return Err(ruvllm::RuvLLMError::Backend("Model not loaded".to_string()));
|
||||
}
|
||||
// Return a mock embedding
|
||||
Ok(vec![0.1; 768])
|
||||
}
|
||||
|
||||
fn tokenizer(&self) -> Option<&dyn Tokenizer> {
|
||||
None
|
||||
}
|
||||
|
||||
fn is_model_loaded(&self) -> bool {
|
||||
self.loaded
|
||||
}
|
||||
|
||||
fn model_info(&self) -> Option<ModelInfo> {
|
||||
self.model_info.clone()
|
||||
}
|
||||
|
||||
fn unload_model(&mut self) {
|
||||
self.loaded = false;
|
||||
self.model_info = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_backend_load_model() {
|
||||
let mut backend = MockBackend::new();
|
||||
|
||||
// Initially not loaded
|
||||
assert!(!backend.is_model_loaded());
|
||||
assert!(backend.model_info().is_none());
|
||||
|
||||
// Load model
|
||||
let config = ModelConfig::default();
|
||||
let result = backend.load_model("test-model", config);
|
||||
assert!(result.is_ok());
|
||||
assert!(backend.is_model_loaded());
|
||||
assert!(backend.model_info().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_generate_basic() {
|
||||
let mut backend = MockBackend::new();
|
||||
backend
|
||||
.load_model("test-model", ModelConfig::default())
|
||||
.unwrap();
|
||||
|
||||
let params = GenerateParams {
|
||||
max_tokens: 100,
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
top_k: 40,
|
||||
repetition_penalty: 1.1,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop_sequences: vec![],
|
||||
seed: Some(42),
|
||||
};
|
||||
|
||||
let result = backend.generate("Hello, how are you?", params);
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(!output.is_empty());
|
||||
assert!(output.contains("Hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_generate_requires_loaded_model() {
|
||||
let backend = MockBackend::new();
|
||||
|
||||
let params = GenerateParams::default();
|
||||
let result = backend.generate("Test prompt", params);
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_streaming() {
|
||||
let mut backend = MockBackend::new();
|
||||
backend
|
||||
.load_model("test-model", ModelConfig::default())
|
||||
.unwrap();
|
||||
|
||||
let params = GenerateParams::default();
|
||||
let stream = backend.generate_stream("Hello", params).unwrap();
|
||||
|
||||
let tokens: Vec<_> = stream.collect();
|
||||
assert_eq!(tokens.len(), 3);
|
||||
|
||||
let first = tokens[0].as_ref().unwrap();
|
||||
assert_eq!(first.text, "Hello");
|
||||
assert_eq!(first.id, 1);
|
||||
assert!(!first.is_special);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_embeddings() {
|
||||
let mut backend = MockBackend::new();
|
||||
backend
|
||||
.load_model("test-model", ModelConfig::default())
|
||||
.unwrap();
|
||||
|
||||
let embedding = backend.get_embeddings("Test text for embedding").unwrap();
|
||||
|
||||
assert_eq!(embedding.len(), 768);
|
||||
assert!(embedding.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_model_info() {
|
||||
let mut backend = MockBackend::new();
|
||||
|
||||
let config = ModelConfig {
|
||||
architecture: ModelArchitecture::Llama,
|
||||
max_sequence_length: 4096,
|
||||
quantization: Some(Quantization::Q4K),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
backend.load_model("llama-test", config).unwrap();
|
||||
let info = backend.model_info().unwrap();
|
||||
|
||||
assert_eq!(info.name, "llama-test");
|
||||
assert_eq!(info.max_context_length, 4096);
|
||||
assert!(matches!(info.architecture, ModelArchitecture::Llama));
|
||||
assert!(matches!(info.quantization, Some(Quantization::Q4K)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_unload() {
|
||||
let mut backend = MockBackend::new();
|
||||
backend
|
||||
.load_model("test-model", ModelConfig::default())
|
||||
.unwrap();
|
||||
assert!(backend.is_model_loaded());
|
||||
|
||||
backend.unload_model();
|
||||
assert!(!backend.is_model_loaded());
|
||||
assert!(backend.model_info().is_none());
|
||||
|
||||
// Should fail after unload
|
||||
let result = backend.generate("Test", GenerateParams::default());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config() {
|
||||
let config = ModelConfig {
|
||||
architecture: ModelArchitecture::Mistral,
|
||||
device: DeviceType::Cpu,
|
||||
dtype: DType::F32,
|
||||
quantization: Some(Quantization::Q4K),
|
||||
use_flash_attention: true,
|
||||
max_sequence_length: 4096,
|
||||
num_kv_heads: Some(8),
|
||||
hidden_size: Some(4096),
|
||||
num_layers: Some(32),
|
||||
vocab_size: Some(32000),
|
||||
rope_theta: Some(10000.0),
|
||||
sliding_window: None,
|
||||
};
|
||||
|
||||
assert!(matches!(config.device, DeviceType::Cpu));
|
||||
assert!(matches!(config.dtype, DType::F32));
|
||||
assert!(matches!(config.quantization, Some(Quantization::Q4K)));
|
||||
assert!(config.use_flash_attention);
|
||||
assert_eq!(config.max_sequence_length, 4096);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_params_default() {
|
||||
let params = GenerateParams::default();
|
||||
|
||||
assert!(params.max_tokens > 0);
|
||||
assert!(params.temperature > 0.0);
|
||||
assert!(params.top_p <= 1.0);
|
||||
assert!(params.top_k > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_params_builder() {
|
||||
let params = GenerateParams::default()
|
||||
.with_max_tokens(512)
|
||||
.with_temperature(0.5)
|
||||
.with_top_p(0.95)
|
||||
.with_top_k(50)
|
||||
.with_repetition_penalty(1.2)
|
||||
.with_seed(42);
|
||||
|
||||
assert_eq!(params.max_tokens, 512);
|
||||
assert_eq!(params.temperature, 0.5);
|
||||
assert_eq!(params.top_p, 0.95);
|
||||
assert_eq!(params.top_k, 50);
|
||||
assert_eq!(params.repetition_penalty, 1.2);
|
||||
assert_eq!(params.seed, Some(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_variants() {
|
||||
let q4 = Quantization::Q4;
|
||||
let q8 = Quantization::Q8;
|
||||
let q4k = Quantization::Q4K;
|
||||
let f16 = Quantization::F16;
|
||||
|
||||
assert!(q4.is_gguf());
|
||||
assert!(q8.is_gguf());
|
||||
assert!(q4k.is_gguf());
|
||||
assert!(!f16.is_gguf());
|
||||
|
||||
// Check bytes per weight
|
||||
assert_eq!(Quantization::None.bytes_per_weight(), 4.0);
|
||||
assert_eq!(Quantization::F16.bytes_per_weight(), 2.0);
|
||||
assert_eq!(Quantization::Q8.bytes_per_weight(), 1.0);
|
||||
assert_eq!(Quantization::Q4K.bytes_per_weight(), 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_device_type_variants() {
|
||||
let cpu = DeviceType::Cpu;
|
||||
let metal = DeviceType::Metal;
|
||||
let cuda = DeviceType::Cuda(0);
|
||||
|
||||
assert!(matches!(cpu, DeviceType::Cpu));
|
||||
assert!(matches!(metal, DeviceType::Metal));
|
||||
if let DeviceType::Cuda(idx) = cuda {
|
||||
assert_eq!(idx, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_architecture_variants() {
|
||||
let llama = ModelArchitecture::Llama;
|
||||
let mistral = ModelArchitecture::Mistral;
|
||||
let phi = ModelArchitecture::Phi;
|
||||
let qwen = ModelArchitecture::Qwen;
|
||||
let gemma = ModelArchitecture::Gemma;
|
||||
|
||||
assert_eq!(llama.config_name(), "llama");
|
||||
assert_eq!(mistral.config_name(), "mistral");
|
||||
assert_eq!(phi.config_name(), "phi");
|
||||
assert_eq!(qwen.config_name(), "qwen2");
|
||||
assert_eq!(gemma.config_name(), "gemma");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_variants() {
|
||||
let f32_type = DType::F32;
|
||||
let f16_type = DType::F16;
|
||||
let bf16_type = DType::Bf16;
|
||||
|
||||
assert!(matches!(f32_type, DType::F32));
|
||||
assert!(matches!(f16_type, DType::F16));
|
||||
assert!(matches!(bf16_type, DType::Bf16));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_tokens() {
|
||||
let tokens = SpecialTokens {
|
||||
bos_token_id: Some(1),
|
||||
eos_token_id: Some(2),
|
||||
pad_token_id: Some(0),
|
||||
unk_token_id: Some(3),
|
||||
};
|
||||
|
||||
assert_eq!(tokens.bos_token_id, Some(1));
|
||||
assert_eq!(tokens.eos_token_id, Some(2));
|
||||
assert_eq!(tokens.pad_token_id, Some(0));
|
||||
assert_eq!(tokens.unk_token_id, Some(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_backend() {
|
||||
// This creates a NoopBackend when candle feature is not enabled
|
||||
let backend = create_backend();
|
||||
|
||||
// Without the candle feature, the backend should not be able to load models
|
||||
#[cfg(not(feature = "candle"))]
|
||||
{
|
||||
assert!(!backend.is_model_loaded());
|
||||
}
|
||||
}
|
||||
|
||||
// Candle backend tests (only run when the feature is enabled)
|
||||
#[cfg(feature = "candle")]
|
||||
mod candle_tests {
|
||||
use super::*;
|
||||
use ruvllm::backends::CandleBackend;
|
||||
|
||||
#[test]
|
||||
#[ignore] // Requires model download
|
||||
fn test_candle_backend_creation() {
|
||||
let backend = CandleBackend::new();
|
||||
assert!(backend.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Requires model download
|
||||
fn test_candle_backend_load_model() {
|
||||
let mut backend = CandleBackend::new().unwrap();
|
||||
let config = ModelConfig {
|
||||
architecture: ModelArchitecture::Phi,
|
||||
device: DeviceType::Cpu,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// This would require an actual model file
|
||||
// let result = backend.load_model("microsoft/phi-2", config);
|
||||
// assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
// ========== V2 Feature Tests: Memory Pool Integration ==========
|
||||
|
||||
mod memory_pool_tests {
|
||||
use ruvllm::memory_pool::{
|
||||
BufferPool, BufferSize, InferenceArena, MemoryManager, MemoryManagerConfig,
|
||||
ScratchSpaceManager,
|
||||
};
|
||||
|
||||
/// Test memory pool integration with streaming generation
|
||||
#[test]
|
||||
fn test_memory_pool_integration() {
|
||||
let pool = BufferPool::new();
|
||||
|
||||
// Pre-warm the pool
|
||||
pool.prewarm_all(4).expect("prewarm failed");
|
||||
|
||||
// Simulate multiple generation steps
|
||||
for step in 0..10 {
|
||||
// Acquire buffers for KV cache
|
||||
let kv_buffer = pool.acquire(BufferSize::KB64).expect("acquire failed");
|
||||
assert_eq!(kv_buffer.capacity(), 65536);
|
||||
|
||||
// Simulate processing
|
||||
let data = kv_buffer.as_slice::<f32>();
|
||||
assert!(!data.is_empty());
|
||||
|
||||
// Buffer returns to pool when dropped
|
||||
}
|
||||
|
||||
// Check pool statistics
|
||||
let stats = pool.stats();
|
||||
assert!(stats.hits + stats.misses > 0, "Pool should have been used");
|
||||
|
||||
// Hit rate should be decent after warm-up
|
||||
if stats.hits + stats.misses >= 10 {
|
||||
assert!(
|
||||
stats.hit_rate > 0.5,
|
||||
"Pool hit rate should be decent: {:.2}",
|
||||
stats.hit_rate
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test streaming with memory pool
|
||||
#[test]
|
||||
fn test_streaming_with_pool() {
|
||||
let manager = MemoryManager::new().expect("manager creation failed");
|
||||
|
||||
// Simulate streaming generation
|
||||
for token_idx in 0..100 {
|
||||
// Reset arena at start of each step
|
||||
manager.reset_step();
|
||||
|
||||
// Allocate temporary buffers from arena
|
||||
let activations: &mut [f32] = manager.arena.alloc(1024).expect("arena alloc failed");
|
||||
activations[0] = token_idx as f32;
|
||||
|
||||
let logits: &mut [f32] = manager.arena.alloc(32000).expect("arena alloc for logits");
|
||||
logits[0] = token_idx as f32 * 0.1;
|
||||
|
||||
// Acquire KV cache buffer from pool
|
||||
let kv_buf = manager
|
||||
.pool
|
||||
.acquire(BufferSize::KB16)
|
||||
.expect("acquire failed");
|
||||
assert!(kv_buf.capacity() >= 16384);
|
||||
|
||||
// Use scratch space for intermediate computations
|
||||
let mut scratch = manager.scratch.get_scratch().expect("get_scratch failed");
|
||||
if let Some(temp) = scratch.get::<f32>(256) {
|
||||
temp.fill(1.0);
|
||||
assert_eq!(temp.len(), 256);
|
||||
}
|
||||
|
||||
// Verify arena usage grows
|
||||
assert!(manager.arena.used() > 0);
|
||||
}
|
||||
|
||||
// Verify final statistics
|
||||
let stats = manager.stats();
|
||||
assert!(stats.pool.hits + stats.pool.misses > 0);
|
||||
assert!(stats.arena.high_water_mark > 0);
|
||||
}
|
||||
|
||||
/// Test arena allocation and reset cycle
|
||||
#[test]
|
||||
fn test_arena_allocation_cycle() {
|
||||
let arena = InferenceArena::new(4 * 1024 * 1024).expect("arena creation failed"); // 4MB
|
||||
|
||||
for cycle in 0..50 {
|
||||
// Allocate various buffer sizes
|
||||
let buf1: &mut [f32] = arena.alloc(4096).expect("alloc 4096");
|
||||
let buf2: &mut [f32] = arena.alloc(8192).expect("alloc 8192");
|
||||
let buf3: &mut [f32] = arena.alloc(1024).expect("alloc 1024");
|
||||
|
||||
// Write to buffers
|
||||
buf1[0] = cycle as f32;
|
||||
buf2[0] = cycle as f32 * 2.0;
|
||||
buf3[0] = cycle as f32 * 3.0;
|
||||
|
||||
// Verify allocations
|
||||
assert_eq!(arena.allocation_count(), 3);
|
||||
assert!(arena.used() > 0);
|
||||
|
||||
// Reset for next cycle
|
||||
arena.reset();
|
||||
assert_eq!(arena.used(), 0);
|
||||
assert_eq!(arena.allocation_count(), 0);
|
||||
}
|
||||
|
||||
// High water mark should be set
|
||||
assert!(arena.high_water_mark() > 0);
|
||||
}
|
||||
|
||||
/// Test buffer pool reuse efficiency
|
||||
#[test]
|
||||
fn test_buffer_pool_reuse() {
|
||||
let pool = BufferPool::with_capacity(8);
|
||||
|
||||
// Acquire and release same size multiple times
|
||||
for _ in 0..20 {
|
||||
let buf = pool.acquire(BufferSize::KB4).expect("acquire failed");
|
||||
assert_eq!(buf.capacity(), 4096);
|
||||
// Buffer returns to pool on drop
|
||||
}
|
||||
|
||||
let stats = pool.stats();
|
||||
// After first allocation, subsequent ones should hit the pool
|
||||
assert!(
|
||||
stats.hits >= 19,
|
||||
"Expected at least 19 hits, got {}",
|
||||
stats.hits
|
||||
);
|
||||
}
|
||||
|
||||
/// Test scratch space thread isolation
|
||||
#[test]
|
||||
fn test_scratch_space_isolation() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let manager = Arc::new(ScratchSpaceManager::new(8192, 8).expect("manager creation failed"));
|
||||
|
||||
let handles: Vec<_> = (0..4)
|
||||
.map(|thread_id| {
|
||||
let manager = Arc::clone(&manager);
|
||||
thread::spawn(move || {
|
||||
for _ in 0..10 {
|
||||
let mut scratch = manager.get_scratch().expect("get_scratch failed");
|
||||
|
||||
// Each thread writes its ID
|
||||
if let Some(buf) = scratch.get::<u32>(100) {
|
||||
buf.fill(thread_id);
|
||||
// Verify no cross-thread contamination
|
||||
assert!(buf.iter().all(|&v| v == thread_id));
|
||||
}
|
||||
|
||||
scratch.reset();
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
|
||||
// Verify 4 threads were tracked
|
||||
assert_eq!(manager.active_threads(), 4);
|
||||
}
|
||||
|
||||
/// Test memory manager configuration for model
|
||||
#[test]
|
||||
fn test_memory_manager_for_model() {
|
||||
// Configure for a small LLM (e.g., Phi-2)
|
||||
let config = MemoryManagerConfig::for_model(
|
||||
2560, // hidden_dim
|
||||
51200, // vocab_size
|
||||
1, // batch_size
|
||||
);
|
||||
|
||||
let manager = MemoryManager::with_config(config).expect("manager creation failed");
|
||||
|
||||
// Verify adequate capacity
|
||||
assert!(manager.arena.capacity() > 2560 * 4 * 4); // At least hidden_dim * 4 * sizeof(f32)
|
||||
|
||||
// Simulate inference
|
||||
let activations: &mut [f32] = manager.arena.alloc(2560).expect("alloc activations");
|
||||
let logits: &mut [f32] = manager.arena.alloc(51200).expect("alloc logits");
|
||||
|
||||
assert_eq!(activations.len(), 2560);
|
||||
assert_eq!(logits.len(), 51200);
|
||||
|
||||
// Reset for next step
|
||||
manager.reset_step();
|
||||
assert_eq!(manager.arena.used(), 0);
|
||||
}
|
||||
|
||||
/// Test buffer size class selection
|
||||
#[test]
|
||||
fn test_buffer_size_selection() {
|
||||
let pool = BufferPool::new();
|
||||
|
||||
// Test automatic size class selection
|
||||
if let Some(buf) = pool.acquire_for_size(500).ok().flatten() {
|
||||
assert!(buf.capacity() >= 500);
|
||||
assert_eq!(buf.size_class(), BufferSize::KB1);
|
||||
}
|
||||
|
||||
if let Some(buf) = pool.acquire_for_size(3000).ok().flatten() {
|
||||
assert!(buf.capacity() >= 3000);
|
||||
assert_eq!(buf.size_class(), BufferSize::KB4);
|
||||
}
|
||||
|
||||
if let Some(buf) = pool.acquire_for_size(100000).ok().flatten() {
|
||||
assert!(buf.capacity() >= 100000);
|
||||
assert_eq!(buf.size_class(), BufferSize::KB256);
|
||||
}
|
||||
|
||||
// Size too large should return None
|
||||
let too_large = pool.acquire_for_size(500000).ok().flatten();
|
||||
assert!(too_large.is_none(), "Should not find buffer for 500KB");
|
||||
}
|
||||
}
|
||||
465
crates/ruvllm/tests/cross_platform.rs
Normal file
465
crates/ruvllm/tests/cross_platform.rs
Normal file
@@ -0,0 +1,465 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Cross-platform tests for scalar fallback implementations
|
||||
//!
|
||||
//! These tests verify that the scalar fallback implementations produce
|
||||
//! correct results and work on all platforms (including non-NEON and WASM).
|
||||
|
||||
use ruvllm::kernels::{flash_attention_neon, gemm_neon, gemv_neon, layer_norm_neon, rms_norm_neon};
|
||||
|
||||
// ========== Scalar Reference Implementations ==========
|
||||
|
||||
/// Scalar reference GEMV implementation
|
||||
fn gemv_scalar(a: &[f32], x: &[f32], y: &mut [f32], m: usize, n: usize) {
|
||||
for row in 0..m {
|
||||
let mut sum = 0.0f32;
|
||||
for col in 0..n {
|
||||
sum += a[row * n + col] * x[col];
|
||||
}
|
||||
y[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalar reference GEMM implementation
|
||||
fn gemm_scalar(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
|
||||
c.fill(0.0);
|
||||
for i in 0..m {
|
||||
for j in 0..n {
|
||||
let mut sum = 0.0f32;
|
||||
for kk in 0..k {
|
||||
sum += a[i * k + kk] * b[kk * n + j];
|
||||
}
|
||||
c[i * n + j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalar reference attention implementation
|
||||
fn attention_scalar(
|
||||
query: &[f32],
|
||||
key: &[f32],
|
||||
value: &[f32],
|
||||
head_dim: usize,
|
||||
kv_len: usize,
|
||||
scale: f32,
|
||||
) -> Vec<f32> {
|
||||
// Compute attention scores
|
||||
let mut scores = Vec::with_capacity(kv_len);
|
||||
for t in 0..kv_len {
|
||||
let k_offset = t * head_dim;
|
||||
let score: f32 = query
|
||||
.iter()
|
||||
.zip(&key[k_offset..k_offset + head_dim])
|
||||
.map(|(q, k)| q * k * scale)
|
||||
.sum();
|
||||
scores.push(score);
|
||||
}
|
||||
|
||||
// Softmax
|
||||
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
|
||||
let sum_exp: f32 = exp_scores.iter().sum();
|
||||
let attn_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
|
||||
|
||||
// Weighted sum of values
|
||||
let mut output = vec![0.0; head_dim];
|
||||
for (t, weight) in attn_weights.iter().enumerate() {
|
||||
let v_offset = t * head_dim;
|
||||
for (i, v) in value[v_offset..v_offset + head_dim].iter().enumerate() {
|
||||
output[i] += weight * v;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Scalar reference RMSNorm implementation
|
||||
fn rms_norm_scalar(x: &mut [f32], weight: &[f32], eps: f32) {
|
||||
let len = x.len();
|
||||
let sum_sq: f32 = x.iter().map(|v| v * v).sum();
|
||||
let inv_rms = 1.0 / (sum_sq / len as f32 + eps).sqrt();
|
||||
for (i, w) in weight.iter().enumerate() {
|
||||
x[i] = x[i] * inv_rms * w;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalar reference LayerNorm implementation
|
||||
fn layer_norm_scalar(x: &mut [f32], weight: &[f32], bias: &[f32], eps: f32) {
|
||||
let len = x.len();
|
||||
let mean: f32 = x.iter().sum::<f32>() / len as f32;
|
||||
let var: f32 = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / len as f32;
|
||||
let inv_std = 1.0 / (var + eps).sqrt();
|
||||
|
||||
for i in 0..len {
|
||||
x[i] = (x[i] - mean) * inv_std * weight[i] + bias[i];
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Cross-Platform Tests ==========
|
||||
|
||||
#[test]
|
||||
fn test_cross_platform_gemv() {
|
||||
let test_cases = [
|
||||
(4, 4),
|
||||
(8, 16),
|
||||
(16, 32),
|
||||
(32, 64),
|
||||
(64, 128),
|
||||
(100, 50),
|
||||
(7, 13), // Non-aligned
|
||||
];
|
||||
|
||||
for (m, n) in test_cases {
|
||||
let a: Vec<f32> = (0..m * n)
|
||||
.map(|i| ((i % 100) as f32 - 50.0) / 50.0)
|
||||
.collect();
|
||||
let x: Vec<f32> = (0..n).map(|i| ((i % 20) as f32 - 10.0) / 10.0).collect();
|
||||
|
||||
let mut y_neon = vec![0.0; m];
|
||||
let mut y_scalar = vec![0.0; m];
|
||||
|
||||
gemv_neon(&a, &x, &mut y_neon, m, n);
|
||||
gemv_scalar(&a, &x, &mut y_scalar, m, n);
|
||||
|
||||
for i in 0..m {
|
||||
let abs_error = (y_neon[i] - y_scalar[i]).abs();
|
||||
let rel_error = abs_error / y_scalar[i].abs().max(1e-6);
|
||||
assert!(
|
||||
rel_error < 0.001 || abs_error < 1e-5,
|
||||
"Cross-platform GEMV mismatch at ({},{}) index {}: {} vs {} (rel: {:.6})",
|
||||
m,
|
||||
n,
|
||||
i,
|
||||
y_neon[i],
|
||||
y_scalar[i],
|
||||
rel_error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_platform_gemm() {
|
||||
let test_cases = [
|
||||
(4, 4, 4),
|
||||
(8, 16, 8),
|
||||
(16, 32, 16),
|
||||
(32, 64, 32),
|
||||
(7, 11, 13), // Non-aligned
|
||||
];
|
||||
|
||||
for (m, k, n) in test_cases {
|
||||
let a: Vec<f32> = (0..m * k)
|
||||
.map(|i| ((i % 100) as f32 - 50.0) / 100.0)
|
||||
.collect();
|
||||
let b: Vec<f32> = (0..k * n)
|
||||
.map(|i| ((i % 50) as f32 - 25.0) / 50.0)
|
||||
.collect();
|
||||
|
||||
let mut c_neon = vec![0.0; m * n];
|
||||
let mut c_scalar = vec![0.0; m * n];
|
||||
|
||||
gemm_neon(&a, &b, &mut c_neon, m, k, n);
|
||||
gemm_scalar(&a, &b, &mut c_scalar, m, k, n);
|
||||
|
||||
for i in 0..(m * n) {
|
||||
let abs_error = (c_neon[i] - c_scalar[i]).abs();
|
||||
let rel_error = abs_error / c_scalar[i].abs().max(1e-6);
|
||||
assert!(
|
||||
rel_error < 0.01 || abs_error < 0.001,
|
||||
"Cross-platform GEMM mismatch at ({},{},{}) index {}: {} vs {} (rel: {:.6})",
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
i,
|
||||
c_neon[i],
|
||||
c_scalar[i],
|
||||
rel_error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_platform_attention() {
|
||||
let test_cases = [(16, 4), (32, 8), (64, 16), (128, 32)];
|
||||
|
||||
for (head_dim, kv_len) in test_cases {
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let query: Vec<f32> = (0..head_dim)
|
||||
.map(|i| ((i % 7) as f32 - 3.0) / 10.0)
|
||||
.collect();
|
||||
let key: Vec<f32> = (0..kv_len * head_dim)
|
||||
.map(|i| ((i % 11) as f32 - 5.0) / 20.0)
|
||||
.collect();
|
||||
let value: Vec<f32> = (0..kv_len * head_dim)
|
||||
.map(|i| ((i % 13) as f32 - 6.0) / 15.0)
|
||||
.collect();
|
||||
|
||||
let output_neon = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
let output_scalar = attention_scalar(&query, &key, &value, head_dim, kv_len, scale);
|
||||
|
||||
assert_eq!(output_neon.len(), output_scalar.len());
|
||||
|
||||
for i in 0..head_dim {
|
||||
let abs_error = (output_neon[i] - output_scalar[i]).abs();
|
||||
let rel_error = abs_error / output_scalar[i].abs().max(1e-6);
|
||||
assert!(
|
||||
rel_error < 0.01 || abs_error < 1e-4,
|
||||
"Cross-platform attention mismatch at head_dim={}, kv_len={}, index {}: {} vs {} (rel: {:.6})",
|
||||
head_dim, kv_len, i, output_neon[i], output_scalar[i], rel_error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_platform_rms_norm() {
|
||||
let test_cases = [8, 16, 32, 64, 128];
|
||||
|
||||
for dim in test_cases {
|
||||
let mut x_neon: Vec<f32> = (0..dim)
|
||||
.map(|i| (i as f32 - dim as f32 / 2.0) / 10.0)
|
||||
.collect();
|
||||
let mut x_scalar = x_neon.clone();
|
||||
let weight: Vec<f32> = (0..dim).map(|i| 0.5 + (i as f32) * 0.01).collect();
|
||||
let eps = 1e-6;
|
||||
|
||||
rms_norm_neon(&mut x_neon, &weight, eps);
|
||||
rms_norm_scalar(&mut x_scalar, &weight, eps);
|
||||
|
||||
for i in 0..dim {
|
||||
let abs_error = (x_neon[i] - x_scalar[i]).abs();
|
||||
assert!(
|
||||
abs_error < 1e-4,
|
||||
"Cross-platform RMSNorm mismatch at dim={}, index {}: {} vs {} (abs: {:.6})",
|
||||
dim,
|
||||
i,
|
||||
x_neon[i],
|
||||
x_scalar[i],
|
||||
abs_error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_platform_layer_norm() {
|
||||
let test_cases = [8, 16, 32, 64, 128];
|
||||
|
||||
for dim in test_cases {
|
||||
let mut x_neon: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.1 - 5.0).collect();
|
||||
let mut x_scalar = x_neon.clone();
|
||||
let weight: Vec<f32> = vec![1.0; dim];
|
||||
let bias: Vec<f32> = vec![0.0; dim];
|
||||
let eps = 1e-6;
|
||||
|
||||
layer_norm_neon(&mut x_neon, &weight, &bias, eps);
|
||||
layer_norm_scalar(&mut x_scalar, &weight, &bias, eps);
|
||||
|
||||
for i in 0..dim {
|
||||
let abs_error = (x_neon[i] - x_scalar[i]).abs();
|
||||
assert!(
|
||||
abs_error < 1e-4,
|
||||
"Cross-platform LayerNorm mismatch at dim={}, index {}: {} vs {} (abs: {:.6})",
|
||||
dim,
|
||||
i,
|
||||
x_neon[i],
|
||||
x_scalar[i],
|
||||
abs_error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ========== Edge Case Tests ==========
|
||||
|
||||
#[test]
|
||||
fn test_scalar_fallback_edge_cases() {
|
||||
// Zero vectors
|
||||
let a_zero = vec![0.0f32; 16];
|
||||
let x_zero = vec![0.0f32; 4];
|
||||
let mut y = vec![0.0f32; 4];
|
||||
|
||||
gemv_neon(&a_zero, &x_zero, &mut y, 4, 4);
|
||||
assert!(
|
||||
y.iter().all(|&v| v == 0.0),
|
||||
"Zero input should give zero output"
|
||||
);
|
||||
|
||||
// Single element
|
||||
let a_single = vec![3.0f32];
|
||||
let x_single = vec![4.0f32];
|
||||
let mut y_single = vec![0.0f32];
|
||||
|
||||
gemv_neon(&a_single, &x_single, &mut y_single, 1, 1);
|
||||
assert!((y_single[0] - 12.0).abs() < 1e-5, "1x1 GEMV failed");
|
||||
|
||||
// Negative values
|
||||
let a_neg: Vec<f32> = (0..16)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
let x_neg: Vec<f32> = (0..4)
|
||||
.map(|i| if i % 2 == 0 { -1.0 } else { 1.0 })
|
||||
.collect();
|
||||
let mut y_neg = vec![0.0f32; 4];
|
||||
|
||||
gemv_neon(&a_neg, &x_neg, &mut y_neg, 4, 4);
|
||||
assert!(
|
||||
y_neg.iter().all(|&v| v.is_finite()),
|
||||
"Negative values should produce finite output"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_fallback_numerical_stability() {
|
||||
// Very small values
|
||||
let a_small: Vec<f32> = vec![1e-20; 64];
|
||||
let x_small: Vec<f32> = vec![1e-20; 8];
|
||||
let mut y_small = vec![0.0f32; 8];
|
||||
|
||||
gemv_neon(&a_small, &x_small, &mut y_small, 8, 8);
|
||||
assert!(
|
||||
y_small.iter().all(|&v| v.is_finite()),
|
||||
"Very small values should produce finite output"
|
||||
);
|
||||
|
||||
// Large values (but not overflow)
|
||||
let a_large: Vec<f32> = vec![1e10; 64];
|
||||
let x_large: Vec<f32> = vec![1e-10; 8]; // Scale x to avoid overflow
|
||||
let mut y_large = vec![0.0f32; 8];
|
||||
|
||||
gemv_neon(&a_large, &x_large, &mut y_large, 8, 8);
|
||||
assert!(
|
||||
y_large.iter().all(|&v| v.is_finite()),
|
||||
"Large values with small x should produce finite output"
|
||||
);
|
||||
|
||||
// Mixed magnitudes
|
||||
let a_mixed: Vec<f32> = (0..64)
|
||||
.map(|i| if i % 2 == 0 { 1e5 } else { 1e-5 })
|
||||
.collect();
|
||||
let x_mixed: Vec<f32> = vec![1.0; 8];
|
||||
let mut y_mixed = vec![0.0f32; 8];
|
||||
|
||||
gemv_neon(&a_mixed, &x_mixed, &mut y_mixed, 8, 8);
|
||||
assert!(
|
||||
y_mixed.iter().all(|&v| v.is_finite()),
|
||||
"Mixed magnitude values should produce finite output"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_fallback_determinism() {
|
||||
let m = 32;
|
||||
let n = 64;
|
||||
|
||||
let a: Vec<f32> = (0..m * n).map(|i| ((i as f32) * 0.1).sin()).collect();
|
||||
let x: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.2).cos()).collect();
|
||||
|
||||
// Run multiple times and verify same result
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..5 {
|
||||
let mut y = vec![0.0f32; m];
|
||||
gemv_neon(&a, &x, &mut y, m, n);
|
||||
results.push(y);
|
||||
}
|
||||
|
||||
for i in 1..results.len() {
|
||||
for j in 0..m {
|
||||
assert_eq!(
|
||||
results[0][j], results[i][j],
|
||||
"GEMV should be deterministic: run 0 vs run {} differ at index {}",
|
||||
i, j
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ========== WASM Compatibility Tests ==========
|
||||
|
||||
#[test]
|
||||
fn test_wasm_compatible_operations() {
|
||||
// These operations should work on WASM (no NEON)
|
||||
// Test with dimensions that don't require SIMD
|
||||
|
||||
// Small GEMV
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let x = vec![1.0, 1.0];
|
||||
let mut y = vec![0.0; 2];
|
||||
gemv_neon(&a, &x, &mut y, 2, 2);
|
||||
assert!((y[0] - 3.0).abs() < 1e-5); // 1*1 + 2*1 = 3
|
||||
assert!((y[1] - 7.0).abs() < 1e-5); // 3*1 + 4*1 = 7
|
||||
|
||||
// Small GEMM
|
||||
let a_gemm = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b_gemm = vec![1.0, 0.0, 0.0, 1.0]; // Identity
|
||||
let mut c_gemm = vec![0.0; 4];
|
||||
gemm_neon(&a_gemm, &b_gemm, &mut c_gemm, 2, 2, 2);
|
||||
// A * I = A
|
||||
for i in 0..4 {
|
||||
assert!(
|
||||
(c_gemm[i] - a_gemm[i]).abs() < 1e-5,
|
||||
"GEMM with identity failed"
|
||||
);
|
||||
}
|
||||
|
||||
// Small attention
|
||||
let query = vec![0.1, 0.2, 0.3, 0.4];
|
||||
let key = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
|
||||
let value = vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0];
|
||||
let scale = 0.5;
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
assert_eq!(output.len(), 4);
|
||||
assert!(output.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_path_verification() {
|
||||
// Test that scalar fallback path produces correct results
|
||||
// for small inputs that might not trigger SIMD optimizations
|
||||
|
||||
// Verify GEMV with small non-aligned dimensions
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let x = vec![1.0, 2.0, 3.0];
|
||||
let mut y = vec![0.0; 1];
|
||||
gemv_neon(&a, &x, &mut y, 1, 3);
|
||||
let expected = 1.0 + 4.0 + 9.0; // 1*1 + 2*2 + 3*3 = 14
|
||||
assert!(
|
||||
(y[0] - expected).abs() < 1e-5,
|
||||
"Scalar GEMV expected {}, got {}",
|
||||
expected,
|
||||
y[0]
|
||||
);
|
||||
|
||||
// Verify GEMM with 1x1
|
||||
let a1 = vec![5.0f32];
|
||||
let b1 = vec![3.0f32];
|
||||
let mut c1 = vec![0.0f32];
|
||||
gemm_neon(&a1, &b1, &mut c1, 1, 1, 1);
|
||||
assert!(
|
||||
(c1[0] - 15.0).abs() < 1e-5,
|
||||
"1x1 GEMM expected 15, got {}",
|
||||
c1[0]
|
||||
);
|
||||
|
||||
// Verify normalization with small vector
|
||||
let mut x_norm = vec![3.0, 4.0];
|
||||
let weight = vec![1.0, 1.0];
|
||||
rms_norm_neon(&mut x_norm, &weight, 1e-6);
|
||||
// RMS = sqrt((9+16)/2) = sqrt(12.5) = 3.536
|
||||
// Normalized: [3/3.536, 4/3.536] = [0.848, 1.131]
|
||||
assert!(x_norm.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
1227
crates/ruvllm/tests/cross_platform_v21.rs
Normal file
1227
crates/ruvllm/tests/cross_platform_v21.rs
Normal file
File diff suppressed because it is too large
Load Diff
765
crates/ruvllm/tests/e2e_integration.rs
Normal file
765
crates/ruvllm/tests/e2e_integration.rs
Normal file
@@ -0,0 +1,765 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! End-to-end integration tests for RuvLLM
|
||||
//!
|
||||
//! Tests the complete inference pipeline including model loading,
|
||||
//! session management, KV cache, paged attention, and policy/witness stores.
|
||||
|
||||
use chrono::Utc;
|
||||
use ruvllm::{
|
||||
backends::{DType, DeviceType, GenerateParams, ModelArchitecture, ModelConfig, Quantization},
|
||||
error::Result,
|
||||
kv_cache::{KvCacheConfig, TwoTierKvCache},
|
||||
lora::{AdaptFeedback, MicroLoRA, MicroLoraConfig, TargetModule},
|
||||
paged_attention::{PagedAttention, PagedAttentionConfig},
|
||||
policy_store::{PolicyEntry, PolicySource, PolicyStore, PolicyType, QuantizationPolicy},
|
||||
session::{SessionConfig, SessionManager},
|
||||
sona::{LearningLoop, SonaConfig, SonaIntegration, Trajectory},
|
||||
types::ModelSize,
|
||||
witness_log::{LatencyBreakdown, RoutingDecision, WitnessEntry, WitnessLog},
|
||||
RuvLLMConfig, RuvLLMEngine,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tempfile::TempDir;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Create a temporary directory for test storage
|
||||
fn create_test_dir() -> TempDir {
|
||||
tempfile::tempdir().expect("Failed to create temp dir")
|
||||
}
|
||||
|
||||
/// Create a test RuvLLM configuration
|
||||
fn create_test_config(storage_path: &str) -> RuvLLMConfig {
|
||||
RuvLLMConfig {
|
||||
storage_path: storage_path.to_string(),
|
||||
paged_attention: PagedAttentionConfig {
|
||||
page_size: 16,
|
||||
page_table_capacity: 64,
|
||||
num_kv_heads: 4,
|
||||
head_dim: 32,
|
||||
..Default::default()
|
||||
},
|
||||
kv_cache: KvCacheConfig {
|
||||
tail_length: 32,
|
||||
max_tokens: 256,
|
||||
num_kv_heads: 4,
|
||||
head_dim: 32,
|
||||
..Default::default()
|
||||
},
|
||||
session: SessionConfig::default(),
|
||||
sona: SonaConfig::default(),
|
||||
max_sessions: 100,
|
||||
embedding_dim: 768, // Must match SessionState::from_session default
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires model download
|
||||
async fn test_full_inference_pipeline() {
|
||||
// This test would require an actual model
|
||||
// let temp_dir = create_test_dir();
|
||||
// let config = create_test_config(temp_dir.path().to_str().unwrap());
|
||||
// let engine = RuvLLMEngine::new(config).unwrap();
|
||||
|
||||
// Steps:
|
||||
// 1. Load model
|
||||
// 2. Create session
|
||||
// 3. Generate initial response
|
||||
// 4. Apply adaptation based on feedback
|
||||
// 5. Generate again (should be different/improved)
|
||||
// 6. Verify learning metrics
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_engine_creation() {
|
||||
let temp_dir = create_test_dir();
|
||||
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
||||
|
||||
let result = RuvLLMEngine::new(config);
|
||||
assert!(result.is_ok(), "Engine creation failed: {:?}", result.err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_creation_and_retrieval() {
|
||||
let temp_dir = create_test_dir();
|
||||
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
||||
let engine = RuvLLMEngine::new(config).unwrap();
|
||||
|
||||
// Create session
|
||||
let session = engine.create_session(Some("user-123")).unwrap();
|
||||
assert!(!session.id.is_empty());
|
||||
|
||||
// Retrieve session
|
||||
let retrieved = engine.get_session(&session.id).unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
let retrieved_session = retrieved.unwrap();
|
||||
assert_eq!(retrieved_session.id, session.id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_sessions() {
|
||||
let temp_dir = create_test_dir();
|
||||
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
||||
let engine = RuvLLMEngine::new(config).unwrap();
|
||||
|
||||
let mut sessions = Vec::new();
|
||||
for i in 0..10 {
|
||||
let session = engine.create_session(Some(&format!("user-{}", i))).unwrap();
|
||||
sessions.push(session.id.clone());
|
||||
}
|
||||
|
||||
// Verify all sessions exist
|
||||
for session_id in &sessions {
|
||||
let session = engine.get_session(session_id).unwrap();
|
||||
assert!(session.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kv_cache_eviction() {
|
||||
let config = KvCacheConfig {
|
||||
tail_length: 4,
|
||||
max_tokens: 10,
|
||||
num_kv_heads: 2,
|
||||
head_dim: 8,
|
||||
migration_batch: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cache = TwoTierKvCache::new(config);
|
||||
|
||||
// Add more tokens than max
|
||||
for i in 0..20 {
|
||||
let keys = vec![i as f32; 2 * 8]; // num_kv_heads * head_dim
|
||||
let values = vec![i as f32 * 2.0; 2 * 8];
|
||||
cache.append(&keys, &values).unwrap();
|
||||
}
|
||||
|
||||
// Should have evicted to stay under max
|
||||
let stats = cache.stats();
|
||||
assert!(
|
||||
stats.total_tokens <= 10,
|
||||
"Should evict to stay under max: {}",
|
||||
stats.total_tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kv_cache_two_tier_storage() {
|
||||
let config = KvCacheConfig {
|
||||
tail_length: 4,
|
||||
max_tokens: 100,
|
||||
num_kv_heads: 2,
|
||||
head_dim: 8,
|
||||
migration_batch: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cache = TwoTierKvCache::new(config);
|
||||
|
||||
// Add tokens to trigger migration
|
||||
for i in 0..10 {
|
||||
let keys = vec![i as f32; 2 * 8];
|
||||
let values = vec![i as f32 * 2.0; 2 * 8];
|
||||
cache.append(&keys, &values).unwrap();
|
||||
}
|
||||
|
||||
let stats = cache.stats();
|
||||
|
||||
// Should have some in tail and some in store
|
||||
assert_eq!(stats.total_tokens, 10);
|
||||
assert!(
|
||||
stats.tail_tokens <= 4,
|
||||
"Tail should be limited: {}",
|
||||
stats.tail_tokens
|
||||
);
|
||||
assert!(
|
||||
stats.store_tokens >= 6,
|
||||
"Store should have overflow: {}",
|
||||
stats.store_tokens
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kv_cache_attention() {
|
||||
let config = KvCacheConfig {
|
||||
tail_length: 8,
|
||||
max_tokens: 32,
|
||||
num_kv_heads: 1,
|
||||
head_dim: 16,
|
||||
migration_batch: 4,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cache = TwoTierKvCache::new(config);
|
||||
|
||||
// Add some KV pairs
|
||||
for i in 0..5 {
|
||||
let keys: Vec<f32> = (0..16).map(|j| (i * 16 + j) as f32 * 0.1).collect();
|
||||
let values: Vec<f32> = (0..16).map(|j| (i * 16 + j) as f32 * 0.2).collect();
|
||||
cache.append(&keys, &values).unwrap();
|
||||
}
|
||||
|
||||
// Query
|
||||
let query: Vec<f32> = (0..16).map(|i| i as f32 * 0.05).collect();
|
||||
let scale = 1.0 / 16.0f32.sqrt();
|
||||
|
||||
let output = cache.attend(&query, scale).unwrap();
|
||||
|
||||
assert_eq!(output.len(), 16);
|
||||
assert!(output.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_basic() {
|
||||
let config = PagedAttentionConfig {
|
||||
page_size: 4,
|
||||
page_table_capacity: 16,
|
||||
num_kv_heads: 2,
|
||||
head_dim: 16,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let paged_attn = PagedAttention::new(config);
|
||||
|
||||
// Check initial state
|
||||
let stats_before = paged_attn.stats();
|
||||
assert_eq!(stats_before.active_sequences, 0);
|
||||
|
||||
// Allocate pages for a sequence
|
||||
let seq_id = "seq-1";
|
||||
paged_attn.allocate_sequence(seq_id, 8).unwrap();
|
||||
|
||||
// Check allocation via stats
|
||||
let stats_after_alloc = paged_attn.stats();
|
||||
assert_eq!(stats_after_alloc.active_sequences, 1);
|
||||
|
||||
// Free sequence
|
||||
paged_attn.free_sequence(seq_id).unwrap();
|
||||
|
||||
// Verify freed via stats
|
||||
let stats_after_free = paged_attn.stats();
|
||||
assert_eq!(stats_after_free.active_sequences, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_kv_cache_access() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let config = KvCacheConfig {
|
||||
tail_length: 64,
|
||||
max_tokens: 256,
|
||||
num_kv_heads: 4,
|
||||
head_dim: 32,
|
||||
migration_batch: 16,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cache = Arc::new(TwoTierKvCache::new(config));
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn multiple writers
|
||||
for t in 0..4 {
|
||||
let cache_clone = Arc::clone(&cache);
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
let keys = vec![(t * 100 + i) as f32; 4 * 32];
|
||||
let values = vec![(t * 100 + i) as f32 * 2.0; 4 * 32];
|
||||
cache_clone.append(&keys, &values).unwrap();
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
let stats = cache.stats();
|
||||
assert!(stats.total_tokens > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_requests() {
|
||||
let temp_dir = create_test_dir();
|
||||
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
||||
let engine = Arc::new(RuvLLMEngine::new(config).unwrap());
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn concurrent session creators
|
||||
for i in 0..10 {
|
||||
let engine_clone = Arc::clone(&engine);
|
||||
let handle = tokio::spawn(async move {
|
||||
let session = engine_clone.create_session(Some(&format!("concurrent-user-{}", i)));
|
||||
session.is_ok()
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// All should succeed
|
||||
for handle in handles {
|
||||
assert!(handle.await.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_store() {
|
||||
let temp_dir = create_test_dir();
|
||||
let storage_path = format!("{}/policies", temp_dir.path().to_str().unwrap());
|
||||
|
||||
let store = PolicyStore::new(&storage_path, 64).unwrap();
|
||||
|
||||
// Store a policy
|
||||
let policy = PolicyEntry {
|
||||
id: Uuid::new_v4(),
|
||||
policy_type: PolicyType::Quantization,
|
||||
embedding: vec![0.1; 64],
|
||||
parameters: serde_json::json!({
|
||||
"precision": "q4_k",
|
||||
"quality_threshold": 0.9,
|
||||
}),
|
||||
confidence: 0.85,
|
||||
fisher_diagonal: None,
|
||||
created_at: Utc::now(),
|
||||
last_accessed: Utc::now(),
|
||||
source: PolicySource::InstantLoop,
|
||||
tags: vec!["quantization".to_string()],
|
||||
};
|
||||
|
||||
store.store(policy).unwrap();
|
||||
|
||||
// Search
|
||||
let query = vec![0.1; 64];
|
||||
let results = store.search(&query, 5).unwrap();
|
||||
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_log() {
|
||||
let temp_dir = create_test_dir();
|
||||
let storage_path = format!("{}/witness", temp_dir.path().to_str().unwrap());
|
||||
|
||||
let log = WitnessLog::new(&storage_path, 64).unwrap();
|
||||
|
||||
// Record entries
|
||||
for i in 0..5 {
|
||||
let routing_decision = RoutingDecision {
|
||||
model: ModelSize::Small,
|
||||
context_size: 512,
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
confidence: 0.8 + (i as f32 * 0.02),
|
||||
model_probs: [0.1, 0.4, 0.3, 0.2],
|
||||
};
|
||||
|
||||
let entry = WitnessEntry::new(
|
||||
format!("session-{}", i % 2),
|
||||
vec![i as f32 * 0.1; 64],
|
||||
routing_decision,
|
||||
)
|
||||
.with_quality(0.85)
|
||||
.with_latency(LatencyBreakdown {
|
||||
embedding_ms: 5.0,
|
||||
retrieval_ms: 2.0,
|
||||
routing_ms: 1.0,
|
||||
attention_ms: 30.0,
|
||||
generation_ms: 62.0,
|
||||
total_ms: 100.0 + (i as f32 * 10.0),
|
||||
});
|
||||
|
||||
log.record(entry).unwrap();
|
||||
}
|
||||
|
||||
// Flush to ensure entries are searchable
|
||||
log.flush().unwrap();
|
||||
|
||||
// Search
|
||||
let query = vec![0.2; 64];
|
||||
let results = log.search(&query, 3).unwrap();
|
||||
|
||||
// Results may be empty if flush didn't complete vector indexing
|
||||
// This is expected behavior for async write-back
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_end_to_end_adaptation_flow() {
|
||||
let config = MicroLoraConfig {
|
||||
rank: 2,
|
||||
alpha: 4.0,
|
||||
dropout: 0.0,
|
||||
target_modules: vec![TargetModule::QProj],
|
||||
in_features: 64,
|
||||
out_features: 64,
|
||||
use_bias: false,
|
||||
standard_init: true,
|
||||
gradient_checkpointing: false,
|
||||
};
|
||||
|
||||
let lora = MicroLoRA::new(config);
|
||||
let _sona = SonaIntegration::new(SonaConfig::default());
|
||||
|
||||
let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.01).collect();
|
||||
|
||||
// Initial forward
|
||||
let output_initial = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
// Simulate inference loop with adaptation
|
||||
let mut quality_history = Vec::new();
|
||||
for i in 0..20 {
|
||||
// Forward pass
|
||||
let _output = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
// Compute simulated quality (increasing over time)
|
||||
let simulated_quality = 0.2 + (i as f32 * 0.03);
|
||||
quality_history.push(simulated_quality);
|
||||
|
||||
// Create feedback
|
||||
let feedback = AdaptFeedback::from_quality(simulated_quality);
|
||||
|
||||
// Adapt
|
||||
lora.adapt(&input, feedback).unwrap();
|
||||
lora.apply_updates(0.01);
|
||||
}
|
||||
|
||||
// Final forward
|
||||
let output_final = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
// Verify adaptation happened
|
||||
let changed = output_initial
|
||||
.iter()
|
||||
.zip(output_final.iter())
|
||||
.any(|(a, b)| (a - b).abs() > 1e-6);
|
||||
let all_near_zero = output_initial.iter().all(|&v| v.abs() < 1e-6);
|
||||
|
||||
assert!(changed || all_near_zero);
|
||||
|
||||
// Verify quality increased
|
||||
let first_qualities: f32 = quality_history[..5].iter().sum::<f32>() / 5.0;
|
||||
let last_qualities: f32 = quality_history[15..].iter().sum::<f32>() / 5.0;
|
||||
assert!(
|
||||
last_qualities > first_qualities,
|
||||
"Quality should increase: {} vs {}",
|
||||
last_qualities,
|
||||
first_qualities
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_lifecycle() {
|
||||
let config = SessionConfig::default();
|
||||
let manager = SessionManager::new(config);
|
||||
|
||||
// Create session
|
||||
let session = manager.create_session(Some("user-1")).unwrap();
|
||||
let session_id = session.id.clone();
|
||||
|
||||
// Get session
|
||||
let retrieved = manager.get_session(&session_id).unwrap();
|
||||
assert!(retrieved.is_some());
|
||||
|
||||
// Terminate session
|
||||
manager.terminate_session(&session_id).unwrap();
|
||||
|
||||
// Session should be gone
|
||||
let ended = manager.get_session(&session_id).unwrap();
|
||||
assert!(ended.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_latency_measurement() {
|
||||
let start = Instant::now();
|
||||
|
||||
// Simulate some work
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..10000 {
|
||||
sum += (i as f32).sqrt();
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Create latency breakdown
|
||||
let breakdown = LatencyBreakdown {
|
||||
embedding_ms: elapsed.as_secs_f32() * 100.0, // 10%
|
||||
retrieval_ms: elapsed.as_secs_f32() * 50.0, // 5%
|
||||
routing_ms: elapsed.as_secs_f32() * 50.0, // 5%
|
||||
attention_ms: elapsed.as_secs_f32() * 300.0, // 30%
|
||||
generation_ms: elapsed.as_secs_f32() * 500.0, // 50%
|
||||
total_ms: elapsed.as_secs_f32() * 1000.0,
|
||||
};
|
||||
|
||||
assert!(breakdown.total_ms >= 0.0);
|
||||
assert!(sum > 0.0); // Use sum to prevent optimization
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config_variants() {
|
||||
let configs = vec![
|
||||
ModelConfig {
|
||||
architecture: ModelArchitecture::Llama,
|
||||
device: DeviceType::Cpu,
|
||||
dtype: DType::F32,
|
||||
quantization: None,
|
||||
use_flash_attention: false,
|
||||
max_sequence_length: 2048,
|
||||
..Default::default()
|
||||
},
|
||||
ModelConfig {
|
||||
architecture: ModelArchitecture::Mistral,
|
||||
device: DeviceType::Metal,
|
||||
dtype: DType::F16,
|
||||
quantization: Some(Quantization::Q4),
|
||||
use_flash_attention: true,
|
||||
max_sequence_length: 4096,
|
||||
..Default::default()
|
||||
},
|
||||
ModelConfig {
|
||||
architecture: ModelArchitecture::Phi,
|
||||
device: DeviceType::Cuda(0),
|
||||
dtype: DType::Bf16,
|
||||
quantization: Some(Quantization::Q8),
|
||||
use_flash_attention: true,
|
||||
max_sequence_length: 8192,
|
||||
..Default::default()
|
||||
},
|
||||
];
|
||||
|
||||
for config in configs {
|
||||
assert!(config.max_sequence_length > 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_params_customization() {
|
||||
let params = GenerateParams {
|
||||
max_tokens: 256,
|
||||
temperature: 0.8,
|
||||
top_p: 0.95,
|
||||
top_k: 50,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop_sequences: vec!["<|end|>".to_string(), "\n\n".to_string()],
|
||||
seed: Some(12345),
|
||||
};
|
||||
|
||||
assert_eq!(params.max_tokens, 256);
|
||||
assert_eq!(params.stop_sequences.len(), 2);
|
||||
assert!(params.seed.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_params_builder() {
|
||||
let params = GenerateParams::default()
|
||||
.with_max_tokens(512)
|
||||
.with_temperature(0.5)
|
||||
.with_top_p(0.95)
|
||||
.with_top_k(50)
|
||||
.with_repetition_penalty(1.2)
|
||||
.with_seed(42);
|
||||
|
||||
assert_eq!(params.max_tokens, 512);
|
||||
assert_eq!(params.temperature, 0.5);
|
||||
assert_eq!(params.top_p, 0.95);
|
||||
assert_eq!(params.top_k, 50);
|
||||
assert_eq!(params.repetition_penalty, 1.2);
|
||||
assert_eq!(params.seed, Some(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_routing_decision() {
|
||||
let decisions = vec![
|
||||
RoutingDecision {
|
||||
model: ModelSize::Large,
|
||||
context_size: 1024,
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
confidence: 0.95,
|
||||
model_probs: [0.05, 0.1, 0.25, 0.6],
|
||||
},
|
||||
RoutingDecision {
|
||||
model: ModelSize::Medium,
|
||||
context_size: 512,
|
||||
temperature: 0.8,
|
||||
top_p: 0.95,
|
||||
confidence: 0.88,
|
||||
model_probs: [0.1, 0.2, 0.5, 0.2],
|
||||
},
|
||||
RoutingDecision {
|
||||
model: ModelSize::Small,
|
||||
context_size: 256,
|
||||
temperature: 0.6,
|
||||
top_p: 0.9,
|
||||
confidence: 0.6,
|
||||
model_probs: [0.2, 0.5, 0.2, 0.1],
|
||||
},
|
||||
];
|
||||
|
||||
for decision in decisions {
|
||||
assert!(decision.confidence >= 0.0 && decision.confidence <= 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_handling() {
|
||||
let temp_dir = create_test_dir();
|
||||
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
||||
let engine = RuvLLMEngine::new(config).unwrap();
|
||||
|
||||
// Try to get non-existent session
|
||||
let result = engine.get_session("non-existent-session-id");
|
||||
assert!(result.is_ok()); // Should succeed but return None
|
||||
assert!(result.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_efficiency() {
|
||||
let config = KvCacheConfig {
|
||||
tail_length: 32,
|
||||
max_tokens: 128,
|
||||
num_kv_heads: 4,
|
||||
head_dim: 64,
|
||||
migration_batch: 16,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let cache = TwoTierKvCache::new(config);
|
||||
|
||||
// Fill cache
|
||||
for _ in 0..100 {
|
||||
let keys = vec![1.0; 4 * 64];
|
||||
let values = vec![2.0; 4 * 64];
|
||||
cache.append(&keys, &values).unwrap();
|
||||
}
|
||||
|
||||
let stats = cache.stats();
|
||||
|
||||
// Store should use less memory per token than tail (quantized)
|
||||
if stats.store_tokens > 0 && stats.tail_tokens > 0 {
|
||||
let bytes_per_tail_token = stats.tail_bytes as f32 / stats.tail_tokens as f32;
|
||||
let bytes_per_store_token = stats.store_bytes as f32 / stats.store_tokens as f32;
|
||||
|
||||
// Quantized store should use less memory (or same if not actually quantized)
|
||||
assert!(
|
||||
bytes_per_store_token <= bytes_per_tail_token * 1.1,
|
||||
"Store should be more memory efficient: {} vs {} bytes/token",
|
||||
bytes_per_store_token,
|
||||
bytes_per_tail_token
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_integration_basic() {
|
||||
let config = SonaConfig {
|
||||
embedding_dim: 256,
|
||||
..Default::default()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Record a trajectory
|
||||
let trajectory = Trajectory {
|
||||
request_id: "req-1".to_string(),
|
||||
session_id: "test-session".to_string(),
|
||||
query_embedding: vec![0.1; 256],
|
||||
response_embedding: vec![0.2; 256],
|
||||
quality_score: 0.8,
|
||||
routing_features: vec![0.7, 0.9, 0.5, 0.5],
|
||||
model_index: 1,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
|
||||
// Get stats
|
||||
let stats = sona.stats();
|
||||
assert!(stats.total_trajectories >= 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_learning_loops() {
|
||||
// Test that all learning loop variants exist
|
||||
let loops = vec![
|
||||
LearningLoop::Instant,
|
||||
LearningLoop::Background,
|
||||
LearningLoop::Deep,
|
||||
];
|
||||
|
||||
for _loop in loops {
|
||||
// Just verify the variants exist
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_variants() {
|
||||
let q4 = Quantization::Q4;
|
||||
let q8 = Quantization::Q8;
|
||||
let q4k = Quantization::Q4K;
|
||||
let f16 = Quantization::F16;
|
||||
|
||||
assert!(q4.is_gguf());
|
||||
assert!(q8.is_gguf());
|
||||
assert!(q4k.is_gguf());
|
||||
assert!(!f16.is_gguf());
|
||||
|
||||
// Check bytes per weight
|
||||
assert_eq!(Quantization::None.bytes_per_weight(), 4.0);
|
||||
assert_eq!(Quantization::F16.bytes_per_weight(), 2.0);
|
||||
assert_eq!(Quantization::Q8.bytes_per_weight(), 1.0);
|
||||
assert_eq!(Quantization::Q4K.bytes_per_weight(), 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_device_type_variants() {
|
||||
let cpu = DeviceType::Cpu;
|
||||
let metal = DeviceType::Metal;
|
||||
let cuda = DeviceType::Cuda(0);
|
||||
|
||||
assert!(matches!(cpu, DeviceType::Cpu));
|
||||
assert!(matches!(metal, DeviceType::Metal));
|
||||
if let DeviceType::Cuda(idx) = cuda {
|
||||
assert_eq!(idx, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_architecture_variants() {
|
||||
let llama = ModelArchitecture::Llama;
|
||||
let mistral = ModelArchitecture::Mistral;
|
||||
let phi = ModelArchitecture::Phi;
|
||||
let qwen = ModelArchitecture::Qwen;
|
||||
let gemma = ModelArchitecture::Gemma;
|
||||
|
||||
assert_eq!(llama.config_name(), "llama");
|
||||
assert_eq!(mistral.config_name(), "mistral");
|
||||
assert_eq!(phi.config_name(), "phi");
|
||||
assert_eq!(qwen.config_name(), "qwen2");
|
||||
assert_eq!(gemma.config_name(), "gemma");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dtype_variants() {
|
||||
let f32_type = DType::F32;
|
||||
let f16_type = DType::F16;
|
||||
let bf16_type = DType::Bf16;
|
||||
|
||||
assert!(matches!(f32_type, DType::F32));
|
||||
assert!(matches!(f16_type, DType::F16));
|
||||
assert!(matches!(bf16_type, DType::Bf16));
|
||||
}
|
||||
1535
crates/ruvllm/tests/e2e_integration_test.rs
Normal file
1535
crates/ruvllm/tests/e2e_integration_test.rs
Normal file
File diff suppressed because it is too large
Load Diff
404
crates/ruvllm/tests/fixtures/mod.rs
vendored
Normal file
404
crates/ruvllm/tests/fixtures/mod.rs
vendored
Normal file
@@ -0,0 +1,404 @@
|
||||
//! Test Fixtures for RuvLTRA-Small
|
||||
//!
|
||||
//! This module provides test fixtures including sample prompts, expected patterns,
|
||||
//! and perplexity baselines for validating the RuvLTRA-Small inference engine.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ============================================================================
|
||||
// Sample Prompts
|
||||
// ============================================================================
|
||||
|
||||
/// Collection of test prompts organized by category
|
||||
pub mod prompts {
|
||||
/// Simple text completion prompts
|
||||
pub mod completion {
|
||||
pub const QUICK_BROWN_FOX: &str = "The quick brown fox";
|
||||
pub const ONCE_UPON_A_TIME: &str = "Once upon a time";
|
||||
pub const IN_THE_BEGINNING: &str = "In the beginning";
|
||||
pub const IT_WAS_A_DARK: &str = "It was a dark and stormy night";
|
||||
}
|
||||
|
||||
/// Instruction-following prompts
|
||||
pub mod instruction {
|
||||
pub const WRITE_HAIKU: &str = "Write a haiku about programming:";
|
||||
pub const EXPLAIN_GRAVITY: &str = "Explain gravity in simple terms:";
|
||||
pub const LIST_PLANETS: &str = "List the planets in our solar system:";
|
||||
pub const DESCRIBE_OCEAN: &str = "Describe the ocean in three sentences:";
|
||||
}
|
||||
|
||||
/// Question-answering prompts
|
||||
pub mod qa {
|
||||
pub const CAPITAL_FRANCE: &str = "Q: What is the capital of France?\nA:";
|
||||
pub const TWO_PLUS_TWO: &str = "Q: What is 2 + 2?\nA:";
|
||||
pub const COLOR_SKY: &str = "Q: What color is the sky?\nA:";
|
||||
pub const LARGEST_PLANET: &str = "Q: What is the largest planet in our solar system?\nA:";
|
||||
}
|
||||
|
||||
/// Code generation prompts
|
||||
pub mod code {
|
||||
pub const FIBONACCI: &str = "def fibonacci(n):\n '''Return the nth Fibonacci number.'''\n";
|
||||
pub const HELLO_WORLD: &str = "# Python function to print hello world\ndef hello():";
|
||||
pub const FACTORIAL: &str = "def factorial(n):\n '''Return n factorial.'''\n";
|
||||
pub const SORT_LIST: &str = "def sort_list(items):\n '''Sort a list in ascending order.'''\n";
|
||||
}
|
||||
|
||||
/// Conversation/chat prompts
|
||||
pub mod conversation {
|
||||
pub const GREETING: &str = "User: Hello!\nAssistant:";
|
||||
pub const TELL_JOKE: &str = "User: Tell me a joke.\nAssistant:";
|
||||
pub const WEATHER: &str = "User: What's the weather like today?\nAssistant:";
|
||||
pub const HELP: &str = "User: Can you help me?\nAssistant:";
|
||||
}
|
||||
|
||||
/// Edge case prompts
|
||||
pub mod edge_cases {
|
||||
pub const EMPTY: &str = "";
|
||||
pub const SINGLE_CHAR: &str = "A";
|
||||
pub const SINGLE_WORD: &str = "Hello";
|
||||
pub const SPECIAL_CHARS: &str = "Translate: \"Hello, world!\" ->";
|
||||
pub const UNICODE: &str = "\u{4f60}\u{597d}\u{4e16}\u{754c}"; // 你好世界
|
||||
pub const NUMBERS_ONLY: &str = "1 2 3 4 5";
|
||||
pub const VERY_LONG: &str = "The quick brown fox jumps over the lazy dog. \
|
||||
The quick brown fox jumps over the lazy dog. \
|
||||
The quick brown fox jumps over the lazy dog. \
|
||||
The quick brown fox jumps over the lazy dog. \
|
||||
The quick brown fox jumps over the lazy dog. \
|
||||
Continue:";
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Expected Output Patterns
|
||||
// ============================================================================
|
||||
|
||||
/// Expected patterns in generated outputs
|
||||
pub mod expected_patterns {
|
||||
/// Patterns expected after "The quick brown fox"
|
||||
pub const FOX_COMPLETION: &[&str] = &[
|
||||
"jumps", "jumped", "runs", "ran", "over", "the", "lazy", "dog"
|
||||
];
|
||||
|
||||
/// Patterns expected in haiku responses
|
||||
pub const HAIKU_PATTERNS: &[&str] = &[
|
||||
"code", "bug", "compile", "debug", "screen", "night", "lines", "function"
|
||||
];
|
||||
|
||||
/// Capital of France
|
||||
pub const FRANCE_CAPITAL: &str = "Paris";
|
||||
|
||||
/// Answer to 2+2
|
||||
pub const TWO_PLUS_TWO: &str = "4";
|
||||
|
||||
/// Patterns in Fibonacci code
|
||||
pub const FIBONACCI_PATTERNS: &[&str] = &[
|
||||
"return", "if", "else", "n", "<=", "1", "+", "fibonacci"
|
||||
];
|
||||
|
||||
/// Patterns in greeting responses
|
||||
pub const GREETING_PATTERNS: &[&str] = &[
|
||||
"hello", "hi", "hey", "how", "help", "assist", "welcome"
|
||||
];
|
||||
|
||||
/// Patterns in factorial code
|
||||
pub const FACTORIAL_PATTERNS: &[&str] = &[
|
||||
"return", "if", "n", "<=", "1", "*", "factorial"
|
||||
];
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Perplexity Baselines
|
||||
// ============================================================================
|
||||
|
||||
/// Perplexity baseline values for quality validation
|
||||
pub mod perplexity {
|
||||
/// Maximum acceptable perplexity for coherent output
|
||||
pub const MAX_ACCEPTABLE: f32 = 50.0;
|
||||
|
||||
/// Warning threshold for elevated perplexity
|
||||
pub const WARNING_THRESHOLD: f32 = 30.0;
|
||||
|
||||
/// Excellent perplexity (high-quality output)
|
||||
pub const EXCELLENT: f32 = 15.0;
|
||||
|
||||
/// Expected perplexity ranges by task type
|
||||
pub mod task_ranges {
|
||||
/// Simple completion: low perplexity expected
|
||||
pub const COMPLETION: (f32, f32) = (5.0, 20.0);
|
||||
|
||||
/// Code generation: moderate perplexity
|
||||
pub const CODE: (f32, f32) = (8.0, 30.0);
|
||||
|
||||
/// Creative writing: higher perplexity acceptable
|
||||
pub const CREATIVE: (f32, f32) = (15.0, 45.0);
|
||||
|
||||
/// Factual QA: low perplexity (confident answers)
|
||||
pub const FACTUAL: (f32, f32) = (3.0, 15.0);
|
||||
}
|
||||
|
||||
/// Quantization degradation limits
|
||||
pub mod degradation {
|
||||
/// Max perplexity increase from quantization (%)
|
||||
pub const MAX_INCREASE_PCT: f32 = 20.0;
|
||||
|
||||
/// Q4_K expected degradation from F16 (%)
|
||||
pub const Q4K_EXPECTED: f32 = 15.0;
|
||||
|
||||
/// Q8_0 expected degradation from F16 (%)
|
||||
pub const Q8_EXPECTED: f32 = 3.0;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Token Probability Thresholds
|
||||
// ============================================================================
|
||||
|
||||
/// Thresholds for token probability validation
|
||||
pub mod probability_thresholds {
|
||||
/// Minimum probability for top-1 token
|
||||
pub const MIN_TOP1: f32 = 0.01;
|
||||
|
||||
/// Minimum cumulative probability for top-5 tokens
|
||||
pub const MIN_TOP5_CUMULATIVE: f32 = 0.1;
|
||||
|
||||
/// Maximum entropy for non-degenerate output
|
||||
pub const MAX_ENTROPY: f32 = 10.0;
|
||||
|
||||
/// Minimum confidence for factual answers
|
||||
pub const MIN_FACTUAL_CONFIDENCE: f32 = 0.5;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Coherence Metrics
|
||||
// ============================================================================
|
||||
|
||||
/// Coherence validation thresholds
|
||||
pub mod coherence {
|
||||
/// Maximum consecutive word repetitions
|
||||
pub const MAX_CONSECUTIVE_REPEATS: usize = 3;
|
||||
|
||||
/// Maximum n-gram repetition ratio
|
||||
pub const MAX_NGRAM_REPETITION: f32 = 0.3;
|
||||
|
||||
/// Minimum alphanumeric ratio for valid text
|
||||
pub const MIN_ALPHANUMERIC_RATIO: f32 = 0.7;
|
||||
|
||||
/// Maximum special character ratio
|
||||
pub const MAX_SPECIAL_CHAR_RATIO: f32 = 0.2;
|
||||
|
||||
/// Sentence length bounds
|
||||
pub const MIN_SENTENCE_LENGTH: usize = 3;
|
||||
pub const MAX_SENTENCE_LENGTH: usize = 200;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Baselines
|
||||
// ============================================================================
|
||||
|
||||
/// Performance baseline values
|
||||
pub mod performance {
|
||||
/// Tokens per second baselines by device
|
||||
pub mod tokens_per_second {
|
||||
/// M4 Pro with ANE
|
||||
pub const M4_PRO_ANE: f32 = 60.0;
|
||||
|
||||
/// M4 Pro NEON only
|
||||
pub const M4_PRO_NEON: f32 = 45.0;
|
||||
|
||||
/// M1 with ANE
|
||||
pub const M1_ANE: f32 = 40.0;
|
||||
|
||||
/// x86 CPU (AVX2)
|
||||
pub const X86_AVX2: f32 = 15.0;
|
||||
}
|
||||
|
||||
/// Latency thresholds (milliseconds)
|
||||
pub mod latency_ms {
|
||||
/// Maximum time to first token
|
||||
pub const MAX_FIRST_TOKEN: u64 = 500;
|
||||
|
||||
/// Maximum inter-token latency
|
||||
pub const MAX_INTER_TOKEN: u64 = 100;
|
||||
|
||||
/// Target inter-token latency
|
||||
pub const TARGET_INTER_TOKEN: u64 = 20;
|
||||
}
|
||||
|
||||
/// Memory thresholds (bytes)
|
||||
pub mod memory {
|
||||
/// Maximum model memory (Q4_K)
|
||||
pub const MAX_MODEL_Q4K: usize = 1_500_000_000;
|
||||
|
||||
/// Maximum KV cache memory
|
||||
pub const MAX_KV_CACHE: usize = 500_000_000;
|
||||
|
||||
/// Maximum working memory
|
||||
pub const MAX_WORKING: usize = 200_000_000;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Test Data Generators
|
||||
// ============================================================================
|
||||
|
||||
/// Generate a long prompt of specified length
|
||||
pub fn generate_long_prompt(word_count: usize) -> String {
|
||||
let words = [
|
||||
"the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog",
|
||||
"and", "then", "runs", "around", "park", "with", "great", "joy"
|
||||
];
|
||||
|
||||
(0..word_count)
|
||||
.map(|i| words[i % words.len()])
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
/// Generate a sequence of numbers for pattern completion tests
|
||||
pub fn generate_number_sequence(start: i32, count: usize) -> String {
|
||||
(start..start + count as i32)
|
||||
.map(|n| n.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
|
||||
/// Generate a repeated pattern prompt
|
||||
pub fn generate_repetition_prompt(word: &str, count: usize) -> String {
|
||||
vec![word; count].join(" ")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Validation Helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Check if output contains any of the expected patterns
|
||||
pub fn contains_expected_pattern(output: &str, patterns: &[&str]) -> bool {
|
||||
let output_lower = output.to_lowercase();
|
||||
patterns.iter().any(|p| output_lower.contains(&p.to_lowercase()))
|
||||
}
|
||||
|
||||
/// Calculate repetition ratio for n-grams
|
||||
pub fn calculate_ngram_repetition(text: &str, n: usize) -> f32 {
|
||||
let words: Vec<&str> = text.split_whitespace().collect();
|
||||
if words.len() < n {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let total_ngrams = words.len() - n + 1;
|
||||
let mut ngram_counts: HashMap<Vec<&str>, usize> = HashMap::new();
|
||||
|
||||
for window in words.windows(n) {
|
||||
*ngram_counts.entry(window.to_vec()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let repeated = ngram_counts.values().filter(|&&c| c > 1).sum::<usize>();
|
||||
repeated as f32 / total_ngrams as f32
|
||||
}
|
||||
|
||||
/// Count consecutive word repetitions
|
||||
pub fn count_consecutive_repeats(text: &str) -> usize {
|
||||
let words: Vec<&str> = text.split_whitespace().collect();
|
||||
let mut max_repeats = 0;
|
||||
let mut current_repeats = 0;
|
||||
|
||||
for i in 1..words.len() {
|
||||
if words[i] == words[i - 1] {
|
||||
current_repeats += 1;
|
||||
max_repeats = max_repeats.max(current_repeats);
|
||||
} else {
|
||||
current_repeats = 0;
|
||||
}
|
||||
}
|
||||
|
||||
max_repeats
|
||||
}
|
||||
|
||||
/// Calculate alphanumeric ratio
|
||||
pub fn alphanumeric_ratio(text: &str) -> f32 {
|
||||
if text.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let alphanumeric = text.chars()
|
||||
.filter(|c| c.is_alphanumeric())
|
||||
.count();
|
||||
|
||||
alphanumeric as f32 / text.len() as f32
|
||||
}
|
||||
|
||||
/// Check if text passes basic coherence checks
|
||||
pub fn is_coherent(text: &str) -> bool {
|
||||
// Check alphanumeric ratio
|
||||
if alphanumeric_ratio(text) < coherence::MIN_ALPHANUMERIC_RATIO {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check repetition
|
||||
if count_consecutive_repeats(text) > coherence::MAX_CONSECUTIVE_REPEATS {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check n-gram repetition
|
||||
if calculate_ngram_repetition(text, 3) > coherence::MAX_NGRAM_REPETITION {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests for Fixtures Module
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_long_prompt() {
|
||||
let prompt = generate_long_prompt(100);
|
||||
let word_count = prompt.split_whitespace().count();
|
||||
assert_eq!(word_count, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_number_sequence() {
|
||||
let seq = generate_number_sequence(1, 5);
|
||||
assert_eq!(seq, "1, 2, 3, 4, 5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contains_expected_pattern() {
|
||||
let output = "The fox jumps over the lazy dog";
|
||||
assert!(contains_expected_pattern(output, expected_patterns::FOX_COMPLETION));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ngram_repetition() {
|
||||
let no_repeat = "the quick brown fox jumps over";
|
||||
assert!(calculate_ngram_repetition(no_repeat, 2) < 0.1);
|
||||
|
||||
let high_repeat = "the the the the the the";
|
||||
assert!(calculate_ngram_repetition(high_repeat, 2) > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_consecutive_repeats() {
|
||||
assert_eq!(count_consecutive_repeats("hello world"), 0);
|
||||
assert_eq!(count_consecutive_repeats("hello hello world"), 1);
|
||||
assert_eq!(count_consecutive_repeats("hello hello hello"), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_alphanumeric_ratio() {
|
||||
assert!(alphanumeric_ratio("Hello World") > 0.8);
|
||||
assert!(alphanumeric_ratio("!@#$%^&*()") < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coherence_check() {
|
||||
assert!(is_coherent("The quick brown fox jumps over the lazy dog."));
|
||||
assert!(!is_coherent("!@#$%^&*()!@#$%^&*()!@#$%^&*()"));
|
||||
assert!(!is_coherent("the the the the the the the"));
|
||||
}
|
||||
}
|
||||
161
crates/ruvllm/tests/fixtures/perplexity_baselines.json
vendored
Normal file
161
crates/ruvllm/tests/fixtures/perplexity_baselines.json
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
{
|
||||
"metadata": {
|
||||
"version": "1.0.0",
|
||||
"description": "Perplexity baselines for RuvLTRA-Small quality validation",
|
||||
"model": "ruvltra-small",
|
||||
"quantization_tested": ["Q4_K", "Q5_K", "Q8_0", "F16"],
|
||||
"last_updated": "2024-01-19"
|
||||
},
|
||||
"quality_thresholds": {
|
||||
"max_acceptable_perplexity": 50.0,
|
||||
"warning_perplexity": 30.0,
|
||||
"excellent_perplexity": 15.0,
|
||||
"notes": "Perplexity values vary by dataset and prompt type"
|
||||
},
|
||||
"baselines": {
|
||||
"wikitext": {
|
||||
"description": "WikiText-2 test set perplexity",
|
||||
"dataset_url": "https://huggingface.co/datasets/wikitext",
|
||||
"values": {
|
||||
"F16": {
|
||||
"perplexity": 8.5,
|
||||
"tokens_evaluated": 250000,
|
||||
"notes": "Full precision baseline"
|
||||
},
|
||||
"Q8_0": {
|
||||
"perplexity": 8.7,
|
||||
"degradation_pct": 2.4,
|
||||
"notes": "8-bit quantization, minimal quality loss"
|
||||
},
|
||||
"Q5_K": {
|
||||
"perplexity": 9.2,
|
||||
"degradation_pct": 8.2,
|
||||
"notes": "5-bit k-quant, good balance"
|
||||
},
|
||||
"Q4_K": {
|
||||
"perplexity": 9.8,
|
||||
"degradation_pct": 15.3,
|
||||
"notes": "4-bit k-quant, most common deployment format"
|
||||
},
|
||||
"Q2_K": {
|
||||
"perplexity": 14.5,
|
||||
"degradation_pct": 70.6,
|
||||
"notes": "2-bit extreme quantization, noticeable degradation"
|
||||
}
|
||||
}
|
||||
},
|
||||
"lambada": {
|
||||
"description": "LAMBADA last-word prediction accuracy",
|
||||
"metric": "accuracy",
|
||||
"values": {
|
||||
"F16": {
|
||||
"accuracy": 0.72,
|
||||
"notes": "Full precision accuracy"
|
||||
},
|
||||
"Q4_K": {
|
||||
"accuracy": 0.68,
|
||||
"degradation_pct": 5.6,
|
||||
"notes": "Slight accuracy drop acceptable"
|
||||
}
|
||||
}
|
||||
},
|
||||
"hellaswag": {
|
||||
"description": "HellaSwag commonsense reasoning",
|
||||
"metric": "accuracy",
|
||||
"values": {
|
||||
"F16": {
|
||||
"accuracy": 0.68
|
||||
},
|
||||
"Q4_K": {
|
||||
"accuracy": 0.65,
|
||||
"degradation_pct": 4.4
|
||||
}
|
||||
}
|
||||
},
|
||||
"custom_prompts": {
|
||||
"description": "Perplexity on custom test prompts",
|
||||
"values": {
|
||||
"simple_completion": {
|
||||
"expected_ppl_range": [5.0, 20.0],
|
||||
"notes": "Common phrase continuation should have low perplexity"
|
||||
},
|
||||
"code_generation": {
|
||||
"expected_ppl_range": [8.0, 30.0],
|
||||
"notes": "Code has higher entropy but should still be coherent"
|
||||
},
|
||||
"creative_writing": {
|
||||
"expected_ppl_range": [15.0, 45.0],
|
||||
"notes": "Creative tasks have higher acceptable perplexity"
|
||||
},
|
||||
"factual_qa": {
|
||||
"expected_ppl_range": [3.0, 15.0],
|
||||
"notes": "Factual responses should be confident"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"degradation_limits": {
|
||||
"max_perplexity_increase_pct": 20.0,
|
||||
"max_accuracy_decrease_pct": 10.0,
|
||||
"notes": "Quantization should not degrade quality beyond these limits"
|
||||
},
|
||||
"token_probability_thresholds": {
|
||||
"min_top1_probability": 0.01,
|
||||
"min_top5_cumulative": 0.1,
|
||||
"max_entropy": 10.0,
|
||||
"notes": "Thresholds for detecting garbled or degenerate output"
|
||||
},
|
||||
"repetition_metrics": {
|
||||
"max_ngram_repetition_ratio": 0.3,
|
||||
"max_consecutive_repeats": 3,
|
||||
"ngram_window_sizes": [2, 3, 4],
|
||||
"notes": "Detect excessive repetition in generated text"
|
||||
},
|
||||
"coherence_metrics": {
|
||||
"min_sentence_length": 3,
|
||||
"max_sentence_length": 200,
|
||||
"punctuation_ratio_range": [0.01, 0.15],
|
||||
"alphanumeric_ratio_min": 0.7,
|
||||
"notes": "Basic structural coherence checks"
|
||||
},
|
||||
"speed_baselines": {
|
||||
"description": "Token generation speed baselines (tokens/second)",
|
||||
"device_baselines": {
|
||||
"m4_pro_ane": {
|
||||
"prompt_processing": 2000,
|
||||
"generation": 60,
|
||||
"notes": "M4 Pro with ANE acceleration"
|
||||
},
|
||||
"m4_pro_neon": {
|
||||
"prompt_processing": 1500,
|
||||
"generation": 45,
|
||||
"notes": "M4 Pro NEON-only fallback"
|
||||
},
|
||||
"m1_ane": {
|
||||
"prompt_processing": 1200,
|
||||
"generation": 40,
|
||||
"notes": "M1 with ANE"
|
||||
},
|
||||
"cpu_x86": {
|
||||
"prompt_processing": 500,
|
||||
"generation": 15,
|
||||
"notes": "x86 CPU baseline (AVX2)"
|
||||
}
|
||||
}
|
||||
},
|
||||
"memory_baselines": {
|
||||
"model_sizes_mb": {
|
||||
"F16": 4000,
|
||||
"Q8_0": 2200,
|
||||
"Q4_K": 1200,
|
||||
"Q2_K": 700
|
||||
},
|
||||
"kv_cache_per_token_bytes": {
|
||||
"F16": 1100,
|
||||
"Q8_0": 1100,
|
||||
"notes": "KV cache typically stays in F16 for accuracy"
|
||||
},
|
||||
"peak_memory_multiplier": 1.5,
|
||||
"notes": "Peak memory = model_size * multiplier during inference"
|
||||
}
|
||||
}
|
||||
191
crates/ruvllm/tests/fixtures/test_prompts.json
vendored
Normal file
191
crates/ruvllm/tests/fixtures/test_prompts.json
vendored
Normal file
@@ -0,0 +1,191 @@
|
||||
{
|
||||
"metadata": {
|
||||
"version": "1.0.0",
|
||||
"description": "Test prompts for RuvLTRA-Small validation",
|
||||
"model": "ruvltra-small",
|
||||
"last_updated": "2024-01-19"
|
||||
},
|
||||
"prompts": {
|
||||
"simple_completion": {
|
||||
"id": "simple_001",
|
||||
"category": "completion",
|
||||
"prompt": "The quick brown fox",
|
||||
"expected_patterns": ["jumps", "jumped", "runs", "ran", "over", "lazy"],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.7,
|
||||
"notes": "Classic completion test for basic language modeling"
|
||||
},
|
||||
"instruction_haiku": {
|
||||
"id": "instruction_001",
|
||||
"category": "instruction",
|
||||
"prompt": "Write a haiku about programming:",
|
||||
"expected_patterns": ["code", "bug", "compile", "debug", "screen", "night", "lines", "function"],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.8,
|
||||
"notes": "Tests instruction-following ability"
|
||||
},
|
||||
"qa_capital": {
|
||||
"id": "qa_001",
|
||||
"category": "question_answering",
|
||||
"prompt": "Q: What is the capital of France?\nA:",
|
||||
"expected_output": "Paris",
|
||||
"max_tokens": 20,
|
||||
"temperature": 0.1,
|
||||
"notes": "Simple factual QA with deterministic expected output"
|
||||
},
|
||||
"qa_math": {
|
||||
"id": "qa_002",
|
||||
"category": "question_answering",
|
||||
"prompt": "Q: What is 2 + 2?\nA:",
|
||||
"expected_output": "4",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0,
|
||||
"notes": "Simple math QA"
|
||||
},
|
||||
"code_fibonacci": {
|
||||
"id": "code_001",
|
||||
"category": "code_generation",
|
||||
"prompt": "def fibonacci(n):\n '''Return the nth Fibonacci number.'''\n",
|
||||
"expected_patterns": ["return", "if", "else", "n", "<=", "1", "+", "fibonacci"],
|
||||
"max_tokens": 150,
|
||||
"temperature": 0.3,
|
||||
"notes": "Code generation with expected structural patterns"
|
||||
},
|
||||
"code_hello_world": {
|
||||
"id": "code_002",
|
||||
"category": "code_generation",
|
||||
"prompt": "# Python function to print hello world\ndef",
|
||||
"expected_patterns": ["print", "hello", "world", "def"],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.2,
|
||||
"notes": "Simple code generation"
|
||||
},
|
||||
"conversation_greeting": {
|
||||
"id": "conv_001",
|
||||
"category": "conversation",
|
||||
"prompt": "User: Hello!\nAssistant:",
|
||||
"expected_patterns": ["hello", "hi", "how", "help", "can", "assist"],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.7,
|
||||
"notes": "Basic conversation response"
|
||||
},
|
||||
"conversation_joke": {
|
||||
"id": "conv_002",
|
||||
"category": "conversation",
|
||||
"prompt": "User: Tell me a joke.\nAssistant:",
|
||||
"expected_patterns": ["why", "what", "because", "knock", "chicken"],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.9,
|
||||
"notes": "Creative response generation"
|
||||
},
|
||||
"summarization": {
|
||||
"id": "summary_001",
|
||||
"category": "summarization",
|
||||
"prompt": "Summarize the following in one sentence:\nMachine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed.\nSummary:",
|
||||
"expected_patterns": ["machine learning", "AI", "artificial intelligence", "learn", "data"],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.3,
|
||||
"notes": "Tests summarization capability"
|
||||
},
|
||||
"translation": {
|
||||
"id": "translation_001",
|
||||
"category": "translation",
|
||||
"prompt": "Translate to French: Hello, how are you?\nFrench:",
|
||||
"expected_patterns": ["bonjour", "comment", "allez", "vous"],
|
||||
"max_tokens": 30,
|
||||
"temperature": 0.1,
|
||||
"notes": "Basic translation test"
|
||||
},
|
||||
"sentiment": {
|
||||
"id": "sentiment_001",
|
||||
"category": "classification",
|
||||
"prompt": "Classify the sentiment of this review as positive, negative, or neutral:\n\"This product is amazing! Best purchase I've ever made.\"\nSentiment:",
|
||||
"expected_output": "positive",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0,
|
||||
"notes": "Sentiment classification"
|
||||
},
|
||||
"reasoning_chain": {
|
||||
"id": "reasoning_001",
|
||||
"category": "reasoning",
|
||||
"prompt": "Question: If I have 3 apples and give away 1, how many do I have left?\nLet's think step by step:",
|
||||
"expected_patterns": ["3", "1", "2", "subtract", "minus", "left", "remaining"],
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.1,
|
||||
"notes": "Chain-of-thought reasoning"
|
||||
}
|
||||
},
|
||||
"edge_cases": {
|
||||
"empty_prompt": {
|
||||
"id": "edge_001",
|
||||
"prompt": "",
|
||||
"expected_behavior": "Should handle gracefully, may produce empty output or generic response",
|
||||
"max_tokens": 20
|
||||
},
|
||||
"single_char": {
|
||||
"id": "edge_002",
|
||||
"prompt": "A",
|
||||
"expected_behavior": "Should produce coherent completion",
|
||||
"max_tokens": 30
|
||||
},
|
||||
"special_characters": {
|
||||
"id": "edge_003",
|
||||
"prompt": "Translate: \"Hello, world!\" ->",
|
||||
"expected_behavior": "Should handle quotes and punctuation correctly",
|
||||
"max_tokens": 30
|
||||
},
|
||||
"very_long_prompt": {
|
||||
"id": "edge_004",
|
||||
"prompt": "The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. Continue:",
|
||||
"expected_behavior": "Should handle long context without issues",
|
||||
"max_tokens": 50
|
||||
},
|
||||
"unicode": {
|
||||
"id": "edge_005",
|
||||
"prompt": "Translate to English: \u4f60\u597d\u4e16\u754c",
|
||||
"expected_patterns": ["hello", "world"],
|
||||
"max_tokens": 20
|
||||
},
|
||||
"mixed_language": {
|
||||
"id": "edge_006",
|
||||
"prompt": "English and \u65e5\u672c\u8a9e mixed:",
|
||||
"expected_behavior": "Should handle multilingual input",
|
||||
"max_tokens": 50
|
||||
},
|
||||
"numbers": {
|
||||
"id": "edge_007",
|
||||
"prompt": "Continue the sequence: 1, 2, 3, 4,",
|
||||
"expected_patterns": ["5", "6", "7"],
|
||||
"max_tokens": 20
|
||||
},
|
||||
"repetitive": {
|
||||
"id": "edge_008",
|
||||
"prompt": "Hello hello hello hello hello",
|
||||
"expected_behavior": "Should not amplify repetition excessively",
|
||||
"max_tokens": 30
|
||||
}
|
||||
},
|
||||
"stress_tests": {
|
||||
"max_context": {
|
||||
"id": "stress_001",
|
||||
"description": "Test with maximum context length",
|
||||
"prompt_length": 8192,
|
||||
"max_tokens": 100,
|
||||
"notes": "Generate prompt programmatically to fill context"
|
||||
},
|
||||
"long_generation": {
|
||||
"id": "stress_002",
|
||||
"description": "Generate many tokens",
|
||||
"prompt": "Once upon a time",
|
||||
"max_tokens": 2000,
|
||||
"notes": "Test stability over long generation"
|
||||
},
|
||||
"rapid_requests": {
|
||||
"id": "stress_003",
|
||||
"description": "Many rapid sequential requests",
|
||||
"num_requests": 100,
|
||||
"prompt": "Hello",
|
||||
"max_tokens": 10
|
||||
}
|
||||
}
|
||||
}
|
||||
1115
crates/ruvllm/tests/gguf_integration.rs
Normal file
1115
crates/ruvllm/tests/gguf_integration.rs
Normal file
File diff suppressed because it is too large
Load Diff
739
crates/ruvllm/tests/gguf_loader_test.rs
Normal file
739
crates/ruvllm/tests/gguf_loader_test.rs
Normal file
@@ -0,0 +1,739 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! GGUF Loader Integration Tests
|
||||
//!
|
||||
//! Tests for the new GGUF model loading system including:
|
||||
//! - Tensor name mapping for different architectures
|
||||
//! - Progress tracking during loading
|
||||
//! - Layer weight organization
|
||||
//! - Streaming loader for large models
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ============================================================================
|
||||
// TensorNameMapper Tests
|
||||
// ============================================================================
|
||||
|
||||
/// Simulated tensor name mapper for testing (mirrors the real implementation)
|
||||
struct TestTensorNameMapper {
|
||||
architecture: &'static str,
|
||||
}
|
||||
|
||||
impl TestTensorNameMapper {
|
||||
fn new(architecture: &'static str) -> Self {
|
||||
Self { architecture }
|
||||
}
|
||||
|
||||
fn extract_layer_index(&self, name: &str) -> Option<usize> {
|
||||
for pattern in &["layers.", "h.", "blocks.", "block."] {
|
||||
if let Some(pos) = name.find(pattern) {
|
||||
let after = &name[pos + pattern.len()..];
|
||||
if let Some(end) = after.find('.') {
|
||||
if let Ok(idx) = after[..end].parse() {
|
||||
return Some(idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn categorize(&self, name: &str) -> &'static str {
|
||||
let lower = name.to_lowercase();
|
||||
|
||||
if lower.contains("embed") || (lower.contains("token") && lower.contains("weight")) {
|
||||
if lower.contains("output") || lower.contains("lm_head") {
|
||||
return "OutputHead";
|
||||
}
|
||||
return "Embedding";
|
||||
}
|
||||
|
||||
if lower.contains("lm_head") || (lower.contains("output") && !lower.contains("attn")) {
|
||||
return "OutputHead";
|
||||
}
|
||||
|
||||
if lower.contains("attn") || lower.contains("attention") {
|
||||
if lower.contains("q_proj") || lower.contains(".wq.") || lower.contains("query") {
|
||||
return "AttentionQuery";
|
||||
}
|
||||
if lower.contains("k_proj") || lower.contains(".wk.") || lower.contains("key") {
|
||||
return "AttentionKey";
|
||||
}
|
||||
if lower.contains("v_proj") || lower.contains(".wv.") || lower.contains("value") {
|
||||
return "AttentionValue";
|
||||
}
|
||||
if lower.contains("o_proj") || lower.contains(".wo.") || lower.contains("out_proj") {
|
||||
return "AttentionOutput";
|
||||
}
|
||||
}
|
||||
|
||||
if lower.contains("mlp") || lower.contains("ffn") || lower.contains("feed_forward") {
|
||||
if lower.contains("gate") || lower.contains(".w1.") {
|
||||
return "FfnGate";
|
||||
}
|
||||
if lower.contains("up") || lower.contains(".w3.") {
|
||||
return "FfnUp";
|
||||
}
|
||||
if lower.contains("down") || lower.contains(".w2.") {
|
||||
return "FfnDown";
|
||||
}
|
||||
}
|
||||
|
||||
if lower.contains("norm") || lower.contains("ln_") || lower.contains("layer_norm") {
|
||||
if lower.contains("final") || lower.contains("model.norm") || !lower.contains("layers")
|
||||
{
|
||||
return "FinalNorm";
|
||||
}
|
||||
return "LayerNorm";
|
||||
}
|
||||
|
||||
"Other"
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llama_tensor_name_mapping() {
|
||||
let mapper = TestTensorNameMapper::new("llama");
|
||||
|
||||
// Test layer extraction
|
||||
assert_eq!(
|
||||
mapper.extract_layer_index("model.layers.0.self_attn.q_proj.weight"),
|
||||
Some(0)
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.extract_layer_index("model.layers.31.mlp.gate_proj.weight"),
|
||||
Some(31)
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.extract_layer_index("model.embed_tokens.weight"),
|
||||
None
|
||||
);
|
||||
assert_eq!(mapper.extract_layer_index("lm_head.weight"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_phi_tensor_name_mapping() {
|
||||
let mapper = TestTensorNameMapper::new("phi");
|
||||
|
||||
// Phi uses transformer.h.N pattern
|
||||
assert_eq!(
|
||||
mapper.extract_layer_index("transformer.h.0.mixer.Wqkv.weight"),
|
||||
Some(0)
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.extract_layer_index("transformer.h.15.mlp.fc1.weight"),
|
||||
Some(15)
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.extract_layer_index("transformer.embd.wte.weight"),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qwen_tensor_name_mapping() {
|
||||
let mapper = TestTensorNameMapper::new("qwen");
|
||||
|
||||
// Qwen uses transformer.h.N pattern like GPT-2
|
||||
assert_eq!(
|
||||
mapper.extract_layer_index("transformer.h.0.attn.c_attn.weight"),
|
||||
Some(0)
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.extract_layer_index("transformer.h.23.mlp.w1.weight"),
|
||||
Some(23)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_categorization_attention() {
|
||||
let mapper = TestTensorNameMapper::new("llama");
|
||||
|
||||
assert_eq!(
|
||||
mapper.categorize("model.layers.0.self_attn.q_proj.weight"),
|
||||
"AttentionQuery"
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.categorize("model.layers.0.self_attn.k_proj.weight"),
|
||||
"AttentionKey"
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.categorize("model.layers.0.self_attn.v_proj.weight"),
|
||||
"AttentionValue"
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.categorize("model.layers.0.self_attn.o_proj.weight"),
|
||||
"AttentionOutput"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_categorization_mlp() {
|
||||
let mapper = TestTensorNameMapper::new("llama");
|
||||
|
||||
assert_eq!(
|
||||
mapper.categorize("model.layers.0.mlp.gate_proj.weight"),
|
||||
"FfnGate"
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.categorize("model.layers.0.mlp.up_proj.weight"),
|
||||
"FfnUp"
|
||||
);
|
||||
assert_eq!(
|
||||
mapper.categorize("model.layers.0.mlp.down_proj.weight"),
|
||||
"FfnDown"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_categorization_embedding() {
|
||||
let mapper = TestTensorNameMapper::new("llama");
|
||||
|
||||
assert_eq!(mapper.categorize("model.embed_tokens.weight"), "Embedding");
|
||||
assert_eq!(mapper.categorize("lm_head.weight"), "OutputHead");
|
||||
assert_eq!(mapper.categorize("model.norm.weight"), "FinalNorm");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LoadProgress Tests
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestLoadProgress {
|
||||
total_tensors: usize,
|
||||
loaded_tensors: usize,
|
||||
total_bytes: usize,
|
||||
loaded_bytes: usize,
|
||||
}
|
||||
|
||||
impl TestLoadProgress {
|
||||
fn percent(&self) -> f32 {
|
||||
if self.total_tensors == 0 {
|
||||
return 100.0;
|
||||
}
|
||||
(self.loaded_tensors as f32 / self.total_tensors as f32) * 100.0
|
||||
}
|
||||
|
||||
fn byte_percent(&self) -> f32 {
|
||||
if self.total_bytes == 0 {
|
||||
return 100.0;
|
||||
}
|
||||
(self.loaded_bytes as f32 / self.total_bytes as f32) * 100.0
|
||||
}
|
||||
|
||||
fn is_complete(&self) -> bool {
|
||||
self.loaded_tensors >= self.total_tensors
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_progress_calculation() {
|
||||
let progress = TestLoadProgress {
|
||||
total_tensors: 100,
|
||||
loaded_tensors: 25,
|
||||
total_bytes: 1_000_000,
|
||||
loaded_bytes: 250_000,
|
||||
};
|
||||
|
||||
assert!((progress.percent() - 25.0).abs() < 0.001);
|
||||
assert!((progress.byte_percent() - 25.0).abs() < 0.001);
|
||||
assert!(!progress.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_progress_complete() {
|
||||
let progress = TestLoadProgress {
|
||||
total_tensors: 50,
|
||||
loaded_tensors: 50,
|
||||
total_bytes: 500_000,
|
||||
loaded_bytes: 500_000,
|
||||
};
|
||||
|
||||
assert!((progress.percent() - 100.0).abs() < 0.001);
|
||||
assert!(progress.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_progress_empty() {
|
||||
let progress = TestLoadProgress {
|
||||
total_tensors: 0,
|
||||
loaded_tensors: 0,
|
||||
total_bytes: 0,
|
||||
loaded_bytes: 0,
|
||||
};
|
||||
|
||||
// Empty should be considered complete
|
||||
assert!((progress.percent() - 100.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LoadConfig Tests
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestLoadConfig {
|
||||
use_mmap: bool,
|
||||
keep_quantized: bool,
|
||||
tensor_filter: Vec<String>,
|
||||
layer_filter: Vec<usize>,
|
||||
num_threads: usize,
|
||||
}
|
||||
|
||||
impl TestLoadConfig {
|
||||
fn with_mmap(mut self, enabled: bool) -> Self {
|
||||
self.use_mmap = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_quantized(mut self, keep: bool) -> Self {
|
||||
self.keep_quantized = keep;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_tensor_filter(mut self, tensors: Vec<String>) -> Self {
|
||||
self.tensor_filter = tensors;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_layer_filter(mut self, layers: Vec<usize>) -> Self {
|
||||
self.layer_filter = layers;
|
||||
self
|
||||
}
|
||||
|
||||
fn with_threads(mut self, threads: usize) -> Self {
|
||||
self.num_threads = threads;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_config_builder() {
|
||||
let config = TestLoadConfig::default()
|
||||
.with_mmap(true)
|
||||
.with_quantized(true)
|
||||
.with_threads(8)
|
||||
.with_layer_filter(vec![0, 1, 2, 3])
|
||||
.with_tensor_filter(vec!["attention".to_string()]);
|
||||
|
||||
assert!(config.use_mmap);
|
||||
assert!(config.keep_quantized);
|
||||
assert_eq!(config.num_threads, 8);
|
||||
assert_eq!(config.layer_filter, vec![0, 1, 2, 3]);
|
||||
assert_eq!(config.tensor_filter, vec!["attention".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_config_defaults() {
|
||||
let config = TestLoadConfig::default();
|
||||
|
||||
assert!(!config.use_mmap);
|
||||
assert!(!config.keep_quantized);
|
||||
assert_eq!(config.num_threads, 0);
|
||||
assert!(config.layer_filter.is_empty());
|
||||
assert!(config.tensor_filter.is_empty());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Architecture-Specific Tensor Mapping Tests
|
||||
// ============================================================================
|
||||
|
||||
struct ArchitectureTensorMap {
|
||||
embed_tokens: &'static str,
|
||||
q_proj_pattern: &'static str,
|
||||
k_proj_pattern: &'static str,
|
||||
v_proj_pattern: &'static str,
|
||||
o_proj_pattern: &'static str,
|
||||
gate_proj_pattern: &'static str,
|
||||
up_proj_pattern: &'static str,
|
||||
down_proj_pattern: &'static str,
|
||||
final_norm: &'static str,
|
||||
lm_head: &'static str,
|
||||
}
|
||||
|
||||
impl ArchitectureTensorMap {
|
||||
fn llama() -> Self {
|
||||
Self {
|
||||
embed_tokens: "model.embed_tokens.weight",
|
||||
q_proj_pattern: "model.layers.{}.self_attn.q_proj.weight",
|
||||
k_proj_pattern: "model.layers.{}.self_attn.k_proj.weight",
|
||||
v_proj_pattern: "model.layers.{}.self_attn.v_proj.weight",
|
||||
o_proj_pattern: "model.layers.{}.self_attn.o_proj.weight",
|
||||
gate_proj_pattern: "model.layers.{}.mlp.gate_proj.weight",
|
||||
up_proj_pattern: "model.layers.{}.mlp.up_proj.weight",
|
||||
down_proj_pattern: "model.layers.{}.mlp.down_proj.weight",
|
||||
final_norm: "model.norm.weight",
|
||||
lm_head: "lm_head.weight",
|
||||
}
|
||||
}
|
||||
|
||||
fn mistral() -> Self {
|
||||
// Mistral uses same naming as Llama
|
||||
Self::llama()
|
||||
}
|
||||
|
||||
fn phi() -> Self {
|
||||
Self {
|
||||
embed_tokens: "transformer.embd.wte.weight",
|
||||
q_proj_pattern: "transformer.h.{}.mixer.Wqkv.weight",
|
||||
k_proj_pattern: "transformer.h.{}.mixer.Wqkv.weight",
|
||||
v_proj_pattern: "transformer.h.{}.mixer.Wqkv.weight",
|
||||
o_proj_pattern: "transformer.h.{}.mixer.out_proj.weight",
|
||||
gate_proj_pattern: "transformer.h.{}.mlp.fc1.weight",
|
||||
up_proj_pattern: "transformer.h.{}.mlp.fc1.weight",
|
||||
down_proj_pattern: "transformer.h.{}.mlp.fc2.weight",
|
||||
final_norm: "transformer.ln_f.weight",
|
||||
lm_head: "lm_head.weight",
|
||||
}
|
||||
}
|
||||
|
||||
fn gemma() -> Self {
|
||||
Self {
|
||||
embed_tokens: "model.embed_tokens.weight",
|
||||
q_proj_pattern: "model.layers.{}.self_attn.q_proj.weight",
|
||||
k_proj_pattern: "model.layers.{}.self_attn.k_proj.weight",
|
||||
v_proj_pattern: "model.layers.{}.self_attn.v_proj.weight",
|
||||
o_proj_pattern: "model.layers.{}.self_attn.o_proj.weight",
|
||||
gate_proj_pattern: "model.layers.{}.mlp.gate_proj.weight",
|
||||
up_proj_pattern: "model.layers.{}.mlp.up_proj.weight",
|
||||
down_proj_pattern: "model.layers.{}.mlp.down_proj.weight",
|
||||
final_norm: "model.norm.weight",
|
||||
lm_head: "model.embed_tokens.weight", // Tied embeddings
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_tensor(&self, pattern: &str, layer: usize) -> String {
|
||||
pattern.replace("{}", &layer.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llama_tensor_patterns() {
|
||||
let map = ArchitectureTensorMap::llama();
|
||||
|
||||
assert_eq!(
|
||||
map.layer_tensor(map.q_proj_pattern, 0),
|
||||
"model.layers.0.self_attn.q_proj.weight"
|
||||
);
|
||||
assert_eq!(
|
||||
map.layer_tensor(map.gate_proj_pattern, 15),
|
||||
"model.layers.15.mlp.gate_proj.weight"
|
||||
);
|
||||
assert_eq!(
|
||||
map.layer_tensor(map.down_proj_pattern, 31),
|
||||
"model.layers.31.mlp.down_proj.weight"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_phi_tensor_patterns() {
|
||||
let map = ArchitectureTensorMap::phi();
|
||||
|
||||
assert_eq!(
|
||||
map.layer_tensor(map.q_proj_pattern, 0),
|
||||
"transformer.h.0.mixer.Wqkv.weight"
|
||||
);
|
||||
assert_eq!(
|
||||
map.layer_tensor(map.o_proj_pattern, 7),
|
||||
"transformer.h.7.mixer.out_proj.weight"
|
||||
);
|
||||
assert_eq!(
|
||||
map.layer_tensor(map.down_proj_pattern, 23),
|
||||
"transformer.h.23.mlp.fc2.weight"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gemma_tied_embeddings() {
|
||||
let map = ArchitectureTensorMap::gemma();
|
||||
|
||||
// Gemma ties lm_head to embed_tokens
|
||||
assert_eq!(map.embed_tokens, map.lm_head);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Weight Tensor Tests
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Clone)]
|
||||
enum TestWeightTensor {
|
||||
F32(Vec<f32>, Vec<usize>),
|
||||
Quantized {
|
||||
data: Vec<u8>,
|
||||
quant_type: u32,
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
impl TestWeightTensor {
|
||||
fn shape(&self) -> &[usize] {
|
||||
match self {
|
||||
TestWeightTensor::F32(_, shape) => shape,
|
||||
TestWeightTensor::Quantized { shape, .. } => shape,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_quantized(&self) -> bool {
|
||||
matches!(self, TestWeightTensor::Quantized { .. })
|
||||
}
|
||||
|
||||
fn memory_bytes(&self) -> usize {
|
||||
match self {
|
||||
TestWeightTensor::F32(data, _) => data.len() * 4,
|
||||
TestWeightTensor::Quantized { data, .. } => data.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weight_tensor_f32() {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let shape = vec![2, 3];
|
||||
let tensor = TestWeightTensor::F32(data.clone(), shape.clone());
|
||||
|
||||
assert!(!tensor.is_quantized());
|
||||
assert_eq!(tensor.shape(), &[2, 3]);
|
||||
assert_eq!(tensor.memory_bytes(), 24); // 6 floats * 4 bytes
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weight_tensor_quantized() {
|
||||
let data = vec![0u8; 18]; // One Q4_0 block (2 bytes scale + 16 bytes data)
|
||||
let tensor = TestWeightTensor::Quantized {
|
||||
data: data.clone(),
|
||||
quant_type: 2, // Q4_0
|
||||
shape: vec![32],
|
||||
};
|
||||
|
||||
assert!(tensor.is_quantized());
|
||||
assert_eq!(tensor.shape(), &[32]);
|
||||
assert_eq!(tensor.memory_bytes(), 18);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Loader Simulation Tests
|
||||
// ============================================================================
|
||||
|
||||
struct TestStreamingLoader {
|
||||
total_layers: usize,
|
||||
current_layer: usize,
|
||||
}
|
||||
|
||||
impl TestStreamingLoader {
|
||||
fn new(total_layers: usize) -> Self {
|
||||
Self {
|
||||
total_layers,
|
||||
current_layer: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn has_more_layers(&self) -> bool {
|
||||
self.current_layer < self.total_layers
|
||||
}
|
||||
|
||||
fn load_next_layer(&mut self) -> Option<usize> {
|
||||
if self.current_layer >= self.total_layers {
|
||||
return None;
|
||||
}
|
||||
let layer = self.current_layer;
|
||||
self.current_layer += 1;
|
||||
Some(layer)
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.current_layer = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_loader_basic() {
|
||||
let mut loader = TestStreamingLoader::new(32);
|
||||
|
||||
assert!(loader.has_more_layers());
|
||||
assert_eq!(loader.load_next_layer(), Some(0));
|
||||
assert_eq!(loader.load_next_layer(), Some(1));
|
||||
assert!(loader.has_more_layers());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_loader_exhaust() {
|
||||
let mut loader = TestStreamingLoader::new(3);
|
||||
|
||||
assert_eq!(loader.load_next_layer(), Some(0));
|
||||
assert_eq!(loader.load_next_layer(), Some(1));
|
||||
assert_eq!(loader.load_next_layer(), Some(2));
|
||||
assert!(!loader.has_more_layers());
|
||||
assert_eq!(loader.load_next_layer(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_loader_reset() {
|
||||
let mut loader = TestStreamingLoader::new(5);
|
||||
|
||||
// Load some layers
|
||||
loader.load_next_layer();
|
||||
loader.load_next_layer();
|
||||
|
||||
// Reset
|
||||
loader.reset();
|
||||
|
||||
// Should start from beginning
|
||||
assert_eq!(loader.load_next_layer(), Some(0));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Model Configuration Tests
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct TestModelConfig {
|
||||
architecture: Option<String>,
|
||||
context_length: Option<usize>,
|
||||
embedding_length: Option<usize>,
|
||||
head_count: Option<usize>,
|
||||
head_count_kv: Option<usize>,
|
||||
layer_count: Option<usize>,
|
||||
vocab_size: Option<usize>,
|
||||
rope_freq_base: Option<f32>,
|
||||
feed_forward_length: Option<usize>,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config_llama_7b() {
|
||||
let config = TestModelConfig {
|
||||
architecture: Some("llama".to_string()),
|
||||
context_length: Some(4096),
|
||||
embedding_length: Some(4096),
|
||||
head_count: Some(32),
|
||||
head_count_kv: Some(32),
|
||||
layer_count: Some(32),
|
||||
vocab_size: Some(32000),
|
||||
rope_freq_base: Some(10000.0),
|
||||
feed_forward_length: Some(11008),
|
||||
};
|
||||
|
||||
assert_eq!(config.architecture, Some("llama".to_string()));
|
||||
assert_eq!(config.layer_count, Some(32));
|
||||
assert_eq!(config.head_count, Some(32));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config_mistral_7b() {
|
||||
let config = TestModelConfig {
|
||||
architecture: Some("mistral".to_string()),
|
||||
context_length: Some(32768),
|
||||
embedding_length: Some(4096),
|
||||
head_count: Some(32),
|
||||
head_count_kv: Some(8), // GQA with 8 KV heads
|
||||
layer_count: Some(32),
|
||||
vocab_size: Some(32000),
|
||||
rope_freq_base: Some(10000.0),
|
||||
feed_forward_length: Some(14336),
|
||||
};
|
||||
|
||||
assert_eq!(config.head_count_kv, Some(8)); // GQA
|
||||
assert_eq!(config.context_length, Some(32768)); // Larger context
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config_phi2() {
|
||||
let config = TestModelConfig {
|
||||
architecture: Some("phi".to_string()),
|
||||
context_length: Some(2048),
|
||||
embedding_length: Some(2560),
|
||||
head_count: Some(32),
|
||||
head_count_kv: Some(32),
|
||||
layer_count: Some(32),
|
||||
vocab_size: Some(51200),
|
||||
rope_freq_base: Some(10000.0),
|
||||
feed_forward_length: Some(10240),
|
||||
};
|
||||
|
||||
assert_eq!(config.embedding_length, Some(2560));
|
||||
assert_eq!(config.vocab_size, Some(51200));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Estimation Tests
|
||||
// ============================================================================
|
||||
|
||||
fn estimate_model_memory(config: &TestModelConfig, quant_type: &str) -> usize {
|
||||
let vocab = config.vocab_size.unwrap_or(32000);
|
||||
let hidden = config.embedding_length.unwrap_or(4096);
|
||||
let layers = config.layer_count.unwrap_or(32);
|
||||
let ff_hidden = config.feed_forward_length.unwrap_or(hidden * 4);
|
||||
|
||||
// Bytes per parameter based on quantization
|
||||
let bytes_per_param: f32 = match quant_type {
|
||||
"F32" => 4.0,
|
||||
"F16" => 2.0,
|
||||
"Q8_0" => 1.0625, // ~8.5 bits per weight
|
||||
"Q4_K" => 0.5625, // ~4.5 bits per weight
|
||||
"Q4_0" => 0.5625,
|
||||
"Q2_K" => 0.325, // ~2.6 bits per weight
|
||||
_ => 4.0,
|
||||
};
|
||||
|
||||
// Embedding: vocab_size * hidden_size
|
||||
let embed_params = vocab * hidden;
|
||||
|
||||
// Per layer:
|
||||
// - Attention: 4 * hidden^2 (Q, K, V, O projections)
|
||||
// - MLP: 3 * hidden * ff_hidden (gate, up, down)
|
||||
let attn_params_per_layer = 4 * hidden * hidden;
|
||||
let mlp_params_per_layer = 3 * hidden * ff_hidden;
|
||||
let layer_params = attn_params_per_layer + mlp_params_per_layer;
|
||||
|
||||
// Total
|
||||
let total_params = embed_params + (layers * layer_params) + (vocab * hidden); // + LM head
|
||||
|
||||
(total_params as f32 * bytes_per_param) as usize
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_estimation_llama_7b() {
|
||||
let config = TestModelConfig {
|
||||
architecture: Some("llama".to_string()),
|
||||
embedding_length: Some(4096),
|
||||
layer_count: Some(32),
|
||||
vocab_size: Some(32000),
|
||||
feed_forward_length: Some(11008),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let f32_size = estimate_model_memory(&config, "F32");
|
||||
let q4_size = estimate_model_memory(&config, "Q4_K");
|
||||
|
||||
// F32 ~7B params * 4 bytes = ~28GB
|
||||
// Q4_K ~7B params * 0.5625 bytes = ~4GB
|
||||
assert!(f32_size > 20_000_000_000); // > 20GB
|
||||
assert!(q4_size < 6_000_000_000); // < 6GB
|
||||
assert!(f32_size > q4_size * 5); // F32 should be ~7x larger
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_estimation_small_model() {
|
||||
let config = TestModelConfig {
|
||||
architecture: Some("phi".to_string()),
|
||||
embedding_length: Some(2560),
|
||||
layer_count: Some(24),
|
||||
vocab_size: Some(51200),
|
||||
feed_forward_length: Some(10240),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let q4_size = estimate_model_memory(&config, "Q4_K");
|
||||
|
||||
// Phi-2 is smaller, Q4_K should be < 2GB
|
||||
assert!(q4_size < 3_000_000_000);
|
||||
}
|
||||
1149
crates/ruvllm/tests/kernel_integration.rs
Normal file
1149
crates/ruvllm/tests/kernel_integration.rs
Normal file
File diff suppressed because it is too large
Load Diff
528
crates/ruvllm/tests/lora_integration.rs
Normal file
528
crates/ruvllm/tests/lora_integration.rs
Normal file
@@ -0,0 +1,528 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Integration tests for LoRA (Low-Rank Adaptation)
|
||||
//!
|
||||
//! Tests MicroLoRA adaptation, forward pass, gradient accumulation,
|
||||
//! EWC state management, and serialization.
|
||||
|
||||
use ruvllm::{
|
||||
error::Result,
|
||||
lora::{AdaptFeedback, LoraAdapter, MicroLoRA, MicroLoraConfig, TargetModule},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Create a test MicroLoRA configuration
|
||||
fn create_test_config(dim: usize) -> MicroLoraConfig {
|
||||
MicroLoraConfig {
|
||||
rank: 2,
|
||||
alpha: 4.0,
|
||||
dropout: 0.0,
|
||||
target_modules: vec![TargetModule::QProj, TargetModule::VProj],
|
||||
in_features: dim,
|
||||
out_features: dim,
|
||||
use_bias: false,
|
||||
standard_init: true,
|
||||
gradient_checkpointing: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create test input data
|
||||
fn create_test_input(dim: usize) -> Vec<f32> {
|
||||
(0..dim).map(|i| (i as f32) * 0.01).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_creation() {
|
||||
let config = create_test_config(256);
|
||||
let lora = MicroLoRA::new(config);
|
||||
|
||||
assert_eq!(lora.config().rank, 2);
|
||||
assert_eq!(lora.config().alpha, 4.0);
|
||||
assert!(lora.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_forward() {
|
||||
let config = create_test_config(64);
|
||||
let lora = MicroLoRA::new(config);
|
||||
|
||||
let input = create_test_input(64);
|
||||
|
||||
// Forward pass for Q projection
|
||||
let output = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
assert_eq!(output.len(), 64);
|
||||
assert!(output.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_adapt_changes_output() {
|
||||
let config = MicroLoraConfig {
|
||||
rank: 2,
|
||||
alpha: 4.0,
|
||||
dropout: 0.0,
|
||||
target_modules: vec![TargetModule::QProj],
|
||||
in_features: 64,
|
||||
out_features: 64,
|
||||
use_bias: false,
|
||||
standard_init: true,
|
||||
gradient_checkpointing: false,
|
||||
};
|
||||
|
||||
let lora = MicroLoRA::new(config);
|
||||
let input = create_test_input(64);
|
||||
|
||||
// Forward pass before adaptation
|
||||
let output_before = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
// Apply adaptation with feedback
|
||||
let feedback = AdaptFeedback::from_quality(0.8);
|
||||
lora.adapt(&input, feedback).unwrap();
|
||||
|
||||
// Apply accumulated updates
|
||||
lora.apply_updates(0.01);
|
||||
|
||||
// Forward pass after adaptation
|
||||
let output_after = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
assert_eq!(output_before.len(), output_after.len());
|
||||
|
||||
// Output should change after adaptation
|
||||
let changed = output_before
|
||||
.iter()
|
||||
.zip(output_after.iter())
|
||||
.any(|(a, b)| (a - b).abs() > 1e-10);
|
||||
let all_near_zero = output_before.iter().all(|&v| v.abs() < 1e-6);
|
||||
|
||||
assert!(
|
||||
changed || all_near_zero,
|
||||
"Adaptation should change output or both should be zero"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_forward_dimensions() {
|
||||
let input_dim = 128;
|
||||
let output_dim = 128;
|
||||
|
||||
let config = MicroLoraConfig {
|
||||
rank: 2,
|
||||
alpha: 4.0,
|
||||
dropout: 0.0,
|
||||
target_modules: vec![TargetModule::QProj],
|
||||
in_features: input_dim,
|
||||
out_features: output_dim,
|
||||
use_bias: false,
|
||||
standard_init: true,
|
||||
gradient_checkpointing: false,
|
||||
};
|
||||
|
||||
let lora = MicroLoRA::new(config);
|
||||
let input = create_test_input(input_dim);
|
||||
let output = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
assert_eq!(output.len(), output_dim);
|
||||
assert!(output.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_adapter_creation() {
|
||||
let adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
||||
|
||||
assert_eq!(adapter.rank(), 2);
|
||||
assert_eq!(adapter.param_count(), 64 * 2 + 2 * 64); // A matrix + B matrix
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_adapter_forward() {
|
||||
let adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
||||
let input = ndarray::Array1::from_vec(create_test_input(64));
|
||||
|
||||
let output = adapter.forward(&input);
|
||||
|
||||
assert_eq!(output.len(), 64);
|
||||
assert!(output.iter().all(|&v| v.is_finite()));
|
||||
|
||||
// With zero-initialized B, output should be zero
|
||||
let sum: f32 = output.iter().sum();
|
||||
assert!(
|
||||
sum.abs() < 1e-6,
|
||||
"Initial forward should be ~0, got {}",
|
||||
sum
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_adapter_gradient_accumulation() {
|
||||
let mut adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
||||
let input = ndarray::Array1::from_elem(64, 0.1);
|
||||
let grad_output = ndarray::Array1::from_elem(64, 0.1);
|
||||
|
||||
// Accumulate gradient
|
||||
adapter.accumulate_gradient(&input, &grad_output, 0.8);
|
||||
assert_eq!(adapter.pending_updates(), 1);
|
||||
|
||||
// Apply gradients
|
||||
adapter.apply_gradients(0.01);
|
||||
assert_eq!(adapter.pending_updates(), 0);
|
||||
|
||||
// After update, forward should produce non-zero output
|
||||
let output = adapter.forward(&input);
|
||||
let sum: f32 = output.iter().map(|x| x.abs()).sum();
|
||||
assert!(sum > 0.0, "After update, output should be non-zero");
|
||||
}
|
||||
|
||||
// Note: EwcState is not exported from the lora module, so EWC-specific
|
||||
// tests are implemented in the unit tests within micro_lora.rs
|
||||
|
||||
#[test]
|
||||
fn test_adapt_feedback_creation() {
|
||||
let feedback = AdaptFeedback::from_quality(0.85);
|
||||
|
||||
assert_eq!(feedback.quality, 0.85);
|
||||
assert_eq!(feedback.reward, Some(0.85));
|
||||
assert!(feedback.gradient_estimate.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapt_feedback_with_gradient() {
|
||||
let gradient = vec![0.1; 64];
|
||||
let feedback = AdaptFeedback::with_gradient(0.9, gradient.clone());
|
||||
|
||||
assert_eq!(feedback.quality, 0.9);
|
||||
assert_eq!(feedback.gradient_estimate.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapt_feedback_for_module() {
|
||||
let feedback = AdaptFeedback::from_quality(0.8).for_module(TargetModule::QProj);
|
||||
|
||||
assert_eq!(feedback.source_module, Some(TargetModule::QProj));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapt_feedback_with_session() {
|
||||
let feedback = AdaptFeedback::from_quality(0.8).with_session("session-123".to_string());
|
||||
|
||||
assert_eq!(feedback.session_id, Some("session-123".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_adaptations() {
|
||||
let config = create_test_config(64);
|
||||
let lora = MicroLoRA::new(config);
|
||||
let input = create_test_input(64);
|
||||
|
||||
// Multiple adaptation cycles
|
||||
for i in 0..5 {
|
||||
let quality = 0.5 + (i as f32 * 0.1);
|
||||
let feedback = AdaptFeedback::from_quality(quality);
|
||||
lora.adapt(&input, feedback).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(lora.adaptation_count(), 5);
|
||||
|
||||
// Apply updates
|
||||
lora.apply_updates(0.01);
|
||||
|
||||
// Verify output is valid
|
||||
let output = lora.forward(&input, &TargetModule::QProj);
|
||||
assert_eq!(output.len(), 64);
|
||||
assert!(output.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_with_different_ranks() {
|
||||
let ranks = [1, 2];
|
||||
let input = create_test_input(64);
|
||||
|
||||
for rank in ranks {
|
||||
let config = MicroLoraConfig {
|
||||
rank,
|
||||
alpha: rank as f32 * 2.0,
|
||||
dropout: 0.0,
|
||||
target_modules: vec![TargetModule::QProj],
|
||||
in_features: 64,
|
||||
out_features: 64,
|
||||
use_bias: false,
|
||||
standard_init: true,
|
||||
gradient_checkpointing: false,
|
||||
};
|
||||
|
||||
let lora = MicroLoRA::new(config);
|
||||
let output = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
assert_eq!(
|
||||
output.len(),
|
||||
64,
|
||||
"Rank {} should produce correct output size",
|
||||
rank
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_target_module_variants() {
|
||||
let modules = vec![
|
||||
TargetModule::QProj,
|
||||
TargetModule::KProj,
|
||||
TargetModule::VProj,
|
||||
TargetModule::OProj,
|
||||
TargetModule::GateProj,
|
||||
TargetModule::UpProj,
|
||||
TargetModule::DownProj,
|
||||
TargetModule::Embed,
|
||||
TargetModule::LmHead,
|
||||
];
|
||||
|
||||
for module in &modules {
|
||||
let name = module.as_str();
|
||||
assert!(!name.is_empty());
|
||||
}
|
||||
|
||||
assert_eq!(TargetModule::QProj.as_str(), "q_proj");
|
||||
assert_eq!(TargetModule::VProj.as_str(), "v_proj");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_target_module_defaults() {
|
||||
let defaults = TargetModule::defaults();
|
||||
assert_eq!(defaults.len(), 2);
|
||||
assert!(defaults.contains(&TargetModule::QProj));
|
||||
assert!(defaults.contains(&TargetModule::VProj));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_target_module_attention() {
|
||||
let attention = TargetModule::attention();
|
||||
assert_eq!(attention.len(), 4);
|
||||
assert!(attention.contains(&TargetModule::QProj));
|
||||
assert!(attention.contains(&TargetModule::KProj));
|
||||
assert!(attention.contains(&TargetModule::VProj));
|
||||
assert!(attention.contains(&TargetModule::OProj));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_target_module_mlp() {
|
||||
let mlp = TargetModule::mlp();
|
||||
assert_eq!(mlp.len(), 3);
|
||||
assert!(mlp.contains(&TargetModule::GateProj));
|
||||
assert!(mlp.contains(&TargetModule::UpProj));
|
||||
assert!(mlp.contains(&TargetModule::DownProj));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_config_memory() {
|
||||
let config = MicroLoraConfig {
|
||||
rank: 2,
|
||||
alpha: 4.0,
|
||||
dropout: 0.0,
|
||||
target_modules: vec![TargetModule::QProj, TargetModule::VProj],
|
||||
in_features: 768,
|
||||
out_features: 768,
|
||||
use_bias: false,
|
||||
standard_init: true,
|
||||
gradient_checkpointing: false,
|
||||
};
|
||||
|
||||
let memory = config.memory_bytes();
|
||||
// 2 modules * (768 * 2 + 2 * 768) * 4 bytes
|
||||
assert!(memory < 1024 * 1024, "Memory should be < 1MB for MicroLoRA");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_enable_disable() {
|
||||
let config = create_test_config(64);
|
||||
let mut lora = MicroLoRA::new(config);
|
||||
let input = create_test_input(64);
|
||||
|
||||
assert!(lora.is_enabled());
|
||||
|
||||
// Disable
|
||||
lora.set_enabled(false);
|
||||
assert!(!lora.is_enabled());
|
||||
|
||||
// Forward when disabled should return zeros
|
||||
let output = lora.forward(&input, &TargetModule::QProj);
|
||||
assert!(output.iter().all(|&v| v == 0.0));
|
||||
|
||||
// Re-enable
|
||||
lora.set_enabled(true);
|
||||
assert!(lora.is_enabled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_reset() {
|
||||
let config = create_test_config(64);
|
||||
let lora = MicroLoRA::new(config);
|
||||
let input = create_test_input(64);
|
||||
|
||||
// Perform some adaptations
|
||||
for _ in 0..5 {
|
||||
let feedback = AdaptFeedback::from_quality(0.8);
|
||||
lora.adapt(&input, feedback).unwrap();
|
||||
}
|
||||
|
||||
assert!(lora.adaptation_count() > 0);
|
||||
|
||||
// Reset
|
||||
lora.reset();
|
||||
|
||||
assert_eq!(lora.adaptation_count(), 0);
|
||||
assert_eq!(lora.forward_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_memory_usage() {
|
||||
let config = create_test_config(64);
|
||||
let lora = MicroLoRA::new(config);
|
||||
|
||||
let memory = lora.memory_bytes();
|
||||
let params = lora.param_count();
|
||||
|
||||
assert!(memory > 0);
|
||||
assert!(params > 0);
|
||||
assert_eq!(memory, params * std::mem::size_of::<f32>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lora_adapter_simd_forward() {
|
||||
let adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
||||
let input = create_test_input(64);
|
||||
let mut output = vec![0.0f32; 64];
|
||||
|
||||
adapter.forward_simd(&input, &mut output);
|
||||
|
||||
// Compare with regular forward
|
||||
let input_array = ndarray::Array1::from_vec(input.clone());
|
||||
let expected = adapter.forward(&input_array);
|
||||
|
||||
for (o, e) in output.iter().zip(expected.iter()) {
|
||||
assert!(
|
||||
(o - e).abs() < 1e-5,
|
||||
"SIMD forward mismatch: {} vs {}",
|
||||
o,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_with_custom_dimensions() {
|
||||
let config = MicroLoraConfig {
|
||||
rank: 2,
|
||||
alpha: 4.0,
|
||||
dropout: 0.0,
|
||||
target_modules: vec![TargetModule::QProj, TargetModule::VProj],
|
||||
in_features: 256, // Default dimensions
|
||||
out_features: 256,
|
||||
use_bias: false,
|
||||
standard_init: true,
|
||||
gradient_checkpointing: false,
|
||||
};
|
||||
|
||||
// Create with custom dimensions per module
|
||||
let mut dimensions = HashMap::new();
|
||||
dimensions.insert(TargetModule::QProj, (128, 128));
|
||||
dimensions.insert(TargetModule::VProj, (128, 128));
|
||||
|
||||
let lora = MicroLoRA::with_dimensions(config, dimensions);
|
||||
|
||||
let input = create_test_input(128);
|
||||
let output = lora.forward(&input, &TargetModule::QProj);
|
||||
|
||||
assert_eq!(output.len(), 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_micro_lora_save_load() {
|
||||
let config = create_test_config(64);
|
||||
let lora = MicroLoRA::new(config);
|
||||
let input = create_test_input(64);
|
||||
|
||||
// Apply some adaptation
|
||||
let feedback = AdaptFeedback::from_quality(0.85);
|
||||
lora.adapt(&input, feedback).unwrap();
|
||||
lora.apply_updates(0.01);
|
||||
|
||||
// Export state
|
||||
let state = lora.export_state();
|
||||
|
||||
assert_eq!(state.config.rank, 2);
|
||||
assert!(!state.adapters.is_empty());
|
||||
|
||||
// Restore from state
|
||||
let lora_restored = MicroLoRA::from_state(state).unwrap();
|
||||
|
||||
// Both should produce same output
|
||||
let output_original = lora.forward(&input, &TargetModule::QProj);
|
||||
let output_restored = lora_restored.forward(&input, &TargetModule::QProj);
|
||||
|
||||
for (a, b) in output_original.iter().zip(output_restored.iter()) {
|
||||
assert!(
|
||||
(a - b).abs() < 1e-5,
|
||||
"Restored model should match: {} vs {}",
|
||||
a,
|
||||
b
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: test_lora_apply_updates_with_ewc removed as EwcState is not exported
|
||||
|
||||
#[test]
|
||||
fn test_lora_adapter_reset() {
|
||||
let mut adapter = LoraAdapter::new(64, 64, 2, 4.0);
|
||||
let input = ndarray::Array1::from_elem(64, 0.1);
|
||||
let grad_output = ndarray::Array1::from_elem(64, 0.1);
|
||||
|
||||
// Accumulate some gradients and apply
|
||||
adapter.accumulate_gradient(&input, &grad_output, 0.8);
|
||||
adapter.apply_gradients(0.01);
|
||||
|
||||
// Reset
|
||||
adapter.reset();
|
||||
|
||||
assert_eq!(adapter.pending_updates(), 0);
|
||||
|
||||
// B matrix should be reset to zero
|
||||
let output = adapter.forward(&input);
|
||||
let sum: f32 = output.iter().sum();
|
||||
assert!(sum.abs() < 1e-6, "After reset, output should be ~0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_for_hidden_dim() {
|
||||
let config = MicroLoraConfig::for_hidden_dim(512);
|
||||
|
||||
assert_eq!(config.in_features, 512);
|
||||
assert_eq!(config.out_features, 512);
|
||||
assert_eq!(config.rank, 2); // Default rank
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_builder_methods() {
|
||||
let config = MicroLoraConfig::for_hidden_dim(256)
|
||||
.with_rank(1)
|
||||
.with_alpha(8.0)
|
||||
.with_targets(vec![
|
||||
TargetModule::QProj,
|
||||
TargetModule::KProj,
|
||||
TargetModule::VProj,
|
||||
]);
|
||||
|
||||
assert_eq!(config.rank, 1);
|
||||
assert_eq!(config.alpha, 8.0);
|
||||
assert_eq!(config.target_modules.len(), 3);
|
||||
}
|
||||
698
crates/ruvllm/tests/mistral_backend_test.rs
Normal file
698
crates/ruvllm/tests/mistral_backend_test.rs
Normal file
@@ -0,0 +1,698 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Integration tests for mistral-rs backend
|
||||
//!
|
||||
//! Tests the mistral-rs backend integration including:
|
||||
//! - Backend creation and configuration
|
||||
//! - PagedAttention integration
|
||||
//! - X-LoRA adapter management
|
||||
//! - ISQ (In-Situ Quantization) configuration
|
||||
//! - Model loading and generation (requires model files)
|
||||
//!
|
||||
//! ## Running Tests
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Run basic tests (no model required)
|
||||
//! cargo test --features mistral-rs mistral_backend
|
||||
//!
|
||||
//! # Run all tests including model-dependent ones
|
||||
//! cargo test --features mistral-rs mistral_backend -- --include-ignored
|
||||
//!
|
||||
//! # Run with Metal acceleration
|
||||
//! cargo test --features mistral-rs-metal mistral_backend
|
||||
//! ```
|
||||
|
||||
#![cfg(feature = "mistral-rs")]
|
||||
|
||||
use ruvllm::backends::mistral_backend::{
|
||||
IsqConfig, IsqMethod, MistralBackend, MistralBackendConfig, PagedAttentionConfigExt,
|
||||
XLoraConfig, XLoraManager, XLoraMixingMode,
|
||||
};
|
||||
use ruvllm::backends::{
|
||||
DType, DeviceType, GenerateParams, LlmBackend, ModelArchitecture, ModelConfig, Quantization,
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
// ============================================================================
|
||||
// Backend Creation Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_mistral_backend_creation() {
|
||||
let backend = MistralBackend::new().unwrap();
|
||||
assert!(!backend.is_model_loaded());
|
||||
assert!(backend.model_info().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_backend_default() {
|
||||
let backend = MistralBackend::default();
|
||||
assert!(!backend.is_model_loaded());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_backend_for_metal() {
|
||||
let result = MistralBackend::for_metal();
|
||||
assert!(result.is_ok());
|
||||
let backend = result.unwrap();
|
||||
assert!(!backend.is_model_loaded());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_backend_for_cuda() {
|
||||
let result = MistralBackend::for_cuda(0);
|
||||
assert!(result.is_ok());
|
||||
let backend = result.unwrap();
|
||||
assert!(!backend.is_model_loaded());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_backend_with_custom_config() {
|
||||
let config = MistralBackendConfig::default()
|
||||
.with_max_seq_len(16384)
|
||||
.with_max_batch_size(64);
|
||||
|
||||
let backend = MistralBackend::with_config(config).unwrap();
|
||||
assert!(!backend.is_model_loaded());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Configuration Builder Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_mistral_config_builder() {
|
||||
let config = MistralBackendConfig::default()
|
||||
.with_paged_attention(16, 4096)
|
||||
.with_xlora_adapters(vec!["code", "chat"])
|
||||
.with_isq(4);
|
||||
|
||||
assert!(config.paged_attention.is_some());
|
||||
assert!(config.xlora.is_some());
|
||||
assert!(config.isq.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_config_paged_attention() {
|
||||
let config = MistralBackendConfig::default().with_paged_attention(32, 8192);
|
||||
|
||||
let pa = config.paged_attention.unwrap();
|
||||
assert_eq!(pa.block_size, 32);
|
||||
assert_eq!(pa.max_pages, 8192);
|
||||
assert!(pa.enable_prefix_caching);
|
||||
assert!((pa.gpu_memory_fraction - 0.9).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_config_xlora() {
|
||||
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "chat", "math"]);
|
||||
|
||||
let xlora = config.xlora.unwrap();
|
||||
assert_eq!(xlora.adapter_names.len(), 3);
|
||||
assert!(xlora.adapter_names.contains(&"code".to_string()));
|
||||
assert!(xlora.adapter_names.contains(&"chat".to_string()));
|
||||
assert!(xlora.adapter_names.contains(&"math".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_config_isq() {
|
||||
let config = MistralBackendConfig::default().with_isq(4);
|
||||
|
||||
let isq = config.isq.unwrap();
|
||||
assert_eq!(isq.bits, 4);
|
||||
assert!(matches!(isq.method, IsqMethod::AWQ));
|
||||
assert!(!isq.symmetric);
|
||||
assert!(isq.per_channel);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_config_chained() {
|
||||
let config = MistralBackendConfig::default()
|
||||
.with_paged_attention(16, 4096)
|
||||
.with_xlora_adapters(vec!["adapter1", "adapter2"])
|
||||
.with_isq(8)
|
||||
.with_max_seq_len(32768)
|
||||
.with_max_batch_size(128);
|
||||
|
||||
assert!(config.paged_attention.is_some());
|
||||
assert!(config.xlora.is_some());
|
||||
assert!(config.isq.is_some());
|
||||
assert_eq!(config.max_seq_len, 32768);
|
||||
assert_eq!(config.max_batch_size, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_config_for_metal() {
|
||||
let config = MistralBackendConfig::for_metal();
|
||||
|
||||
assert!(matches!(config.device, DeviceType::Metal));
|
||||
assert!(matches!(config.dtype, DType::F16));
|
||||
assert!(config.use_flash_attn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mistral_config_for_cuda() {
|
||||
let config = MistralBackendConfig::for_cuda(1);
|
||||
|
||||
if let DeviceType::Cuda(id) = config.device {
|
||||
assert_eq!(id, 1);
|
||||
} else {
|
||||
panic!("Expected CUDA device type");
|
||||
}
|
||||
assert!(matches!(config.dtype, DType::F16));
|
||||
assert!(config.use_flash_attn);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// PagedAttention Configuration Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_config_default() {
|
||||
let config = PagedAttentionConfigExt::default();
|
||||
|
||||
assert_eq!(config.block_size, 16);
|
||||
assert_eq!(config.max_pages, 4096);
|
||||
assert!((config.gpu_memory_fraction - 0.9).abs() < f32::EPSILON);
|
||||
assert!(config.enable_prefix_caching);
|
||||
assert!((config.recomputation_threshold - 0.1).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_stats() {
|
||||
let backend = MistralBackend::new().unwrap();
|
||||
let stats = backend.paged_attention_stats();
|
||||
|
||||
// Default config enables PagedAttention
|
||||
assert!(stats.is_some());
|
||||
let stats = stats.unwrap();
|
||||
assert!(stats.total_blocks > 0);
|
||||
assert_eq!(stats.active_sequences, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_disabled() {
|
||||
let mut config = MistralBackendConfig::default();
|
||||
config.paged_attention = None;
|
||||
|
||||
let backend = MistralBackend::with_config(config).unwrap();
|
||||
let stats = backend.paged_attention_stats();
|
||||
|
||||
assert!(stats.is_none());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// X-LoRA Manager Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_xlora_manager_creation() {
|
||||
let xlora_config = XLoraConfig {
|
||||
adapter_names: vec!["test".to_string()],
|
||||
top_k: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let manager = XLoraManager::new(xlora_config);
|
||||
let stats = manager.stats();
|
||||
|
||||
assert_eq!(stats.loaded_adapters, 0);
|
||||
assert_eq!(stats.forward_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xlora_manager_routing() {
|
||||
let xlora_config = XLoraConfig {
|
||||
adapter_names: vec!["code".to_string(), "chat".to_string()],
|
||||
top_k: 2,
|
||||
use_learned_routing: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let manager = XLoraManager::new(xlora_config);
|
||||
|
||||
// Route without adapters - returns empty
|
||||
let routing = manager.route(&[0.1, 0.2, 0.3]);
|
||||
assert!(routing.is_empty()); // No adapters loaded
|
||||
|
||||
let stats = manager.stats();
|
||||
assert_eq!(stats.forward_count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xlora_config_defaults() {
|
||||
let config = XLoraConfig::default();
|
||||
|
||||
assert!(config.adapter_names.is_empty());
|
||||
assert!(config.base_adapter.is_none());
|
||||
assert!(config.adapter_scales.is_none());
|
||||
assert_eq!(config.router_hidden_dim, 64);
|
||||
assert_eq!(config.router_layers, 2);
|
||||
assert_eq!(config.top_k, 2);
|
||||
assert!((config.temperature - 1.0).abs() < f32::EPSILON);
|
||||
assert!(config.use_learned_routing);
|
||||
assert!(matches!(config.mixing_mode, XLoraMixingMode::Additive));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xlora_mixing_modes() {
|
||||
let additive = XLoraMixingMode::Additive;
|
||||
let concat = XLoraMixingMode::Concatenate;
|
||||
let gated = XLoraMixingMode::Gated;
|
||||
let attention = XLoraMixingMode::Attention;
|
||||
|
||||
assert!(matches!(additive, XLoraMixingMode::Additive));
|
||||
assert!(matches!(concat, XLoraMixingMode::Concatenate));
|
||||
assert!(matches!(gated, XLoraMixingMode::Gated));
|
||||
assert!(matches!(attention, XLoraMixingMode::Attention));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xlora_stats_from_backend() {
|
||||
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "chat"]);
|
||||
let backend = MistralBackend::with_config(config).unwrap();
|
||||
|
||||
let stats = backend.xlora_stats();
|
||||
assert!(stats.is_some());
|
||||
|
||||
let stats = stats.unwrap();
|
||||
assert_eq!(stats.loaded_adapters, 0); // No adapters actually loaded from disk
|
||||
assert_eq!(stats.forward_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xlora_stats_none_when_not_configured() {
|
||||
let mut config = MistralBackendConfig::default();
|
||||
config.xlora = None;
|
||||
|
||||
let backend = MistralBackend::with_config(config).unwrap();
|
||||
let stats = backend.xlora_stats();
|
||||
|
||||
assert!(stats.is_none());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ISQ Configuration Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_isq_config_defaults() {
|
||||
let config = IsqConfig::default();
|
||||
|
||||
assert_eq!(config.bits, 4);
|
||||
assert!(matches!(config.method, IsqMethod::AWQ));
|
||||
assert!(!config.symmetric);
|
||||
assert!(config.per_channel);
|
||||
assert_eq!(config.calibration_samples, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_isq_methods() {
|
||||
let awq = IsqMethod::AWQ;
|
||||
let gptq = IsqMethod::GPTQ;
|
||||
let rtn = IsqMethod::RTN;
|
||||
let smooth = IsqMethod::SmoothQuant;
|
||||
|
||||
assert!(matches!(awq, IsqMethod::AWQ));
|
||||
assert!(matches!(gptq, IsqMethod::GPTQ));
|
||||
assert!(matches!(rtn, IsqMethod::RTN));
|
||||
assert!(matches!(smooth, IsqMethod::SmoothQuant));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_isq_with_different_bits() {
|
||||
for bits in [2, 4, 8] {
|
||||
let config = MistralBackendConfig::default().with_isq(bits);
|
||||
let isq = config.isq.unwrap();
|
||||
assert_eq!(isq.bits, bits);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Backend Operation Tests (Without Model)
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_generate_requires_loaded_model() {
|
||||
let backend = MistralBackend::new().unwrap();
|
||||
|
||||
let result = backend.generate("Hello", GenerateParams::default());
|
||||
assert!(result.is_err());
|
||||
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("No model loaded"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_stream_requires_loaded_model() {
|
||||
let backend = MistralBackend::new().unwrap();
|
||||
|
||||
let result = backend.generate_stream("Hello", GenerateParams::default());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embeddings_require_loaded_model() {
|
||||
let backend = MistralBackend::new().unwrap();
|
||||
|
||||
let result = backend.get_embeddings("Test text");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_none_before_load() {
|
||||
let backend = MistralBackend::new().unwrap();
|
||||
assert!(backend.tokenizer().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_info_none_before_load() {
|
||||
let backend = MistralBackend::new().unwrap();
|
||||
assert!(backend.model_info().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unload_model_when_not_loaded() {
|
||||
let mut backend = MistralBackend::new().unwrap();
|
||||
|
||||
// Should not panic when called on unloaded backend
|
||||
backend.unload_model();
|
||||
assert!(!backend.is_model_loaded());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xlora_adapter_operations_require_config() {
|
||||
let mut config = MistralBackendConfig::default();
|
||||
config.xlora = None;
|
||||
|
||||
let backend = MistralBackend::with_config(config).unwrap();
|
||||
|
||||
// Loading adapter should fail without X-LoRA configured
|
||||
let result = backend.load_xlora_adapter("test", Path::new("/nonexistent"));
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("X-LoRA not configured"));
|
||||
|
||||
// Setting adapters should fail
|
||||
let result = backend.set_xlora_adapters(vec![("test", 1.0)]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_isq_requires_loaded_model() {
|
||||
let config = MistralBackendConfig::default().with_isq(4);
|
||||
let mut backend = MistralBackend::with_config(config).unwrap();
|
||||
|
||||
let result = backend.apply_isq();
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("No model loaded"));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Model Loading Tests (Requires Model Files - Ignored by Default)
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore = "Requires model file - run with --include-ignored"]
|
||||
fn test_model_loading() {
|
||||
let mut backend = MistralBackend::for_metal().unwrap();
|
||||
|
||||
// Note: Replace with actual model path for testing
|
||||
let result = backend.load_model(
|
||||
"models/test-model.gguf",
|
||||
ModelConfig {
|
||||
architecture: ModelArchitecture::Mistral,
|
||||
device: DeviceType::Metal,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
if result.is_ok() {
|
||||
assert!(backend.is_model_loaded());
|
||||
assert!(backend.model_info().is_some());
|
||||
assert!(backend.tokenizer().is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Requires model file - run with --include-ignored"]
|
||||
fn test_generation() {
|
||||
let mut backend = MistralBackend::new().unwrap();
|
||||
|
||||
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
||||
if load_result.is_err() {
|
||||
return; // Skip if model not available
|
||||
}
|
||||
|
||||
let output = backend.generate(
|
||||
"Hello",
|
||||
GenerateParams {
|
||||
max_tokens: 10,
|
||||
temperature: 0.7,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
match output {
|
||||
Ok(text) => {
|
||||
assert!(!text.is_empty());
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Generation failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Requires model file - run with --include-ignored"]
|
||||
fn test_streaming_generation() {
|
||||
let mut backend = MistralBackend::new().unwrap();
|
||||
|
||||
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
||||
if load_result.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let stream = backend.generate_stream("Hello", GenerateParams::default());
|
||||
match stream {
|
||||
Ok(stream) => {
|
||||
let tokens: Vec<_> = stream.collect();
|
||||
assert!(!tokens.is_empty());
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Streaming generation failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Requires model file - run with --include-ignored"]
|
||||
fn test_embeddings_extraction() {
|
||||
let mut backend = MistralBackend::new().unwrap();
|
||||
|
||||
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
||||
if load_result.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let embeddings = backend.get_embeddings("Test text for embedding");
|
||||
match embeddings {
|
||||
Ok(emb) => {
|
||||
assert!(!emb.is_empty());
|
||||
assert!(emb.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Embedding extraction failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Requires model file - run with --include-ignored"]
|
||||
fn test_model_unload_and_reload() {
|
||||
let mut backend = MistralBackend::new().unwrap();
|
||||
|
||||
// Load model
|
||||
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
||||
if load_result.is_err() {
|
||||
return;
|
||||
}
|
||||
assert!(backend.is_model_loaded());
|
||||
|
||||
// Unload
|
||||
backend.unload_model();
|
||||
assert!(!backend.is_model_loaded());
|
||||
assert!(backend.model_info().is_none());
|
||||
|
||||
// Reload
|
||||
let reload_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
||||
if reload_result.is_ok() {
|
||||
assert!(backend.is_model_loaded());
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests with PagedAttention
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_backend_paged_attention_integration() {
|
||||
let config = MistralBackendConfig::default().with_paged_attention(16, 4096);
|
||||
let backend = MistralBackend::with_config(config).unwrap();
|
||||
|
||||
// Verify PagedAttention is configured
|
||||
let stats = backend.paged_attention_stats().unwrap();
|
||||
assert!(stats.total_blocks > 0);
|
||||
assert!(stats.free_blocks > 0);
|
||||
assert_eq!(stats.active_sequences, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backend_xlora_integration() {
|
||||
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "math", "chat"]);
|
||||
let backend = MistralBackend::with_config(config).unwrap();
|
||||
|
||||
// Verify X-LoRA is configured
|
||||
let stats = backend.xlora_stats().unwrap();
|
||||
assert_eq!(stats.loaded_adapters, 0); // None loaded yet
|
||||
assert!(stats.adapter_usage.is_empty());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Serialization Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let config = MistralBackendConfig::default()
|
||||
.with_paged_attention(32, 8192)
|
||||
.with_xlora_adapters(vec!["test"])
|
||||
.with_isq(4);
|
||||
|
||||
// Test serialization roundtrip
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: MistralBackendConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.max_seq_len, config.max_seq_len);
|
||||
assert_eq!(deserialized.max_batch_size, config.max_batch_size);
|
||||
assert!(deserialized.paged_attention.is_some());
|
||||
assert!(deserialized.xlora.is_some());
|
||||
assert!(deserialized.isq.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_config_serialization() {
|
||||
let config = PagedAttentionConfigExt::default();
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: PagedAttentionConfigExt = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.block_size, config.block_size);
|
||||
assert_eq!(deserialized.max_pages, config.max_pages);
|
||||
assert_eq!(
|
||||
deserialized.enable_prefix_caching,
|
||||
config.enable_prefix_caching
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xlora_config_serialization() {
|
||||
let config = XLoraConfig {
|
||||
adapter_names: vec!["a".to_string(), "b".to_string()],
|
||||
top_k: 3,
|
||||
temperature: 0.5,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: XLoraConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.adapter_names.len(), 2);
|
||||
assert_eq!(deserialized.top_k, 3);
|
||||
assert!((deserialized.temperature - 0.5).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_isq_config_serialization() {
|
||||
let config = IsqConfig {
|
||||
bits: 8,
|
||||
method: IsqMethod::GPTQ,
|
||||
symmetric: true,
|
||||
per_channel: false,
|
||||
calibration_samples: 256,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: IsqConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.bits, 8);
|
||||
assert!(matches!(deserialized.method, IsqMethod::GPTQ));
|
||||
assert!(deserialized.symmetric);
|
||||
assert!(!deserialized.per_channel);
|
||||
assert_eq!(deserialized.calibration_samples, 256);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Case Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_empty_xlora_adapters() {
|
||||
let config = MistralBackendConfig::default().with_xlora_adapters(vec![]);
|
||||
|
||||
let xlora = config.xlora.unwrap();
|
||||
assert!(xlora.adapter_names.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_page_config() {
|
||||
let config = MistralBackendConfig::default().with_paged_attention(256, 65536);
|
||||
|
||||
let pa = config.paged_attention.unwrap();
|
||||
assert_eq!(pa.block_size, 256);
|
||||
assert_eq!(pa.max_pages, 65536);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_backend_instances() {
|
||||
let backend1 = MistralBackend::new().unwrap();
|
||||
let backend2 = MistralBackend::for_metal().unwrap();
|
||||
let backend3 = MistralBackend::for_cuda(0).unwrap();
|
||||
|
||||
assert!(!backend1.is_model_loaded());
|
||||
assert!(!backend2.is_model_loaded());
|
||||
assert!(!backend3.is_model_loaded());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_params_integration() {
|
||||
let params = GenerateParams {
|
||||
max_tokens: 256,
|
||||
temperature: 0.8,
|
||||
top_p: 0.95,
|
||||
top_k: 50,
|
||||
repetition_penalty: 1.1,
|
||||
frequency_penalty: 0.1,
|
||||
presence_penalty: 0.1,
|
||||
stop_sequences: vec!["STOP".to_string(), "\n\n".to_string()],
|
||||
seed: Some(12345),
|
||||
};
|
||||
|
||||
assert_eq!(params.max_tokens, 256);
|
||||
assert!((params.temperature - 0.8).abs() < f32::EPSILON);
|
||||
assert_eq!(params.stop_sequences.len(), 2);
|
||||
assert_eq!(params.seed, Some(12345));
|
||||
}
|
||||
1255
crates/ruvllm/tests/model_arch_integration.rs
Normal file
1255
crates/ruvllm/tests/model_arch_integration.rs
Normal file
File diff suppressed because it is too large
Load Diff
752
crates/ruvllm/tests/real_model_test.rs
Normal file
752
crates/ruvllm/tests/real_model_test.rs
Normal file
@@ -0,0 +1,752 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Real model validation tests
|
||||
//!
|
||||
//! These tests require actual GGUF model files to run.
|
||||
//! They are marked with `#[ignore]` by default and can be run with:
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Run with specific model path
|
||||
//! TEST_MODEL_PATH=./test_models/tinyllama.gguf cargo test -p ruvllm --test real_model_test -- --ignored
|
||||
//!
|
||||
//! # Run with default test_models directory
|
||||
//! cargo test -p ruvllm --test real_model_test -- --ignored
|
||||
//! ```
|
||||
//!
|
||||
//! ## Recommended test models (small, fast)
|
||||
//!
|
||||
//! | Model | Size | Use Case |
|
||||
//! |-------|------|----------|
|
||||
//! | TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf | ~700MB | Fast iteration |
|
||||
//! | Qwen2-0.5B-Instruct.Q4_K_M.gguf | ~400MB | Smallest, fastest |
|
||||
//! | Phi-3-mini-4k-instruct.Q4_K_M.gguf | ~2GB | Higher quality |
|
||||
//!
|
||||
//! ## Download test models
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run -p ruvllm --example download_test_model -- --model tinyllama
|
||||
//! ```
|
||||
|
||||
use std::env;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
|
||||
// ============================================================================
|
||||
// Test Utilities
|
||||
// ============================================================================
|
||||
|
||||
/// Common search locations for test models
|
||||
const MODEL_SEARCH_PATHS: &[&str] = &[
|
||||
"./test_models",
|
||||
"../test_models",
|
||||
"../../test_models",
|
||||
"./models",
|
||||
"../models",
|
||||
"~/.cache/ruvllm/models",
|
||||
"~/.cache/huggingface/hub",
|
||||
];
|
||||
|
||||
/// Supported model file patterns for each architecture
|
||||
const TINYLLAMA_PATTERNS: &[&str] = &["tinyllama*.gguf", "TinyLlama*.gguf", "*tinyllama*.gguf"];
|
||||
|
||||
const PHI3_PATTERNS: &[&str] = &["phi-3*.gguf", "Phi-3*.gguf", "*phi3*.gguf", "*phi-3*.gguf"];
|
||||
|
||||
const QWEN_PATTERNS: &[&str] = &["qwen*.gguf", "Qwen*.gguf", "*qwen*.gguf"];
|
||||
|
||||
/// Result type for test helpers (reserved for future use)
|
||||
#[allow(dead_code)]
|
||||
type TestResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
|
||||
|
||||
/// Find a test model in common locations.
|
||||
///
|
||||
/// Search order:
|
||||
/// 1. `TEST_MODEL_PATH` environment variable (exact path)
|
||||
/// 2. `TEST_MODEL_DIR` environment variable (directory to search)
|
||||
/// 3. Common locations in `MODEL_SEARCH_PATHS`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `patterns` - Glob patterns to match model files
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Path to the first matching model file, or None if not found
|
||||
pub fn find_test_model(patterns: &[&str]) -> Option<PathBuf> {
|
||||
// 1. Check TEST_MODEL_PATH for exact path
|
||||
if let Ok(path) = env::var("TEST_MODEL_PATH") {
|
||||
let path = PathBuf::from(path);
|
||||
if path.exists() && path.is_file() {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check TEST_MODEL_DIR for directory
|
||||
if let Ok(dir) = env::var("TEST_MODEL_DIR") {
|
||||
if let Some(found) = search_directory(&PathBuf::from(dir), patterns) {
|
||||
return Some(found);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Search common locations
|
||||
for search_path in MODEL_SEARCH_PATHS {
|
||||
let expanded = expand_path(search_path);
|
||||
if expanded.exists() && expanded.is_dir() {
|
||||
if let Some(found) = search_directory(&expanded, patterns) {
|
||||
return Some(found);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Search a directory for files matching any of the given patterns
|
||||
fn search_directory(dir: &Path, patterns: &[&str]) -> Option<PathBuf> {
|
||||
if !dir.exists() || !dir.is_dir() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entries = match std::fs::read_dir(dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if !path.is_file() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let file_name = match path.file_name().and_then(|n| n.to_str()) {
|
||||
Some(n) => n.to_lowercase(),
|
||||
None => continue,
|
||||
};
|
||||
|
||||
for pattern in patterns {
|
||||
if matches_glob_pattern(&file_name, &pattern.to_lowercase()) {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Simple glob pattern matching (supports * wildcard)
|
||||
fn matches_glob_pattern(name: &str, pattern: &str) -> bool {
|
||||
if !pattern.contains('*') {
|
||||
return name == pattern;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = pattern.split('*').collect();
|
||||
if parts.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut remaining = name;
|
||||
|
||||
// First part must be a prefix (if not empty)
|
||||
if !parts[0].is_empty() {
|
||||
if !remaining.starts_with(parts[0]) {
|
||||
return false;
|
||||
}
|
||||
remaining = &remaining[parts[0].len()..];
|
||||
}
|
||||
|
||||
// Last part must be a suffix (if not empty)
|
||||
if parts.len() > 1 {
|
||||
let last = parts[parts.len() - 1];
|
||||
if !last.is_empty() && !remaining.ends_with(last) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Middle parts must appear in order
|
||||
for part in &parts[1..parts.len().saturating_sub(1)] {
|
||||
if part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
match remaining.find(part) {
|
||||
Some(pos) => remaining = &remaining[pos + part.len()..],
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Expand ~ to home directory
|
||||
fn expand_path(path: &str) -> PathBuf {
|
||||
if path.starts_with("~/") {
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
return home.join(&path[2..]);
|
||||
}
|
||||
}
|
||||
PathBuf::from(path)
|
||||
}
|
||||
|
||||
/// Skip test gracefully if no model is available
|
||||
///
|
||||
/// Returns the model path if found, or prints a skip message and returns None
|
||||
pub fn skip_if_no_model(patterns: &[&str], model_name: &str) -> Option<PathBuf> {
|
||||
match find_test_model(patterns) {
|
||||
Some(path) => {
|
||||
println!("Using model: {}", path.display());
|
||||
Some(path)
|
||||
}
|
||||
None => {
|
||||
println!("SKIPPED: No {} model found.", model_name);
|
||||
println!("To run this test:");
|
||||
println!(" 1. Download the model:");
|
||||
println!(
|
||||
" cargo run -p ruvllm --example download_test_model -- --model {}",
|
||||
model_name.to_lowercase().replace(' ', "")
|
||||
);
|
||||
println!(" 2. Or set TEST_MODEL_PATH environment variable");
|
||||
println!(" 3. Or place model in ./test_models/ directory");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Measure tokens per second during generation
|
||||
pub struct GenerationMetrics {
|
||||
pub total_tokens: usize,
|
||||
pub total_duration: Duration,
|
||||
pub first_token_latency: Duration,
|
||||
pub token_latencies: Vec<Duration>,
|
||||
}
|
||||
|
||||
impl GenerationMetrics {
|
||||
pub fn tokens_per_second(&self) -> f64 {
|
||||
if self.total_duration.as_secs_f64() > 0.0 {
|
||||
self.total_tokens as f64 / self.total_duration.as_secs_f64()
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn latency_p50(&self) -> Duration {
|
||||
self.percentile_latency(50)
|
||||
}
|
||||
|
||||
pub fn latency_p95(&self) -> Duration {
|
||||
self.percentile_latency(95)
|
||||
}
|
||||
|
||||
pub fn latency_p99(&self) -> Duration {
|
||||
self.percentile_latency(99)
|
||||
}
|
||||
|
||||
fn percentile_latency(&self, p: usize) -> Duration {
|
||||
if self.token_latencies.is_empty() {
|
||||
return Duration::ZERO;
|
||||
}
|
||||
|
||||
let mut sorted = self.token_latencies.clone();
|
||||
sorted.sort();
|
||||
|
||||
let idx = (p * sorted.len() / 100).min(sorted.len() - 1);
|
||||
sorted[idx]
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
"Tokens: {}, Duration: {:.2}s, Speed: {:.2} tok/s, TTFT: {:.2}ms, P50: {:.2}ms, P95: {:.2}ms, P99: {:.2}ms",
|
||||
self.total_tokens,
|
||||
self.total_duration.as_secs_f64(),
|
||||
self.tokens_per_second(),
|
||||
self.first_token_latency.as_secs_f64() * 1000.0,
|
||||
self.latency_p50().as_secs_f64() * 1000.0,
|
||||
self.latency_p95().as_secs_f64() * 1000.0,
|
||||
self.latency_p99().as_secs_f64() * 1000.0,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GGUF File Validation Tests
|
||||
// ============================================================================
|
||||
|
||||
/// Test that we can read and validate a GGUF file header
|
||||
#[test]
|
||||
#[ignore = "Requires model file - run with --ignored"]
|
||||
fn test_gguf_file_validation() {
|
||||
// Try to find any GGUF model
|
||||
let all_patterns = ["*.gguf"];
|
||||
let model_path = match skip_if_no_model(&all_patterns, "any GGUF") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
// Read and validate the file header
|
||||
let file = std::fs::File::open(&model_path).expect("Failed to open model file");
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
|
||||
// Read magic number (first 4 bytes should be "GGUF")
|
||||
use std::io::Read;
|
||||
let mut magic = [0u8; 4];
|
||||
reader.read_exact(&mut magic).expect("Failed to read magic");
|
||||
|
||||
// GGUF magic is "GGUF" in little-endian: 0x46554747
|
||||
assert_eq!(&magic, b"GGUF", "Invalid GGUF magic number");
|
||||
|
||||
// Read version (4 bytes, little-endian u32)
|
||||
let mut version_bytes = [0u8; 4];
|
||||
reader
|
||||
.read_exact(&mut version_bytes)
|
||||
.expect("Failed to read version");
|
||||
let version = u32::from_le_bytes(version_bytes);
|
||||
|
||||
// GGUF versions 2 and 3 are common
|
||||
assert!(
|
||||
version >= 2 && version <= 3,
|
||||
"Unexpected GGUF version: {}",
|
||||
version
|
||||
);
|
||||
|
||||
println!("GGUF file validated:");
|
||||
println!(" Path: {}", model_path.display());
|
||||
println!(" Magic: GGUF");
|
||||
println!(" Version: {}", version);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TinyLlama Tests
|
||||
// ============================================================================
|
||||
|
||||
/// Test loading TinyLlama model
|
||||
#[test]
|
||||
#[ignore = "Requires TinyLlama model file"]
|
||||
fn test_tinyllama_load() {
|
||||
let model_path = match skip_if_no_model(TINYLLAMA_PATTERNS, "TinyLlama") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
// This test verifies the model can be loaded without errors
|
||||
// In a real implementation, you would use the RuvLLM API
|
||||
println!("Would load TinyLlama from: {}", model_path.display());
|
||||
|
||||
// Verify file is readable and has reasonable size
|
||||
let metadata = std::fs::metadata(&model_path).expect("Failed to get file metadata");
|
||||
let size_mb = metadata.len() as f64 / (1024.0 * 1024.0);
|
||||
|
||||
println!("Model size: {:.2} MB", size_mb);
|
||||
|
||||
// TinyLlama Q4_K_M should be ~500-800MB
|
||||
assert!(
|
||||
size_mb > 100.0 && size_mb < 2000.0,
|
||||
"Unexpected model size: {:.2} MB (expected 100-2000 MB for TinyLlama)",
|
||||
size_mb
|
||||
);
|
||||
}
|
||||
|
||||
/// Test text generation with TinyLlama
|
||||
#[test]
|
||||
#[ignore = "Requires TinyLlama model file"]
|
||||
fn test_tinyllama_generation() {
|
||||
let model_path = match skip_if_no_model(TINYLLAMA_PATTERNS, "TinyLlama") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!(
|
||||
"Testing generation with TinyLlama: {}",
|
||||
model_path.display()
|
||||
);
|
||||
|
||||
// Placeholder for actual generation test
|
||||
// In real implementation:
|
||||
//
|
||||
// let mut backend = CandleBackend::new().expect("Failed to create backend");
|
||||
// let config = ModelConfig {
|
||||
// architecture: ModelArchitecture::Llama,
|
||||
// quantization: Some(Quantization::Q4K),
|
||||
// ..Default::default()
|
||||
// };
|
||||
// backend.load_model(model_path.to_str().unwrap(), config).expect("Failed to load model");
|
||||
//
|
||||
// let params = GenerateParams::default()
|
||||
// .with_max_tokens(50)
|
||||
// .with_temperature(0.7);
|
||||
//
|
||||
// let response = backend.generate("Hello, I am", params).expect("Generation failed");
|
||||
// assert!(!response.is_empty(), "Empty response from model");
|
||||
// println!("Generated: {}", response);
|
||||
|
||||
println!("TinyLlama generation test placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
/// Test streaming generation with TinyLlama
|
||||
#[test]
|
||||
#[ignore = "Requires TinyLlama model file"]
|
||||
fn test_tinyllama_streaming() {
|
||||
let model_path = match skip_if_no_model(TINYLLAMA_PATTERNS, "TinyLlama") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!("Testing streaming with TinyLlama: {}", model_path.display());
|
||||
|
||||
// Placeholder for streaming test
|
||||
// In real implementation:
|
||||
//
|
||||
// let stream = backend.generate_stream_v2("Once upon a time", params)?;
|
||||
// let mut token_count = 0;
|
||||
// for event in stream {
|
||||
// match event? {
|
||||
// StreamEvent::Token(token) => {
|
||||
// print!("{}", token.text);
|
||||
// token_count += 1;
|
||||
// }
|
||||
// StreamEvent::Done { tokens_per_second, .. } => {
|
||||
// println!("\nSpeed: {:.2} tok/s", tokens_per_second);
|
||||
// }
|
||||
// StreamEvent::Error(e) => panic!("Streaming error: {}", e),
|
||||
// }
|
||||
// }
|
||||
// assert!(token_count > 0, "No tokens generated");
|
||||
|
||||
println!("TinyLlama streaming test placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Phi-3 Tests
|
||||
// ============================================================================
|
||||
|
||||
/// Test loading Phi-3 model
|
||||
#[test]
|
||||
#[ignore = "Requires Phi-3 model file"]
|
||||
fn test_phi3_load() {
|
||||
let model_path = match skip_if_no_model(PHI3_PATTERNS, "Phi-3") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!("Would load Phi-3 from: {}", model_path.display());
|
||||
|
||||
let metadata = std::fs::metadata(&model_path).expect("Failed to get file metadata");
|
||||
let size_mb = metadata.len() as f64 / (1024.0 * 1024.0);
|
||||
|
||||
println!("Model size: {:.2} MB", size_mb);
|
||||
|
||||
// Phi-3 mini Q4_K_M should be ~2-3GB
|
||||
assert!(
|
||||
size_mb > 500.0 && size_mb < 5000.0,
|
||||
"Unexpected model size: {:.2} MB (expected 500-5000 MB for Phi-3)",
|
||||
size_mb
|
||||
);
|
||||
}
|
||||
|
||||
/// Test text generation with Phi-3
|
||||
#[test]
|
||||
#[ignore = "Requires Phi-3 model file"]
|
||||
fn test_phi3_generation() {
|
||||
let model_path = match skip_if_no_model(PHI3_PATTERNS, "Phi-3") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!("Testing generation with Phi-3: {}", model_path.display());
|
||||
println!("Phi-3 generation test placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
/// Test Phi-3 with code completion prompt
|
||||
#[test]
|
||||
#[ignore = "Requires Phi-3 model file"]
|
||||
fn test_phi3_code_completion() {
|
||||
let model_path = match skip_if_no_model(PHI3_PATTERNS, "Phi-3") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!(
|
||||
"Testing code completion with Phi-3: {}",
|
||||
model_path.display()
|
||||
);
|
||||
|
||||
// Code completion prompts test the model's ability to understand code context
|
||||
let _prompts = [
|
||||
"def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n ",
|
||||
"// Function to reverse a string in Rust\nfn reverse_string(s: &str) -> String {\n ",
|
||||
"# Python function to check if a number is prime\ndef is_prime(n):\n ",
|
||||
];
|
||||
|
||||
println!("Phi-3 code completion test placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Qwen Tests
|
||||
// ============================================================================
|
||||
|
||||
/// Test loading Qwen model
|
||||
#[test]
|
||||
#[ignore = "Requires Qwen model file"]
|
||||
fn test_qwen_load() {
|
||||
let model_path = match skip_if_no_model(QWEN_PATTERNS, "Qwen") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!("Would load Qwen from: {}", model_path.display());
|
||||
|
||||
let metadata = std::fs::metadata(&model_path).expect("Failed to get file metadata");
|
||||
let size_mb = metadata.len() as f64 / (1024.0 * 1024.0);
|
||||
|
||||
println!("Model size: {:.2} MB", size_mb);
|
||||
|
||||
// Qwen2-0.5B Q4_K_M should be ~300-500MB
|
||||
assert!(
|
||||
size_mb > 50.0 && size_mb < 1000.0,
|
||||
"Unexpected model size: {:.2} MB (expected 50-1000 MB for Qwen-0.5B)",
|
||||
size_mb
|
||||
);
|
||||
}
|
||||
|
||||
/// Test text generation with Qwen
|
||||
#[test]
|
||||
#[ignore = "Requires Qwen model file"]
|
||||
fn test_qwen_generation() {
|
||||
let model_path = match skip_if_no_model(QWEN_PATTERNS, "Qwen") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!("Testing generation with Qwen: {}", model_path.display());
|
||||
println!("Qwen generation test placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
/// Test Qwen multilingual capability
|
||||
#[test]
|
||||
#[ignore = "Requires Qwen model file"]
|
||||
fn test_qwen_multilingual() {
|
||||
let model_path = match skip_if_no_model(QWEN_PATTERNS, "Qwen") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!("Testing multilingual with Qwen: {}", model_path.display());
|
||||
|
||||
// Qwen is known for good multilingual support
|
||||
let _prompts = [
|
||||
"Hello, how are you today?", // English
|
||||
"Bonjour, comment allez-vous?", // French
|
||||
"Hallo, wie geht es Ihnen?", // German
|
||||
"Translate 'hello' to Chinese: ", // Translation task
|
||||
];
|
||||
|
||||
println!("Qwen multilingual test placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
/// Benchmark token generation speed
|
||||
#[test]
|
||||
#[ignore = "Requires model file - run with --ignored"]
|
||||
fn test_benchmark_generation_speed() {
|
||||
// Try to find any available model
|
||||
let patterns = ["*.gguf"];
|
||||
let model_path = match skip_if_no_model(&patterns, "any GGUF") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!(
|
||||
"Benchmarking generation speed with: {}",
|
||||
model_path.display()
|
||||
);
|
||||
|
||||
// Benchmark parameters
|
||||
let warmup_iterations = 3;
|
||||
let benchmark_iterations = 10;
|
||||
let max_tokens = 50;
|
||||
|
||||
println!("Warmup: {} iterations", warmup_iterations);
|
||||
println!("Benchmark: {} iterations", benchmark_iterations);
|
||||
println!("Max tokens per generation: {}", max_tokens);
|
||||
|
||||
// Placeholder for actual benchmark
|
||||
// In real implementation:
|
||||
//
|
||||
// // Warmup
|
||||
// for _ in 0..warmup_iterations {
|
||||
// backend.generate("Hello", params.clone())?;
|
||||
// }
|
||||
//
|
||||
// // Benchmark
|
||||
// let mut speeds = Vec::new();
|
||||
// for i in 0..benchmark_iterations {
|
||||
// let start = Instant::now();
|
||||
// let stream = backend.generate_stream_v2("Hello", params.clone())?;
|
||||
// let mut tokens = 0;
|
||||
// for event in stream {
|
||||
// if let StreamEvent::Token(_) = event? {
|
||||
// tokens += 1;
|
||||
// }
|
||||
// }
|
||||
// let elapsed = start.elapsed();
|
||||
// let speed = tokens as f64 / elapsed.as_secs_f64();
|
||||
// speeds.push(speed);
|
||||
// println!(" Iteration {}: {:.2} tok/s", i + 1, speed);
|
||||
// }
|
||||
//
|
||||
// let avg_speed = speeds.iter().sum::<f64>() / speeds.len() as f64;
|
||||
// println!("\nAverage speed: {:.2} tok/s", avg_speed);
|
||||
|
||||
println!("Benchmark placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
/// Test memory usage during inference
|
||||
#[test]
|
||||
#[ignore = "Requires model file"]
|
||||
fn test_memory_usage() {
|
||||
let patterns = ["*.gguf"];
|
||||
let model_path = match skip_if_no_model(&patterns, "any GGUF") {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
println!("Testing memory usage with: {}", model_path.display());
|
||||
|
||||
// Get initial memory usage (platform-specific)
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
use std::process::Command;
|
||||
let output = Command::new("ps")
|
||||
.args(["-o", "rss=", "-p", &std::process::id().to_string()])
|
||||
.output()
|
||||
.ok();
|
||||
|
||||
if let Some(output) = output {
|
||||
if let Ok(rss) = String::from_utf8_lossy(&output.stdout)
|
||||
.trim()
|
||||
.parse::<u64>()
|
||||
{
|
||||
println!("Initial RSS: {} KB", rss);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("Memory usage test placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Model Comparison Tests
|
||||
// ============================================================================
|
||||
|
||||
/// Compare generation quality across different models
|
||||
#[test]
|
||||
#[ignore = "Requires multiple model files"]
|
||||
fn test_model_comparison() {
|
||||
println!("Model comparison test");
|
||||
|
||||
let test_prompts = [
|
||||
"What is the capital of France?",
|
||||
"Write a haiku about programming.",
|
||||
"Explain quantum computing in simple terms.",
|
||||
];
|
||||
|
||||
// Find all available models
|
||||
let models: Vec<(&str, Option<PathBuf>)> = vec![
|
||||
("TinyLlama", find_test_model(TINYLLAMA_PATTERNS)),
|
||||
("Phi-3", find_test_model(PHI3_PATTERNS)),
|
||||
("Qwen", find_test_model(QWEN_PATTERNS)),
|
||||
];
|
||||
|
||||
let available: Vec<_> = models.iter().filter(|(_, path)| path.is_some()).collect();
|
||||
|
||||
if available.is_empty() {
|
||||
println!("SKIPPED: No models available for comparison");
|
||||
return;
|
||||
}
|
||||
|
||||
println!("Available models for comparison:");
|
||||
for (name, path) in &available {
|
||||
if let Some(p) = path {
|
||||
println!(" - {}: {}", name, p.display());
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nTest prompts:");
|
||||
for (i, prompt) in test_prompts.iter().enumerate() {
|
||||
println!(" {}. {}", i + 1, prompt);
|
||||
}
|
||||
|
||||
println!("\nModel comparison placeholder - implement with actual backend");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Unit Tests for Helpers
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod helper_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_glob_pattern_matching() {
|
||||
assert!(matches_glob_pattern("tinyllama.gguf", "*.gguf"));
|
||||
assert!(matches_glob_pattern("tinyllama.gguf", "tinyllama*"));
|
||||
assert!(matches_glob_pattern(
|
||||
"tinyllama-1.1b.gguf",
|
||||
"*tinyllama*.gguf"
|
||||
));
|
||||
assert!(matches_glob_pattern("model.gguf", "model.gguf"));
|
||||
assert!(!matches_glob_pattern("tinyllama.bin", "*.gguf"));
|
||||
assert!(!matches_glob_pattern("other.gguf", "tinyllama*"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_path_no_tilde() {
|
||||
let path = expand_path("/usr/local/models");
|
||||
assert_eq!(path, PathBuf::from("/usr/local/models"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_path_relative() {
|
||||
let path = expand_path("./models");
|
||||
assert_eq!(path, PathBuf::from("./models"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metrics_percentile() {
|
||||
let metrics = GenerationMetrics {
|
||||
total_tokens: 100,
|
||||
total_duration: Duration::from_secs(10),
|
||||
first_token_latency: Duration::from_millis(50),
|
||||
token_latencies: (0..100).map(|i| Duration::from_millis(i as u64)).collect(),
|
||||
};
|
||||
|
||||
assert_eq!(metrics.tokens_per_second(), 10.0);
|
||||
assert!(metrics.latency_p50() >= Duration::from_millis(49));
|
||||
assert!(metrics.latency_p50() <= Duration::from_millis(51));
|
||||
assert!(metrics.latency_p99() >= Duration::from_millis(98));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metrics_empty_latencies() {
|
||||
let metrics = GenerationMetrics {
|
||||
total_tokens: 0,
|
||||
total_duration: Duration::ZERO,
|
||||
first_token_latency: Duration::ZERO,
|
||||
token_latencies: vec![],
|
||||
};
|
||||
|
||||
assert_eq!(metrics.tokens_per_second(), 0.0);
|
||||
assert_eq!(metrics.latency_p50(), Duration::ZERO);
|
||||
}
|
||||
}
|
||||
1061
crates/ruvllm/tests/ruvltra_e2e.rs
Normal file
1061
crates/ruvllm/tests/ruvltra_e2e.rs
Normal file
File diff suppressed because it is too large
Load Diff
1257
crates/ruvllm/tests/ruvltra_tests.rs
Normal file
1257
crates/ruvllm/tests/ruvltra_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
1051
crates/ruvllm/tests/serving_integration.rs
Normal file
1051
crates/ruvllm/tests/serving_integration.rs
Normal file
File diff suppressed because it is too large
Load Diff
549
crates/ruvllm/tests/sona_integration.rs
Normal file
549
crates/ruvllm/tests/sona_integration.rs
Normal file
@@ -0,0 +1,549 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Integration tests for SONA (Self-Optimizing Neural Architecture)
|
||||
//!
|
||||
//! Tests the three-tier learning loop: instant adaptation, background consolidation,
|
||||
//! and deep loop processing.
|
||||
|
||||
use ruvllm::{
|
||||
error::Result,
|
||||
sona::{
|
||||
LearningLoop, RoutingRecommendation, SonaConfig, SonaIntegration, SonaStats, Trajectory,
|
||||
},
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Create a test SONA configuration
|
||||
fn create_test_sona_config() -> SonaConfig {
|
||||
SonaConfig {
|
||||
hidden_dim: 64,
|
||||
embedding_dim: 128,
|
||||
micro_lora_rank: 2,
|
||||
base_lora_rank: 4,
|
||||
instant_learning_rate: 0.01,
|
||||
background_learning_rate: 0.001,
|
||||
ewc_lambda: 0.1,
|
||||
pattern_capacity: 100,
|
||||
background_interval_secs: 3600,
|
||||
deep_interval_secs: 604800,
|
||||
quality_threshold: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a test trajectory
|
||||
fn create_test_trajectory(request_id: &str, quality: f32) -> Trajectory {
|
||||
Trajectory {
|
||||
request_id: request_id.to_string(),
|
||||
session_id: "test-session".to_string(),
|
||||
query_embedding: vec![0.1; 128],
|
||||
response_embedding: vec![0.2; 128],
|
||||
quality_score: quality,
|
||||
routing_features: vec![0.7, 0.9, 0.5, 0.5],
|
||||
model_index: 1,
|
||||
timestamp: chrono::Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_config_default() {
|
||||
let config = SonaConfig::default();
|
||||
|
||||
assert_eq!(config.hidden_dim, 256);
|
||||
assert_eq!(config.embedding_dim, 768);
|
||||
assert_eq!(config.micro_lora_rank, 2);
|
||||
assert_eq!(config.base_lora_rank, 8);
|
||||
assert!(config.instant_learning_rate > 0.0);
|
||||
assert!(config.ewc_lambda > 0.0);
|
||||
assert!(config.quality_threshold > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_integration_creation() {
|
||||
let config = create_test_sona_config();
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 0);
|
||||
assert_eq!(stats.instant_updates, 0);
|
||||
assert_eq!(stats.background_updates, 0);
|
||||
assert_eq!(stats.deep_updates, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_learning_loop_variants() {
|
||||
assert!(matches!(LearningLoop::Instant, LearningLoop::Instant));
|
||||
assert!(matches!(LearningLoop::Background, LearningLoop::Background));
|
||||
assert!(matches!(LearningLoop::Deep, LearningLoop::Deep));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trajectory_creation() {
|
||||
let trajectory = create_test_trajectory("req-001", 0.8);
|
||||
|
||||
assert_eq!(trajectory.request_id, "req-001");
|
||||
assert_eq!(trajectory.session_id, "test-session");
|
||||
assert_eq!(trajectory.quality_score, 0.8);
|
||||
assert_eq!(trajectory.query_embedding.len(), 128);
|
||||
assert_eq!(trajectory.response_embedding.len(), 128);
|
||||
assert_eq!(trajectory.routing_features.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_record_trajectory() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.0, // Accept all trajectories
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
let trajectory = create_test_trajectory("req-001", 0.8);
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 1);
|
||||
assert_eq!(stats.instant_updates, 1); // Should run instant loop
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_quality_threshold() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.7,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// High quality - should trigger instant loop
|
||||
let high_quality = create_test_trajectory("req-001", 0.9);
|
||||
sona.record_trajectory(high_quality).unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 1);
|
||||
assert_eq!(stats.instant_updates, 1);
|
||||
|
||||
// Low quality - should not trigger instant loop
|
||||
let low_quality = create_test_trajectory("req-002", 0.5);
|
||||
sona.record_trajectory(low_quality).unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 2);
|
||||
assert_eq!(stats.instant_updates, 1); // Still 1
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_multiple_trajectories() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.0,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
for i in 0..10 {
|
||||
let trajectory = create_test_trajectory(&format!("req-{:03}", i), 0.8);
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
}
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 10);
|
||||
assert_eq!(stats.instant_updates, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_routing_recommendation_no_patterns() {
|
||||
let config = create_test_sona_config();
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
let query = vec![0.1; 128];
|
||||
let rec = sona.get_routing_recommendation(&query);
|
||||
|
||||
// With no patterns, should return defaults
|
||||
assert_eq!(rec.based_on_patterns, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_routing_recommendation_default() {
|
||||
let rec = RoutingRecommendation::default();
|
||||
|
||||
assert_eq!(rec.suggested_model, 0);
|
||||
assert_eq!(rec.confidence, 0.0);
|
||||
assert_eq!(rec.based_on_patterns, 0);
|
||||
assert_eq!(rec.average_quality, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_search_patterns_empty() {
|
||||
let config = create_test_sona_config();
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
let query = vec![0.1; 128];
|
||||
let patterns = sona.search_patterns(&query, 5);
|
||||
|
||||
assert!(patterns.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_apply_transform() {
|
||||
let config = create_test_sona_config();
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
let input = vec![0.1; 64]; // Must match hidden_dim
|
||||
let output = sona.apply_transform(&input);
|
||||
|
||||
assert_eq!(output.len(), input.len());
|
||||
assert!(output.iter().all(|&v| v.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_stats() {
|
||||
let config = create_test_sona_config();
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
let stats = sona.stats();
|
||||
|
||||
assert_eq!(stats.total_trajectories, 0);
|
||||
assert_eq!(stats.instant_updates, 0);
|
||||
assert_eq!(stats.background_updates, 0);
|
||||
assert_eq!(stats.deep_updates, 0);
|
||||
assert_eq!(stats.patterns_learned, 0);
|
||||
assert_eq!(stats.buffer_size, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_stats_after_learning() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.0,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Record some trajectories
|
||||
for i in 0..5 {
|
||||
let trajectory = create_test_trajectory(&format!("req-{}", i), 0.8);
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
}
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 5);
|
||||
assert!(stats.buffer_size > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_trigger_background_loop() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.0,
|
||||
background_interval_secs: 0, // Allow immediate trigger
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Record trajectories
|
||||
for i in 0..5 {
|
||||
let trajectory = create_test_trajectory(&format!("req-{}", i), 0.8);
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
}
|
||||
|
||||
// Trigger background loop
|
||||
sona.trigger_background_loop().unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert!(stats.background_updates >= 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_trigger_deep_loop() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.0,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Record trajectories (this may trigger deep loop automatically if interval elapsed)
|
||||
for i in 0..5 {
|
||||
let trajectory = create_test_trajectory(&format!("req-{}", i), 0.8);
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
}
|
||||
|
||||
let stats_before = sona.stats();
|
||||
let deep_updates_before = stats_before.deep_updates;
|
||||
|
||||
// Trigger background loop first (to populate patterns)
|
||||
sona.trigger_background_loop().unwrap();
|
||||
|
||||
// Trigger deep loop explicitly
|
||||
sona.trigger_deep_loop().unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
// At least one more deep update after explicit trigger
|
||||
assert!(
|
||||
stats.deep_updates >= deep_updates_before + 1,
|
||||
"Expected at least {} deep updates, got {}",
|
||||
deep_updates_before + 1,
|
||||
stats.deep_updates
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trajectory_timestamp() {
|
||||
let trajectory = create_test_trajectory("req-001", 0.8);
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
// Timestamp should be recent
|
||||
let diff = now - trajectory.timestamp;
|
||||
assert!(diff.num_seconds() < 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_varying_quality_trajectories() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.5,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Record trajectories with varying quality
|
||||
let qualities = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1];
|
||||
for (i, &quality) in qualities.iter().enumerate() {
|
||||
let trajectory = create_test_trajectory(&format!("req-{}", i), quality);
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
}
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 9);
|
||||
// Only 5 have quality >= 0.5 threshold
|
||||
assert_eq!(stats.instant_updates, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_empty_background_loop() {
|
||||
let config = create_test_sona_config();
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Trigger background loop with no trajectories
|
||||
// Note: The implementation returns early without incrementing counter
|
||||
// if there are no high-quality trajectories to process
|
||||
let result = sona.trigger_background_loop();
|
||||
assert!(result.is_ok());
|
||||
|
||||
let stats = sona.stats();
|
||||
// With no trajectories meeting quality threshold, background_updates is 0
|
||||
assert_eq!(
|
||||
stats.background_updates, 0,
|
||||
"Background loop with no trajectories should not count as an update"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_empty_deep_loop() {
|
||||
let config = create_test_sona_config();
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Trigger deep loop with no patterns
|
||||
let result = sona.trigger_deep_loop();
|
||||
assert!(result.is_ok());
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.deep_updates, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_large_embedding() {
|
||||
let config = SonaConfig {
|
||||
embedding_dim: 768,
|
||||
hidden_dim: 256,
|
||||
quality_threshold: 0.0,
|
||||
..SonaConfig::default()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
let trajectory = Trajectory {
|
||||
request_id: "large-001".to_string(),
|
||||
session_id: "test".to_string(),
|
||||
query_embedding: vec![0.1; 768],
|
||||
response_embedding: vec![0.2; 768],
|
||||
quality_score: 0.9,
|
||||
routing_features: vec![0.5; 4],
|
||||
model_index: 0,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_model_index_mapping() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.0,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Test different model indices
|
||||
for model_idx in 0..4 {
|
||||
let trajectory = Trajectory {
|
||||
request_id: format!("model-{}", model_idx),
|
||||
session_id: "test".to_string(),
|
||||
query_embedding: vec![0.1; 128],
|
||||
response_embedding: vec![0.2; 128],
|
||||
quality_score: 0.8,
|
||||
routing_features: vec![0.5; 4],
|
||||
model_index: model_idx,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
}
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_concurrent_safe() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.0,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = Arc::new(SonaIntegration::new(config));
|
||||
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn multiple threads recording trajectories
|
||||
for thread_id in 0..4 {
|
||||
let sona_clone = Arc::clone(&sona);
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
let trajectory = Trajectory {
|
||||
request_id: format!("thread-{}-req-{}", thread_id, i),
|
||||
session_id: format!("thread-{}", thread_id),
|
||||
query_embedding: vec![0.1; 128],
|
||||
response_embedding: vec![0.2; 128],
|
||||
quality_score: 0.8,
|
||||
routing_features: vec![0.5; 4],
|
||||
model_index: 0,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
sona_clone.record_trajectory(trajectory).unwrap();
|
||||
}
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 40);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_stats_struct() {
|
||||
let stats = SonaStats {
|
||||
total_trajectories: 100,
|
||||
instant_updates: 80,
|
||||
background_updates: 5,
|
||||
deep_updates: 1,
|
||||
patterns_learned: 50,
|
||||
buffer_size: 20,
|
||||
last_background_secs_ago: 3600,
|
||||
last_deep_secs_ago: 86400,
|
||||
};
|
||||
|
||||
assert_eq!(stats.total_trajectories, 100);
|
||||
assert_eq!(stats.instant_updates, 80);
|
||||
assert_eq!(stats.background_updates, 5);
|
||||
assert_eq!(stats.deep_updates, 1);
|
||||
assert_eq!(stats.patterns_learned, 50);
|
||||
assert_eq!(stats.buffer_size, 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_routing_features() {
|
||||
let trajectory = Trajectory {
|
||||
request_id: "routing-test".to_string(),
|
||||
session_id: "test".to_string(),
|
||||
query_embedding: vec![0.1; 128],
|
||||
response_embedding: vec![0.2; 128],
|
||||
quality_score: 0.9,
|
||||
routing_features: vec![0.7, 0.9, 0.8, 0.5], // temperature, top_p, confidence, context_ratio
|
||||
model_index: 1,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
assert_eq!(trajectory.routing_features.len(), 4);
|
||||
assert_eq!(trajectory.routing_features[0], 0.7); // temperature
|
||||
assert_eq!(trajectory.routing_features[1], 0.9); // top_p
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_boundary_quality() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.5,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Exactly at threshold
|
||||
let trajectory = create_test_trajectory("boundary", 0.5);
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.instant_updates, 1); // Should still trigger
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_zero_quality() {
|
||||
let config = SonaConfig {
|
||||
quality_threshold: 0.0,
|
||||
..create_test_sona_config()
|
||||
};
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
let trajectory = create_test_trajectory("zero-quality", 0.0);
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 1);
|
||||
// With threshold 0.0, even quality 0.0 should trigger (0.0 >= 0.0)
|
||||
assert_eq!(stats.instant_updates, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sona_negative_quality_handling() {
|
||||
let config = create_test_sona_config();
|
||||
let sona = SonaIntegration::new(config);
|
||||
|
||||
// Negative quality should still be recorded but not trigger learning
|
||||
let trajectory = Trajectory {
|
||||
request_id: "negative".to_string(),
|
||||
session_id: "test".to_string(),
|
||||
query_embedding: vec![0.1; 128],
|
||||
response_embedding: vec![0.2; 128],
|
||||
quality_score: -0.5, // Negative
|
||||
routing_features: vec![0.5; 4],
|
||||
model_index: 0,
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
|
||||
sona.record_trajectory(trajectory).unwrap();
|
||||
|
||||
let stats = sona.stats();
|
||||
assert_eq!(stats.total_trajectories, 1);
|
||||
assert_eq!(stats.instant_updates, 0); // Should not trigger
|
||||
}
|
||||
464
crates/ruvllm/tests/speculative_integration.rs
Normal file
464
crates/ruvllm/tests/speculative_integration.rs
Normal file
@@ -0,0 +1,464 @@
|
||||
#![allow(
|
||||
clippy::all,
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
unused_mut,
|
||||
unused_assignments,
|
||||
non_camel_case_types,
|
||||
clippy::approx_constant,
|
||||
unexpected_cfgs,
|
||||
unused_must_use,
|
||||
unused_parens
|
||||
)]
|
||||
//! Integration tests for speculative decoding
|
||||
//!
|
||||
//! These tests verify the speculative decoding implementation works correctly
|
||||
//! with mock backends.
|
||||
|
||||
use ruvllm::speculative::{
|
||||
log_softmax, softmax, top_k_filter, top_p_filter, AtomicSpeculativeStats, SpeculationTree,
|
||||
SpeculativeConfig, SpeculativeStats, TreeNode, VerificationResult,
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_speculative_config_defaults() {
|
||||
let config = SpeculativeConfig::default();
|
||||
|
||||
assert_eq!(config.lookahead, 4);
|
||||
assert!((config.acceptance_threshold - 0.5).abs() < 0.01);
|
||||
assert!((config.draft_temperature - 0.0).abs() < 0.01);
|
||||
assert!(!config.tree_speculation);
|
||||
assert_eq!(config.max_tree_depth, 3);
|
||||
assert_eq!(config.tree_branching_factor, 2);
|
||||
assert!(config.adaptive_lookahead);
|
||||
assert_eq!(config.min_lookahead, 2);
|
||||
assert_eq!(config.max_lookahead, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_config_custom() {
|
||||
let config = SpeculativeConfig {
|
||||
lookahead: 6,
|
||||
acceptance_threshold: 0.7,
|
||||
draft_temperature: 0.1,
|
||||
tree_speculation: true,
|
||||
max_tree_depth: 4,
|
||||
tree_branching_factor: 3,
|
||||
draft_top_p: 0.9,
|
||||
min_acceptance_ratio: 0.2,
|
||||
adaptive_lookahead: false,
|
||||
min_lookahead: 3,
|
||||
max_lookahead: 10,
|
||||
};
|
||||
|
||||
assert_eq!(config.lookahead, 6);
|
||||
assert!((config.acceptance_threshold - 0.7).abs() < 0.01);
|
||||
assert!(config.tree_speculation);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_stats_empty() {
|
||||
let stats = SpeculativeStats::new();
|
||||
|
||||
assert_eq!(stats.draft_tokens, 0);
|
||||
assert_eq!(stats.accepted_tokens, 0);
|
||||
assert!((stats.acceptance_rate - 0.0).abs() < 0.01);
|
||||
assert!((stats.speedup - 0.0).abs() < 0.01);
|
||||
assert_eq!(stats.main_forward_passes, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_stats_record_round() {
|
||||
let mut stats = SpeculativeStats::new();
|
||||
|
||||
// Simulate a round: 4 draft tokens, 3 accepted
|
||||
stats.record_round(4, 3, 10.0);
|
||||
|
||||
assert_eq!(stats.draft_tokens, 4);
|
||||
assert_eq!(stats.accepted_tokens, 3);
|
||||
assert!((stats.acceptance_rate - 0.75).abs() < 0.01);
|
||||
assert_eq!(stats.main_forward_passes, 1);
|
||||
assert_eq!(stats.draft_forward_passes, 4);
|
||||
assert_eq!(stats.total_tokens_generated, 4); // 3 accepted + 1 correction
|
||||
|
||||
// Simulate another round: 4 draft, 2 accepted
|
||||
stats.record_round(4, 2, 12.0);
|
||||
|
||||
assert_eq!(stats.draft_tokens, 8);
|
||||
assert_eq!(stats.accepted_tokens, 5);
|
||||
assert!((stats.acceptance_rate - 0.625).abs() < 0.01);
|
||||
assert_eq!(stats.main_forward_passes, 2);
|
||||
assert_eq!(stats.total_tokens_generated, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_stats_speedup_calculation() {
|
||||
let mut stats = SpeculativeStats::new();
|
||||
|
||||
// Perfect speculation: all accepted
|
||||
stats.record_round(4, 4, 10.0);
|
||||
|
||||
// 5 tokens per pass (4 accepted + 1 continuation)
|
||||
assert!((stats.avg_tokens_per_main_pass - 5.0).abs() < 0.01);
|
||||
assert!((stats.speedup - 5.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_speculative_stats() {
|
||||
let stats = AtomicSpeculativeStats::new();
|
||||
|
||||
// Record multiple rounds (simulating concurrent access)
|
||||
stats.record_round(4, 3, Duration::from_millis(10));
|
||||
stats.record_round(4, 4, Duration::from_millis(8));
|
||||
stats.record_round(4, 2, Duration::from_millis(12));
|
||||
|
||||
let snapshot = stats.snapshot();
|
||||
|
||||
assert_eq!(snapshot.draft_tokens, 12);
|
||||
assert_eq!(snapshot.accepted_tokens, 9);
|
||||
assert_eq!(snapshot.main_forward_passes, 3);
|
||||
assert!((snapshot.acceptance_rate - 0.75).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_stats_reset() {
|
||||
let stats = AtomicSpeculativeStats::new();
|
||||
stats.record_round(4, 3, Duration::from_millis(10));
|
||||
|
||||
assert_eq!(stats.snapshot().draft_tokens, 4);
|
||||
|
||||
stats.reset();
|
||||
|
||||
assert_eq!(stats.snapshot().draft_tokens, 0);
|
||||
assert_eq!(stats.snapshot().accepted_tokens, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node_creation() {
|
||||
let node = TreeNode::new(42, 0.8, 0);
|
||||
|
||||
assert_eq!(node.token, 42);
|
||||
assert!((node.prob - 0.8).abs() < 0.01);
|
||||
assert_eq!(node.depth, 0);
|
||||
assert!(node.children.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node_add_child() {
|
||||
let mut root = TreeNode::new(0, 1.0, 0);
|
||||
|
||||
root.add_child(1, 0.6);
|
||||
root.add_child(2, 0.3);
|
||||
root.add_child(3, 0.1);
|
||||
|
||||
assert_eq!(root.children.len(), 3);
|
||||
assert_eq!(root.children[0].token, 1);
|
||||
assert_eq!(root.children[1].token, 2);
|
||||
assert_eq!(root.children[2].token, 3);
|
||||
assert_eq!(root.children[0].depth, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node_best_path() {
|
||||
let mut root = TreeNode::new(0, 1.0, 0);
|
||||
|
||||
// Build tree:
|
||||
// 0
|
||||
// / \
|
||||
// 1 2
|
||||
// / / \
|
||||
// 3 4 5
|
||||
|
||||
let child1 = root.add_child(1, 0.6);
|
||||
child1.add_child(3, 0.5);
|
||||
|
||||
let child2 = root.add_child(2, 0.3);
|
||||
child2.add_child(4, 0.2);
|
||||
child2.add_child(5, 0.1);
|
||||
|
||||
// Best path should follow highest probabilities
|
||||
let path = root.best_path();
|
||||
assert_eq!(path[0], 0);
|
||||
assert_eq!(path[1], 1); // 0.6 > 0.3
|
||||
assert_eq!(path[2], 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node_get_paths() {
|
||||
let mut root = TreeNode::new(0, 1.0, 0);
|
||||
|
||||
let child1 = root.add_child(1, 0.6);
|
||||
child1.add_child(3, 0.5);
|
||||
|
||||
let child2 = root.add_child(2, 0.3);
|
||||
child2.add_child(4, 0.2);
|
||||
child2.add_child(5, 0.1);
|
||||
|
||||
let paths = root.get_paths();
|
||||
|
||||
// Should have 3 paths:
|
||||
// [0, 1, 3]
|
||||
// [0, 2, 4]
|
||||
// [0, 2, 5]
|
||||
assert_eq!(paths.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_tree_creation() {
|
||||
let tree = SpeculationTree::new(3, 2);
|
||||
|
||||
assert_eq!(tree.max_depth, 3);
|
||||
assert_eq!(tree.branching_factor, 2);
|
||||
assert_eq!(tree.node_count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_tree_best_path() {
|
||||
let mut tree = SpeculationTree::new(3, 2);
|
||||
|
||||
// Build a linear path
|
||||
let mut current = &mut tree.root;
|
||||
current = current.add_child(10, 0.9);
|
||||
tree.node_count += 1;
|
||||
current = current.add_child(20, 0.8);
|
||||
tree.node_count += 1;
|
||||
current.add_child(30, 0.7);
|
||||
tree.node_count += 1;
|
||||
|
||||
let best = tree.best_path();
|
||||
|
||||
// Should skip the root placeholder and return [10, 20, 30]
|
||||
assert_eq!(best, vec![10, 20, 30]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_tree_clear() {
|
||||
let mut tree = SpeculationTree::new(3, 2);
|
||||
|
||||
tree.root.add_child(1, 0.5);
|
||||
tree.node_count += 1;
|
||||
|
||||
assert_eq!(tree.node_count, 2);
|
||||
|
||||
tree.clear();
|
||||
|
||||
assert_eq!(tree.node_count, 1);
|
||||
assert!(tree.root.children.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verification_result() {
|
||||
let result = VerificationResult {
|
||||
accepted_count: 3,
|
||||
next_token: 100,
|
||||
accepted_logprobs: vec![-0.1, -0.2, -0.3],
|
||||
next_logprob: -0.5,
|
||||
all_accepted: false,
|
||||
};
|
||||
|
||||
assert_eq!(result.accepted_count, 3);
|
||||
assert_eq!(result.next_token, 100);
|
||||
assert!(!result.all_accepted);
|
||||
assert_eq!(result.accepted_logprobs.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax() {
|
||||
let logits = vec![1.0, 2.0, 3.0];
|
||||
let probs = softmax(&logits);
|
||||
|
||||
// Probabilities should sum to 1
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 0.001);
|
||||
|
||||
// Probabilities should be ordered
|
||||
assert!(probs[2] > probs[1]);
|
||||
assert!(probs[1] > probs[0]);
|
||||
|
||||
// All probabilities should be positive
|
||||
assert!(probs.iter().all(|&p| p > 0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_with_large_values() {
|
||||
// Test numerical stability with large values
|
||||
let logits = vec![100.0, 200.0, 300.0];
|
||||
let probs = softmax(&logits);
|
||||
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_with_negative_values() {
|
||||
let logits = vec![-1.0, -2.0, -3.0];
|
||||
let probs = softmax(&logits);
|
||||
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_softmax() {
|
||||
let logits = vec![1.0, 2.0, 3.0];
|
||||
let log_probs = log_softmax(&logits);
|
||||
|
||||
// All log probabilities should be negative (probabilities < 1)
|
||||
assert!(log_probs.iter().all(|&lp| lp <= 0.0));
|
||||
|
||||
// exp(log_softmax) should equal softmax
|
||||
let probs: Vec<f32> = log_probs.iter().map(|&lp: &f32| lp.exp()).collect();
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_filter() {
|
||||
let mut logits: Vec<f32> = vec![1.0, 5.0, 3.0, 4.0, 2.0];
|
||||
top_k_filter(&mut logits, 2);
|
||||
|
||||
// Only top 2 (5.0 and 4.0) should remain finite
|
||||
let finite_count = logits.iter().filter(|&&x| x.is_finite()).count();
|
||||
assert_eq!(finite_count, 2);
|
||||
|
||||
// The top values should be unchanged
|
||||
assert!((logits[1] - 5.0).abs() < 0.01);
|
||||
assert!((logits[3] - 4.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_filter_k_equals_len() {
|
||||
let mut logits: Vec<f32> = vec![1.0, 2.0, 3.0];
|
||||
top_k_filter(&mut logits, 3);
|
||||
|
||||
// All values should remain
|
||||
let finite_count = logits.iter().filter(|&&x| x.is_finite()).count();
|
||||
assert_eq!(finite_count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_filter_k_zero() {
|
||||
let mut logits = vec![1.0, 2.0, 3.0];
|
||||
let original = logits.clone();
|
||||
top_k_filter(&mut logits, 0);
|
||||
|
||||
// No filtering when k=0
|
||||
assert_eq!(logits, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_p_filter() {
|
||||
let mut logits: Vec<f32> = vec![10.0, 5.0, 2.0, 1.0, 0.5];
|
||||
top_p_filter(&mut logits, 0.9);
|
||||
|
||||
// Most probability mass should be preserved
|
||||
let finite_count = logits.iter().filter(|&&x| x.is_finite()).count();
|
||||
assert!(finite_count >= 1 && finite_count <= 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_p_filter_p_one() {
|
||||
let mut logits = vec![1.0, 2.0, 3.0];
|
||||
let original = logits.clone();
|
||||
top_p_filter(&mut logits, 1.0);
|
||||
|
||||
// No filtering when p=1.0
|
||||
assert_eq!(logits, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_p_filter_very_low_p() {
|
||||
let mut logits: Vec<f32> = vec![10.0, 1.0, 0.5, 0.1];
|
||||
top_p_filter(&mut logits, 0.01);
|
||||
|
||||
// Only the highest probability token should remain
|
||||
let finite_count = logits.iter().filter(|&&x| x.is_finite()).count();
|
||||
assert!(finite_count >= 1);
|
||||
|
||||
// The top value should be finite
|
||||
assert!(logits[0].is_finite());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let config = SpeculativeConfig {
|
||||
lookahead: 6,
|
||||
acceptance_threshold: 0.7,
|
||||
draft_temperature: 0.1,
|
||||
tree_speculation: true,
|
||||
max_tree_depth: 4,
|
||||
tree_branching_factor: 3,
|
||||
draft_top_p: 0.9,
|
||||
min_acceptance_ratio: 0.2,
|
||||
adaptive_lookahead: true,
|
||||
min_lookahead: 3,
|
||||
max_lookahead: 10,
|
||||
};
|
||||
|
||||
// Test JSON serialization
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: SpeculativeConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.lookahead, 6);
|
||||
assert!(deserialized.tree_speculation);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats_serialization() {
|
||||
let mut stats = SpeculativeStats::new();
|
||||
stats.record_round(4, 3, 10.0);
|
||||
|
||||
let json = serde_json::to_string(&stats).unwrap();
|
||||
let deserialized: SpeculativeStats = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.draft_tokens, 4);
|
||||
assert_eq!(deserialized.accepted_tokens, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_realistic_speculation_scenario() {
|
||||
let mut stats = SpeculativeStats::new();
|
||||
|
||||
// Simulate 100 generation rounds with varying acceptance
|
||||
for i in 0..100 {
|
||||
let draft_count = 4;
|
||||
// Acceptance varies: high at start, lower later (simulating diverse output)
|
||||
let accepted = if i < 30 {
|
||||
4 // 100% acceptance
|
||||
} else if i < 60 {
|
||||
3 // 75% acceptance
|
||||
} else {
|
||||
2 // 50% acceptance
|
||||
};
|
||||
|
||||
stats.record_round(draft_count, accepted, (i as f64) * 0.1);
|
||||
}
|
||||
|
||||
// Verify stats are reasonable
|
||||
assert_eq!(stats.draft_tokens, 400);
|
||||
assert!(stats.acceptance_rate > 0.5 && stats.acceptance_rate < 1.0);
|
||||
assert!(stats.speedup > 1.0); // Should show speedup
|
||||
assert_eq!(stats.main_forward_passes, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_with_deep_nesting() {
|
||||
let mut tree = SpeculationTree::new(5, 2);
|
||||
|
||||
// Build a deep tree
|
||||
fn build_recursive(node: &mut TreeNode, depth: usize, max_depth: usize) {
|
||||
if depth >= max_depth {
|
||||
return;
|
||||
}
|
||||
|
||||
let child = node.add_child((depth * 10) as u32, 1.0 / (depth + 1) as f32);
|
||||
build_recursive(child, depth + 1, max_depth);
|
||||
}
|
||||
|
||||
build_recursive(&mut tree.root, 0, 5);
|
||||
|
||||
let best = tree.best_path();
|
||||
assert_eq!(best.len(), 5);
|
||||
}
|
||||
Reference in New Issue
Block a user