Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
644
vendor/ruvector/crates/ruvllm/src/tests/activation_tests.rs
vendored
Normal file
644
vendor/ruvector/crates/ruvllm/src/tests/activation_tests.rs
vendored
Normal file
@@ -0,0 +1,644 @@
|
||||
//! Activation Function Tests
|
||||
//!
|
||||
//! Tests for NEON vs scalar implementations of activation functions:
|
||||
//! SiLU, GELU, ReLU, and Softmax, including correctness and benchmarks.
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
// ============================================================================
|
||||
// SiLU (Swish) Activation Tests
|
||||
// ============================================================================
|
||||
|
||||
/// Reference SiLU implementation: x * sigmoid(x)
|
||||
fn silu_reference(x: f32) -> f32 {
|
||||
x / (1.0 + (-x).exp())
|
||||
}
|
||||
|
||||
/// Vectorized SiLU for testing
|
||||
fn silu_vec_reference(input: &[f32]) -> Vec<f32> {
|
||||
input.iter().map(|&x| silu_reference(x)).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_silu_basic_values() {
|
||||
// Test known values
|
||||
let inputs = vec![0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5];
|
||||
|
||||
for x in inputs {
|
||||
let result = silu_reference(x);
|
||||
|
||||
// SiLU(0) = 0
|
||||
if x == 0.0 {
|
||||
assert!((result - 0.0).abs() < 1e-6, "SiLU(0) should be 0");
|
||||
}
|
||||
|
||||
// SiLU should be finite for all finite inputs
|
||||
assert!(result.is_finite(), "SiLU({}) should be finite", x);
|
||||
|
||||
// For positive x, SiLU(x) < x (since sigmoid < 1)
|
||||
if x > 0.0 {
|
||||
assert!(result < x, "SiLU({}) should be less than {}", x, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_silu_vector() {
|
||||
let input = vec![0.0, 0.5, 1.0, 1.5, 2.0, -0.5, -1.0, -1.5];
|
||||
let output = silu_vec_reference(&input);
|
||||
|
||||
assert_eq!(output.len(), input.len());
|
||||
|
||||
// Verify each element
|
||||
for (i, (&x, &y)) in input.iter().zip(output.iter()).enumerate() {
|
||||
let expected = silu_reference(x);
|
||||
assert!(
|
||||
(y - expected).abs() < 1e-6,
|
||||
"SiLU mismatch at index {}: got {}, expected {}",
|
||||
i,
|
||||
y,
|
||||
expected
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_silu_symmetry() {
|
||||
// SiLU is NOT symmetric: silu(-x) != -silu(x)
|
||||
// But there's a relationship: silu(-x) = -x * sigmoid(-x) = -x/(1+e^x)
|
||||
let x = 1.5;
|
||||
let silu_pos = silu_reference(x);
|
||||
let silu_neg = silu_reference(-x);
|
||||
|
||||
// They should NOT be equal in magnitude
|
||||
assert!((silu_pos.abs() - silu_neg.abs()).abs() > 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_silu_large_values() {
|
||||
// Test numerical stability with large values
|
||||
let large_positive = 100.0f32;
|
||||
let large_negative = -100.0f32;
|
||||
|
||||
let result_pos = silu_reference(large_positive);
|
||||
let result_neg = silu_reference(large_negative);
|
||||
|
||||
// For large positive x, SiLU(x) ≈ x
|
||||
assert!((result_pos - large_positive).abs() < 1e-4);
|
||||
|
||||
// For large negative x, SiLU(x) ≈ 0
|
||||
assert!(result_neg.abs() < 1e-4);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GELU Activation Tests
|
||||
// ============================================================================
|
||||
|
||||
/// Reference GELU implementation (approximation)
|
||||
fn gelu_reference(x: f32) -> f32 {
|
||||
// Approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
let sqrt_2_pi = 0.7978845608_f32;
|
||||
let coeff = 0.044715_f32;
|
||||
|
||||
let inner = sqrt_2_pi * (x + coeff * x * x * x);
|
||||
0.5 * x * (1.0 + inner.tanh())
|
||||
}
|
||||
|
||||
/// Exact GELU (using erf)
|
||||
fn gelu_exact(x: f32) -> f32 {
|
||||
// GELU(x) = x * Phi(x) where Phi is standard normal CDF
|
||||
// = 0.5 * x * (1 + erf(x / sqrt(2)))
|
||||
let sqrt_2 = std::f32::consts::SQRT_2;
|
||||
0.5 * x * (1.0 + erf_approx(x / sqrt_2))
|
||||
}
|
||||
|
||||
/// Simple erf approximation for testing
|
||||
fn erf_approx(x: f32) -> f32 {
|
||||
// Abramowitz and Stegun approximation
|
||||
let sign = if x < 0.0 { -1.0 } else { 1.0 };
|
||||
let x = x.abs();
|
||||
|
||||
let a1 = 0.254829592_f32;
|
||||
let a2 = -0.284496736_f32;
|
||||
let a3 = 1.421413741_f32;
|
||||
let a4 = -1.453152027_f32;
|
||||
let a5 = 1.061405429_f32;
|
||||
let p = 0.3275911_f32;
|
||||
|
||||
let t = 1.0 / (1.0 + p * x);
|
||||
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
|
||||
|
||||
sign * y
|
||||
}
|
||||
|
||||
fn gelu_vec_reference(input: &[f32]) -> Vec<f32> {
|
||||
input.iter().map(|&x| gelu_reference(x)).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_basic_values() {
|
||||
// GELU(0) = 0
|
||||
assert!((gelu_reference(0.0) - 0.0).abs() < 1e-6);
|
||||
|
||||
// For large positive x, GELU(x) ≈ x
|
||||
let large = 5.0;
|
||||
assert!((gelu_reference(large) - large).abs() < 0.1);
|
||||
|
||||
// For large negative x, GELU(x) ≈ 0
|
||||
assert!(gelu_reference(-5.0).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_approx_vs_exact() {
|
||||
// Test that approximation is close to exact GELU
|
||||
let test_values = vec![-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0];
|
||||
|
||||
for x in test_values {
|
||||
let approx = gelu_reference(x);
|
||||
let exact = gelu_exact(x);
|
||||
|
||||
// Approximation should be within 1%
|
||||
let error = (approx - exact).abs() / exact.abs().max(1e-6);
|
||||
assert!(
|
||||
error < 0.01,
|
||||
"GELU approximation error too large at x={}: approx={}, exact={}",
|
||||
x,
|
||||
approx,
|
||||
exact
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_vector() {
|
||||
let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, -3.0, 0.5];
|
||||
let output = gelu_vec_reference(&input);
|
||||
|
||||
assert_eq!(output.len(), input.len());
|
||||
|
||||
for (i, &y) in output.iter().enumerate() {
|
||||
assert!(y.is_finite(), "GELU output {} should be finite", i);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_monotonicity() {
|
||||
// GELU is approximately monotonic for x > -0.5
|
||||
let values: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
|
||||
let outputs = gelu_vec_reference(&values);
|
||||
|
||||
for i in 1..outputs.len() {
|
||||
// Not strictly monotonic but increasing trend for positive values
|
||||
if values[i] > 0.5 {
|
||||
assert!(
|
||||
outputs[i] >= outputs[i - 1] - 1e-6,
|
||||
"GELU should be increasing for positive values"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ReLU Activation Tests
|
||||
// ============================================================================
|
||||
|
||||
fn relu_reference(x: f32) -> f32 {
|
||||
x.max(0.0)
|
||||
}
|
||||
|
||||
fn relu_vec_reference(input: &[f32]) -> Vec<f32> {
|
||||
input.iter().map(|&x| relu_reference(x)).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu_basic() {
|
||||
assert_eq!(relu_reference(5.0), 5.0);
|
||||
assert_eq!(relu_reference(0.0), 0.0);
|
||||
assert_eq!(relu_reference(-5.0), 0.0);
|
||||
assert_eq!(relu_reference(-0.001), 0.0);
|
||||
assert_eq!(relu_reference(0.001), 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu_vector() {
|
||||
let input = vec![-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0];
|
||||
let expected = vec![0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0];
|
||||
let output = relu_vec_reference(&input);
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu_is_idempotent() {
|
||||
// ReLU(ReLU(x)) = ReLU(x)
|
||||
let input = vec![-5.0, -1.0, 0.0, 1.0, 5.0];
|
||||
let once = relu_vec_reference(&input);
|
||||
let twice = relu_vec_reference(&once);
|
||||
|
||||
assert_eq!(once, twice);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu_special_values() {
|
||||
assert!(relu_reference(f32::INFINITY).is_infinite());
|
||||
assert_eq!(relu_reference(f32::NEG_INFINITY), 0.0);
|
||||
// NaN handling can vary; either NaN or 0.0 is acceptable
|
||||
let nan_result = relu_reference(f32::NAN);
|
||||
assert!(nan_result.is_nan() || nan_result == 0.0);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Softmax Tests
|
||||
// ============================================================================
|
||||
|
||||
fn softmax_reference(logits: &[f32]) -> Vec<f32> {
|
||||
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
logits
|
||||
.iter()
|
||||
.map(|&x| (x - max_logit).exp() / exp_sum)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_sum_to_one() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let probs = softmax_reference(&logits);
|
||||
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!(
|
||||
(sum - 1.0).abs() < 1e-6,
|
||||
"Softmax should sum to 1.0, got {}",
|
||||
sum
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_all_positive() {
|
||||
let logits = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
|
||||
let probs = softmax_reference(&logits);
|
||||
|
||||
for p in &probs {
|
||||
assert!(*p > 0.0, "All softmax outputs should be positive");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_ordering() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let probs = softmax_reference(&logits);
|
||||
|
||||
// Probabilities should be in increasing order
|
||||
for i in 0..probs.len() - 1 {
|
||||
assert!(
|
||||
probs[i] < probs[i + 1],
|
||||
"Higher logit should have higher prob"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_numerical_stability() {
|
||||
// Test with very large logits (would overflow without max subtraction)
|
||||
let logits = vec![1000.0, 1001.0, 1002.0];
|
||||
let probs = softmax_reference(&logits);
|
||||
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!(
|
||||
(sum - 1.0).abs() < 1e-4,
|
||||
"Softmax should be stable with large inputs"
|
||||
);
|
||||
assert!(
|
||||
probs.iter().all(|p| p.is_finite()),
|
||||
"All probs should be finite"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_uniform() {
|
||||
// Equal logits should give uniform distribution
|
||||
let logits = vec![5.0, 5.0, 5.0, 5.0];
|
||||
let probs = softmax_reference(&logits);
|
||||
|
||||
for p in &probs {
|
||||
assert!(
|
||||
(p - 0.25).abs() < 1e-6,
|
||||
"Equal logits should give uniform probs"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_temperature_effect() {
|
||||
let logits = vec![1.0, 2.0, 3.0];
|
||||
|
||||
// Temperature 1.0
|
||||
let probs_t1 = softmax_reference(&logits);
|
||||
|
||||
// Temperature 0.5 (sharper)
|
||||
let scaled_05: Vec<f32> = logits.iter().map(|&x| x / 0.5).collect();
|
||||
let probs_t05 = softmax_reference(&scaled_05);
|
||||
|
||||
// Temperature 2.0 (flatter)
|
||||
let scaled_20: Vec<f32> = logits.iter().map(|&x| x / 2.0).collect();
|
||||
let probs_t20 = softmax_reference(&scaled_20);
|
||||
|
||||
// Lower temperature should concentrate probability on max
|
||||
assert!(
|
||||
probs_t05[2] > probs_t1[2],
|
||||
"Lower temp should increase max prob"
|
||||
);
|
||||
|
||||
// Higher temperature should flatten distribution
|
||||
assert!(
|
||||
probs_t20[0] > probs_t1[0],
|
||||
"Higher temp should increase min prob"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Leaky ReLU Tests
|
||||
// ============================================================================
|
||||
|
||||
fn leaky_relu_reference(x: f32, alpha: f32) -> f32 {
|
||||
if x > 0.0 {
|
||||
x
|
||||
} else {
|
||||
alpha * x
|
||||
}
|
||||
}
|
||||
|
||||
fn leaky_relu_vec_reference(input: &[f32], alpha: f32) -> Vec<f32> {
|
||||
input
|
||||
.iter()
|
||||
.map(|&x| leaky_relu_reference(x, alpha))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_basic() {
|
||||
let alpha = 0.01;
|
||||
|
||||
assert_eq!(leaky_relu_reference(5.0, alpha), 5.0);
|
||||
assert_eq!(leaky_relu_reference(0.0, alpha), 0.0);
|
||||
// Use tolerance for floating-point comparison
|
||||
assert!((leaky_relu_reference(-5.0, alpha) - (-0.05)).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_reduces_to_relu() {
|
||||
let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
|
||||
let leaky = leaky_relu_vec_reference(&input, 0.0);
|
||||
let relu = relu_vec_reference(&input);
|
||||
|
||||
assert_eq!(leaky, relu, "Leaky ReLU with alpha=0 should equal ReLU");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaky_relu_continuity() {
|
||||
let alpha = 0.1;
|
||||
let epsilon = 1e-6;
|
||||
|
||||
// Check continuity at x=0
|
||||
let left = leaky_relu_reference(-epsilon, alpha);
|
||||
let right = leaky_relu_reference(epsilon, alpha);
|
||||
let at_zero = leaky_relu_reference(0.0, alpha);
|
||||
|
||||
assert!(
|
||||
(left - at_zero).abs() < 1e-4,
|
||||
"Should be continuous from left"
|
||||
);
|
||||
assert!(
|
||||
(right - at_zero).abs() < 1e-4,
|
||||
"Should be continuous from right"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Comparison Tests (NEON vs Scalar)
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_activation_performance_comparison() {
|
||||
// Create test data
|
||||
let size = 10000;
|
||||
let input: Vec<f32> = (0..size).map(|i| (i as f32 - 5000.0) / 1000.0).collect();
|
||||
|
||||
// Warm up
|
||||
let _ = relu_vec_reference(&input);
|
||||
let _ = silu_vec_reference(&input);
|
||||
let _ = gelu_vec_reference(&input);
|
||||
|
||||
// Benchmark ReLU
|
||||
let start = Instant::now();
|
||||
for _ in 0..100 {
|
||||
let _ = relu_vec_reference(&input);
|
||||
}
|
||||
let relu_time = start.elapsed();
|
||||
|
||||
// Benchmark SiLU
|
||||
let start = Instant::now();
|
||||
for _ in 0..100 {
|
||||
let _ = silu_vec_reference(&input);
|
||||
}
|
||||
let silu_time = start.elapsed();
|
||||
|
||||
// Benchmark GELU
|
||||
let start = Instant::now();
|
||||
for _ in 0..100 {
|
||||
let _ = gelu_vec_reference(&input);
|
||||
}
|
||||
let gelu_time = start.elapsed();
|
||||
|
||||
// Benchmark Softmax
|
||||
let softmax_input: Vec<f32> = input[0..1000].to_vec();
|
||||
let start = Instant::now();
|
||||
for _ in 0..100 {
|
||||
let _ = softmax_reference(&softmax_input);
|
||||
}
|
||||
let softmax_time = start.elapsed();
|
||||
|
||||
// Print timing results (for manual inspection)
|
||||
// These assertions just verify the functions complete in reasonable time
|
||||
assert!(relu_time.as_millis() < 1000, "ReLU should complete quickly");
|
||||
assert!(
|
||||
silu_time.as_millis() < 2000,
|
||||
"SiLU should complete in reasonable time"
|
||||
);
|
||||
assert!(
|
||||
gelu_time.as_millis() < 2000,
|
||||
"GELU should complete in reasonable time"
|
||||
);
|
||||
assert!(
|
||||
softmax_time.as_millis() < 1000,
|
||||
"Softmax should complete quickly"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// NEON vs Scalar Correctness Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_neon_softmax_vs_scalar() {
|
||||
// Test our reference softmax implementation produces valid probability distribution
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let scalar_result = softmax_reference(&logits);
|
||||
|
||||
// Sum should be 1.0
|
||||
let sum: f32 = scalar_result.iter().sum();
|
||||
assert!(
|
||||
(sum - 1.0).abs() < 1e-4,
|
||||
"Softmax sum should be 1.0, got {}",
|
||||
sum
|
||||
);
|
||||
|
||||
// All probabilities should be positive
|
||||
assert!(scalar_result.iter().all(|&p| p > 0.0 && p < 1.0));
|
||||
|
||||
// Ordering should be preserved (higher logits = higher probs)
|
||||
for i in 0..scalar_result.len() - 1 {
|
||||
assert!(scalar_result[i] < scalar_result[i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neon_softmax_large_array() {
|
||||
// Test reference softmax with large array
|
||||
let logits: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 10.0).collect();
|
||||
|
||||
let scalar_result = softmax_reference(&logits);
|
||||
|
||||
// Check sum
|
||||
let scalar_sum: f32 = scalar_result.iter().sum();
|
||||
assert!(
|
||||
(scalar_sum - 1.0).abs() < 1e-4,
|
||||
"Scalar softmax sum should be 1.0, got {}",
|
||||
scalar_sum
|
||||
);
|
||||
|
||||
// Check all values are valid probabilities
|
||||
assert!(scalar_result
|
||||
.iter()
|
||||
.all(|&p| p >= 0.0 && p <= 1.0 && p.is_finite()));
|
||||
|
||||
// Check ordering is preserved
|
||||
for i in 0..scalar_result.len() - 1 {
|
||||
assert!(
|
||||
scalar_result[i] <= scalar_result[i + 1],
|
||||
"Ordering should be preserved"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Case Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_activation_empty_input() {
|
||||
let empty: Vec<f32> = vec![];
|
||||
|
||||
assert!(relu_vec_reference(&empty).is_empty());
|
||||
assert!(silu_vec_reference(&empty).is_empty());
|
||||
assert!(gelu_vec_reference(&empty).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_single_element() {
|
||||
let single = vec![2.5];
|
||||
|
||||
assert_eq!(relu_vec_reference(&single), vec![2.5]);
|
||||
assert_eq!(silu_vec_reference(&single).len(), 1);
|
||||
assert_eq!(gelu_vec_reference(&single).len(), 1);
|
||||
|
||||
let softmax_result = softmax_reference(&single);
|
||||
assert_eq!(softmax_result.len(), 1);
|
||||
assert!((softmax_result[0] - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_all_negative() {
|
||||
let input = vec![-5.0, -4.0, -3.0, -2.0, -1.0];
|
||||
|
||||
// ReLU should be all zeros
|
||||
let relu_result = relu_vec_reference(&input);
|
||||
assert!(relu_result.iter().all(|&x| x == 0.0));
|
||||
|
||||
// SiLU should be small but non-zero
|
||||
let silu_result = silu_vec_reference(&input);
|
||||
assert!(silu_result.iter().all(|&x| x < 0.0));
|
||||
|
||||
// Softmax should still sum to 1
|
||||
let softmax_result = softmax_reference(&input);
|
||||
let sum: f32 = softmax_result.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_activation_all_zeros() {
|
||||
let input = vec![0.0, 0.0, 0.0, 0.0];
|
||||
|
||||
// ReLU(0) = 0
|
||||
assert_eq!(relu_vec_reference(&input), input);
|
||||
|
||||
// SiLU(0) = 0
|
||||
let silu_result = silu_vec_reference(&input);
|
||||
assert!(silu_result.iter().all(|&x| x.abs() < 1e-6));
|
||||
|
||||
// GELU(0) = 0
|
||||
let gelu_result = gelu_vec_reference(&input);
|
||||
assert!(gelu_result.iter().all(|&x| x.abs() < 1e-6));
|
||||
|
||||
// Softmax of all equal values should be uniform
|
||||
let softmax_result = softmax_reference(&input);
|
||||
assert!(softmax_result.iter().all(|&x| (x - 0.25).abs() < 1e-6));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Gradient-like Tests (Derivative Approximation)
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_relu_derivative() {
|
||||
let epsilon = 1e-5;
|
||||
|
||||
// Positive x: derivative should be 1
|
||||
let x = 2.0;
|
||||
let deriv = (relu_reference(x + epsilon) - relu_reference(x - epsilon)) / (2.0 * epsilon);
|
||||
assert!((deriv - 1.0).abs() < 0.01);
|
||||
|
||||
// Negative x: derivative should be 0
|
||||
let x = -2.0;
|
||||
let deriv = (relu_reference(x + epsilon) - relu_reference(x - epsilon)) / (2.0 * epsilon);
|
||||
assert!(deriv.abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_silu_derivative_at_zero() {
|
||||
let epsilon = 1e-5;
|
||||
let x = 0.0;
|
||||
|
||||
let deriv = (silu_reference(x + epsilon) - silu_reference(x - epsilon)) / (2.0 * epsilon);
|
||||
|
||||
// SiLU'(0) = 0.5
|
||||
assert!(
|
||||
(deriv - 0.5).abs() < 0.01,
|
||||
"SiLU derivative at 0 should be 0.5"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gelu_derivative_positive() {
|
||||
let epsilon = 1e-5;
|
||||
let x = 1.0;
|
||||
|
||||
let deriv = (gelu_reference(x + epsilon) - gelu_reference(x - epsilon)) / (2.0 * epsilon);
|
||||
|
||||
// For positive x, GELU derivative should be close to 1
|
||||
assert!(
|
||||
deriv > 0.5 && deriv < 1.5,
|
||||
"GELU derivative at x=1 should be near 1"
|
||||
);
|
||||
}
|
||||
889
vendor/ruvector/crates/ruvllm/src/tests/attention_tests.rs
vendored
Normal file
889
vendor/ruvector/crates/ruvllm/src/tests/attention_tests.rs
vendored
Normal file
@@ -0,0 +1,889 @@
|
||||
//! Attention Tests
|
||||
//!
|
||||
//! Tests for Flash Attention, Paged Attention, MQA/GQA implementations,
|
||||
//! output correctness, memory allocation, pre-allocated buffer reuse, and benchmarks.
|
||||
|
||||
use crate::kernels::{
|
||||
flash_attention_auto, flash_attention_neon, flash_attention_v2, grouped_query_attention_neon,
|
||||
multi_query_attention_neon, paged_attention_neon, select_block_size, AttentionConfig,
|
||||
PagedKvCache, BLOCK_SIZE_LARGE, BLOCK_SIZE_MEDIUM, BLOCK_SIZE_SMALL,
|
||||
};
|
||||
use std::time::Instant;
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Reference scalar attention implementation for correctness checking
|
||||
fn attention_reference(
|
||||
query: &[f32],
|
||||
key: &[f32],
|
||||
value: &[f32],
|
||||
head_dim: usize,
|
||||
scale: f32,
|
||||
) -> Vec<f32> {
|
||||
let kv_len = key.len() / head_dim;
|
||||
|
||||
// Compute scores: Q @ K^T
|
||||
let mut scores = Vec::with_capacity(kv_len);
|
||||
for t in 0..kv_len {
|
||||
let k_offset = t * head_dim;
|
||||
let score: f32 = query
|
||||
.iter()
|
||||
.zip(&key[k_offset..k_offset + head_dim])
|
||||
.map(|(q, k)| q * k * scale)
|
||||
.sum();
|
||||
scores.push(score);
|
||||
}
|
||||
|
||||
// Softmax
|
||||
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
|
||||
let sum_exp: f32 = exp_scores.iter().sum();
|
||||
let attn_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
|
||||
|
||||
// Weighted sum of values
|
||||
let mut output = vec![0.0; head_dim];
|
||||
for (t, weight) in attn_weights.iter().enumerate() {
|
||||
let v_offset = t * head_dim;
|
||||
for (i, v) in value[v_offset..v_offset + head_dim].iter().enumerate() {
|
||||
output[i] += weight * v;
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Generate random test data
|
||||
fn generate_test_data(head_dim: usize, kv_len: usize, seed: u64) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
|
||||
let mut rng_state = seed;
|
||||
let next_float = |state: &mut u64| -> f32 {
|
||||
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
((*state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
|
||||
};
|
||||
|
||||
let query: Vec<f32> = (0..head_dim).map(|_| next_float(&mut rng_state)).collect();
|
||||
let key: Vec<f32> = (0..kv_len * head_dim)
|
||||
.map(|_| next_float(&mut rng_state))
|
||||
.collect();
|
||||
let value: Vec<f32> = (0..kv_len * head_dim)
|
||||
.map(|_| next_float(&mut rng_state))
|
||||
.collect();
|
||||
|
||||
(query, key, value)
|
||||
}
|
||||
|
||||
/// Check if two vectors are approximately equal
|
||||
fn vectors_approx_equal(a: &[f32], b: &[f32], tolerance: f32) -> bool {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.all(|(x, y)| (x - y).abs() < tolerance)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Flash Attention Basic Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_flash_attention_basic() {
|
||||
let head_dim = 16;
|
||||
let kv_len = 4;
|
||||
|
||||
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32) * 0.1).collect();
|
||||
let key: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.01).collect();
|
||||
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.02).collect();
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
assert_eq!(
|
||||
output.len(),
|
||||
head_dim,
|
||||
"Output should have head_dim elements"
|
||||
);
|
||||
assert!(
|
||||
output.iter().all(|&x| x.is_finite()),
|
||||
"All outputs should be finite"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flash_attention_vs_reference() {
|
||||
let head_dim = 32;
|
||||
let kv_len = 16;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 12345);
|
||||
|
||||
let neon_output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
let ref_output = attention_reference(&query, &key, &value, head_dim, scale);
|
||||
|
||||
assert!(
|
||||
vectors_approx_equal(&neon_output, &ref_output, 1e-3),
|
||||
"NEON and reference outputs should match"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flash_attention_empty_kv() {
|
||||
let head_dim = 16;
|
||||
let query: Vec<f32> = (0..head_dim).map(|i| i as f32).collect();
|
||||
let key: Vec<f32> = vec![];
|
||||
let value: Vec<f32> = vec![];
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
// Should handle empty KV gracefully - either return empty or zero-filled vector
|
||||
assert!(output.len() == 0 || output.len() == head_dim);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flash_attention_single_token() {
|
||||
let head_dim = 64;
|
||||
let kv_len = 1;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 42);
|
||||
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
// With single KV token, output should be proportional to the value
|
||||
// (after softmax, the single token gets weight 1.0)
|
||||
assert!(
|
||||
vectors_approx_equal(&output, &value, 1e-5),
|
||||
"Single token attention should return value directly"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Flash Attention V2 Block Size Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_flash_attention_v2_small_block() {
|
||||
let head_dim = 64;
|
||||
let kv_len = 100;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 111);
|
||||
|
||||
let output_small = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_SMALL);
|
||||
let output_medium = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_MEDIUM);
|
||||
|
||||
// Different block sizes should produce same results
|
||||
assert!(
|
||||
vectors_approx_equal(&output_small, &output_medium, 1e-3),
|
||||
"Block sizes should not affect correctness"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flash_attention_v2_all_block_sizes() {
|
||||
let head_dim = 128;
|
||||
let kv_len = 256;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 222);
|
||||
|
||||
let output_small = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_SMALL);
|
||||
let output_medium = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_MEDIUM);
|
||||
let output_large = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_LARGE);
|
||||
|
||||
// All should produce similar results
|
||||
assert!(vectors_approx_equal(&output_small, &output_medium, 1e-3));
|
||||
assert!(vectors_approx_equal(&output_medium, &output_large, 1e-3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flash_attention_auto_block_selection() {
|
||||
let head_dim = 128;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
// Short sequence should use small blocks
|
||||
let (q1, k1, v1) = generate_test_data(head_dim, 32, 333);
|
||||
let _output1 = flash_attention_auto(&q1, &k1, &v1, scale, false);
|
||||
|
||||
// Long sequence should use larger blocks
|
||||
let (q2, k2, v2) = generate_test_data(head_dim, 1024, 444);
|
||||
let _output2 = flash_attention_auto(&q2, &k2, &v2, scale, false);
|
||||
|
||||
// Just verify they complete without error
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Block Size Selection Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_select_block_size_short_sequence() {
|
||||
let head_dim = 128;
|
||||
|
||||
// Very short sequences should use small blocks
|
||||
assert_eq!(select_block_size(32, head_dim), BLOCK_SIZE_SMALL);
|
||||
assert_eq!(select_block_size(64, head_dim), BLOCK_SIZE_SMALL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_block_size_medium_sequence() {
|
||||
let head_dim = 128;
|
||||
|
||||
// Medium sequences should use medium blocks
|
||||
assert_eq!(select_block_size(128, head_dim), BLOCK_SIZE_MEDIUM);
|
||||
assert_eq!(select_block_size(256, head_dim), BLOCK_SIZE_MEDIUM);
|
||||
assert_eq!(select_block_size(512, head_dim), BLOCK_SIZE_MEDIUM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_block_size_long_sequence() {
|
||||
let head_dim = 64; // Smaller head_dim allows larger blocks
|
||||
|
||||
// Long sequences with small head_dim can use large blocks
|
||||
let block = select_block_size(2048, head_dim);
|
||||
assert!(
|
||||
block >= BLOCK_SIZE_MEDIUM,
|
||||
"Long sequences should use at least medium blocks"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_block_size_large_head_dim() {
|
||||
let head_dim = 256; // Large head_dim limits block size
|
||||
|
||||
// Large head_dim should constrain block size to fit in L1
|
||||
let block = select_block_size(2048, head_dim);
|
||||
assert!(block <= BLOCK_SIZE_LARGE);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Paged KV Cache Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_paged_kv_cache_creation() {
|
||||
let cache = PagedKvCache::new(16, 4, 64);
|
||||
|
||||
assert_eq!(cache.block_size, 16);
|
||||
assert_eq!(cache.num_kv_heads, 4);
|
||||
assert_eq!(cache.head_dim, 64);
|
||||
assert_eq!(cache.num_tokens, 0);
|
||||
assert!(cache.key_blocks.is_empty());
|
||||
assert!(cache.value_blocks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_kv_cache_append() {
|
||||
let mut cache = PagedKvCache::new(4, 2, 8);
|
||||
|
||||
// Append one token (2 kv_heads * 8 head_dim = 16 elements)
|
||||
let keys = vec![1.0; 16];
|
||||
let values = vec![2.0; 16];
|
||||
|
||||
cache.append(&keys, &values);
|
||||
|
||||
assert_eq!(cache.num_tokens, 1);
|
||||
assert_eq!(cache.key_blocks.len(), 1);
|
||||
assert_eq!(cache.value_blocks.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_kv_cache_append_multiple() {
|
||||
let mut cache = PagedKvCache::new(4, 2, 8);
|
||||
let stride = 2 * 8; // 16 elements per token
|
||||
|
||||
// Append 5 tokens (more than one block)
|
||||
for i in 0..5 {
|
||||
let keys = vec![(i + 1) as f32; stride];
|
||||
let values = vec![(i + 1) as f32 * 2.0; stride];
|
||||
cache.append(&keys, &values);
|
||||
}
|
||||
|
||||
assert_eq!(cache.num_tokens, 5);
|
||||
assert_eq!(cache.key_blocks.len(), 2); // 5 tokens, 4 per block = 2 blocks
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_kv_cache_get_keys() {
|
||||
let mut cache = PagedKvCache::new(4, 1, 8);
|
||||
|
||||
// Append 2 tokens
|
||||
let keys1 = vec![1.0; 8];
|
||||
let values1 = vec![10.0; 8];
|
||||
cache.append(&keys1, &values1);
|
||||
|
||||
let keys2 = vec![2.0; 8];
|
||||
let values2 = vec![20.0; 8];
|
||||
cache.append(&keys2, &values2);
|
||||
|
||||
let retrieved_keys = cache.get_keys();
|
||||
assert_eq!(retrieved_keys.len(), 16); // 2 tokens * 1 head * 8 dim
|
||||
assert!(retrieved_keys[..8].iter().all(|&x| x == 1.0));
|
||||
assert!(retrieved_keys[8..].iter().all(|&x| x == 2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_kv_cache_get_values() {
|
||||
let mut cache = PagedKvCache::new(4, 1, 8);
|
||||
|
||||
let keys = vec![1.0; 8];
|
||||
let values = vec![5.0; 8];
|
||||
cache.append(&keys, &values);
|
||||
|
||||
let retrieved_values = cache.get_values();
|
||||
assert_eq!(retrieved_values.len(), 8);
|
||||
assert!(retrieved_values.iter().all(|&x| x == 5.0));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Paged Attention Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_empty_cache() {
|
||||
let cache = PagedKvCache::new(16, 1, 16);
|
||||
let query = vec![0.5; 16];
|
||||
let scale = 0.25;
|
||||
|
||||
let output = paged_attention_neon(&query, &cache, &[], scale);
|
||||
|
||||
assert_eq!(output.len(), 16);
|
||||
// Empty cache should return zeros
|
||||
assert!(output.iter().all(|&x| x == 0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paged_attention_with_cache() {
|
||||
let mut cache = PagedKvCache::new(16, 1, 16);
|
||||
|
||||
// Add some tokens
|
||||
for _ in 0..8 {
|
||||
let keys: Vec<f32> = (0..16).map(|i| (i as f32) * 0.1).collect();
|
||||
let values: Vec<f32> = (0..16).map(|i| (i as f32) * 0.2).collect();
|
||||
cache.append(&keys, &values);
|
||||
}
|
||||
|
||||
let query: Vec<f32> = (0..16).map(|i| (i as f32) * 0.05).collect();
|
||||
let scale = 1.0 / 4.0;
|
||||
|
||||
let output = paged_attention_neon(&query, &cache, &[], scale);
|
||||
|
||||
assert_eq!(output.len(), 16);
|
||||
assert!(output.iter().all(|&x| x.is_finite()));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Multi-Query Attention (MQA) Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_mqa_basic() {
|
||||
let config = AttentionConfig {
|
||||
num_heads: 8,
|
||||
num_kv_heads: 1, // MQA: single KV head
|
||||
head_dim: 16,
|
||||
causal: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let queries: Vec<f32> = (0..config.num_heads * config.head_dim)
|
||||
.map(|i| (i as f32) * 0.01)
|
||||
.collect();
|
||||
let kv_len = 4;
|
||||
let keys: Vec<f32> = (0..kv_len * config.head_dim)
|
||||
.map(|i| (i as f32) * 0.01)
|
||||
.collect();
|
||||
let values: Vec<f32> = (0..kv_len * config.head_dim)
|
||||
.map(|i| (i as f32) * 0.02)
|
||||
.collect();
|
||||
|
||||
let output = multi_query_attention_neon(&queries, &keys, &values, &config);
|
||||
|
||||
assert_eq!(output.len(), config.num_heads * config.head_dim);
|
||||
assert!(output.iter().all(|&x| x.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mqa_shared_kv() {
|
||||
// Verify that all query heads see the same K/V
|
||||
let config = AttentionConfig {
|
||||
num_heads: 4,
|
||||
num_kv_heads: 1,
|
||||
head_dim: 8,
|
||||
causal: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// All queries identical
|
||||
let query_head: Vec<f32> = vec![1.0; config.head_dim];
|
||||
let queries: Vec<f32> = query_head
|
||||
.iter()
|
||||
.cloned()
|
||||
.cycle()
|
||||
.take(config.num_heads * config.head_dim)
|
||||
.collect();
|
||||
|
||||
let kv_len = 2;
|
||||
let keys: Vec<f32> = (0..kv_len * config.head_dim)
|
||||
.map(|i| (i as f32) * 0.1)
|
||||
.collect();
|
||||
let values: Vec<f32> = (0..kv_len * config.head_dim).map(|_| 1.0).collect();
|
||||
|
||||
let output = multi_query_attention_neon(&queries, &keys, &values, &config);
|
||||
|
||||
// All output heads should be identical since all queries are identical
|
||||
let head_outputs: Vec<&[f32]> = output.chunks(config.head_dim).collect();
|
||||
for i in 1..head_outputs.len() {
|
||||
assert!(
|
||||
vectors_approx_equal(head_outputs[0], head_outputs[i], 1e-5),
|
||||
"All heads should produce same output with identical queries"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Grouped-Query Attention (GQA) Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_gqa_basic() {
|
||||
let config = AttentionConfig {
|
||||
num_heads: 8,
|
||||
num_kv_heads: 2, // GQA: 4:1 ratio
|
||||
head_dim: 16,
|
||||
causal: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(config.gqa_ratio(), 4);
|
||||
|
||||
let queries: Vec<f32> = (0..config.num_heads * config.head_dim)
|
||||
.map(|i| (i as f32) * 0.01)
|
||||
.collect();
|
||||
let kv_len = 4;
|
||||
let keys: Vec<f32> = (0..kv_len * config.num_kv_heads * config.head_dim)
|
||||
.map(|i| (i as f32) * 0.01)
|
||||
.collect();
|
||||
let values: Vec<f32> = (0..kv_len * config.num_kv_heads * config.head_dim)
|
||||
.map(|i| (i as f32) * 0.01)
|
||||
.collect();
|
||||
|
||||
let output = grouped_query_attention_neon(&queries, &keys, &values, &config);
|
||||
|
||||
assert_eq!(output.len(), config.num_heads * config.head_dim);
|
||||
assert!(output.iter().all(|&x| x.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gqa_head_grouping() {
|
||||
let config = AttentionConfig {
|
||||
num_heads: 4,
|
||||
num_kv_heads: 2, // 2:1 ratio
|
||||
head_dim: 8,
|
||||
causal: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(config.gqa_ratio(), 2);
|
||||
|
||||
// Query heads 0,1 share KV head 0
|
||||
// Query heads 2,3 share KV head 1
|
||||
|
||||
// Create distinct KV for each KV head
|
||||
let kv_len = 2;
|
||||
let mut keys = vec![0.0; kv_len * config.num_kv_heads * config.head_dim];
|
||||
let mut values = vec![0.0; kv_len * config.num_kv_heads * config.head_dim];
|
||||
|
||||
// KV head 0: all 1.0
|
||||
for t in 0..kv_len {
|
||||
let offset = t * config.num_kv_heads * config.head_dim;
|
||||
for i in 0..config.head_dim {
|
||||
keys[offset + i] = 1.0;
|
||||
values[offset + i] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
// KV head 1: all 2.0
|
||||
for t in 0..kv_len {
|
||||
let offset = t * config.num_kv_heads * config.head_dim + config.head_dim;
|
||||
for i in 0..config.head_dim {
|
||||
keys[offset + i] = 2.0;
|
||||
values[offset + i] = 2.0;
|
||||
}
|
||||
}
|
||||
|
||||
// Uniform queries
|
||||
let queries: Vec<f32> = vec![0.5; config.num_heads * config.head_dim];
|
||||
|
||||
let output = grouped_query_attention_neon(&queries, &keys, &values, &config);
|
||||
|
||||
// Heads 0,1 should have values around 1.0, heads 2,3 around 2.0
|
||||
let head_outputs: Vec<f32> = output
|
||||
.chunks(config.head_dim)
|
||||
.map(|h| h.iter().sum::<f32>() / config.head_dim as f32)
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
(head_outputs[0] - 1.0).abs() < 0.1,
|
||||
"Head 0 should use KV head 0"
|
||||
);
|
||||
assert!(
|
||||
(head_outputs[1] - 1.0).abs() < 0.1,
|
||||
"Head 1 should use KV head 0"
|
||||
);
|
||||
assert!(
|
||||
(head_outputs[2] - 2.0).abs() < 0.1,
|
||||
"Head 2 should use KV head 1"
|
||||
);
|
||||
assert!(
|
||||
(head_outputs[3] - 2.0).abs() < 0.1,
|
||||
"Head 3 should use KV head 1"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// AttentionConfig Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_attention_config_default() {
|
||||
let config = AttentionConfig::default();
|
||||
|
||||
assert_eq!(config.num_heads, 32);
|
||||
assert_eq!(config.num_kv_heads, 8);
|
||||
assert_eq!(config.head_dim, 128);
|
||||
assert!(config.causal);
|
||||
assert_eq!(config.gqa_ratio(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_config_effective_scale() {
|
||||
let config = AttentionConfig {
|
||||
head_dim: 64,
|
||||
scale: 0.0, // Auto-compute
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let expected = 1.0 / (64.0f32).sqrt();
|
||||
assert!((config.effective_scale() - expected).abs() < 1e-6);
|
||||
|
||||
// Explicit scale
|
||||
let config2 = AttentionConfig {
|
||||
head_dim: 64,
|
||||
scale: 0.2,
|
||||
..Default::default()
|
||||
};
|
||||
assert!((config2.effective_scale() - 0.2).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_config_gqa_ratios() {
|
||||
// Standard MHA (1:1)
|
||||
let mha = AttentionConfig {
|
||||
num_heads: 32,
|
||||
num_kv_heads: 32,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(mha.gqa_ratio(), 1);
|
||||
|
||||
// GQA 4:1
|
||||
let gqa_4 = AttentionConfig {
|
||||
num_heads: 32,
|
||||
num_kv_heads: 8,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(gqa_4.gqa_ratio(), 4);
|
||||
|
||||
// GQA 8:1
|
||||
let gqa_8 = AttentionConfig {
|
||||
num_heads: 32,
|
||||
num_kv_heads: 4,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(gqa_8.gqa_ratio(), 8);
|
||||
|
||||
// MQA (all heads share 1 KV)
|
||||
let mqa = AttentionConfig {
|
||||
num_heads: 32,
|
||||
num_kv_heads: 1,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(mqa.gqa_ratio(), 32);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Allocation Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_attention_no_extra_allocation() {
|
||||
let head_dim = 128;
|
||||
let kv_len = 256;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 555);
|
||||
|
||||
// Run attention multiple times
|
||||
let output1 = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
let output2 = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
let output3 = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
// Results should be identical (deterministic)
|
||||
assert!(vectors_approx_equal(&output1, &output2, 1e-6));
|
||||
assert!(vectors_approx_equal(&output2, &output3, 1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_output_size_correct() {
|
||||
let head_dim = 64;
|
||||
let kv_len = 100;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 666);
|
||||
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
assert_eq!(
|
||||
output.len(),
|
||||
head_dim,
|
||||
"Output should exactly match head_dim"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Benchmark Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_attention_benchmark_short_sequence() {
|
||||
let head_dim = 128;
|
||||
let kv_len = 64;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 777);
|
||||
|
||||
// Warm up
|
||||
for _ in 0..10 {
|
||||
let _ = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
let iterations = 1000;
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let _ = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
}
|
||||
let duration = start.elapsed();
|
||||
|
||||
let avg_us = duration.as_micros() as f64 / iterations as f64;
|
||||
assert!(
|
||||
avg_us < 1000.0,
|
||||
"Short sequence attention should be fast: {}us",
|
||||
avg_us
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_benchmark_long_sequence() {
|
||||
let head_dim = 128;
|
||||
let kv_len = 2048;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 888);
|
||||
|
||||
// Warm up
|
||||
for _ in 0..5 {
|
||||
let _ = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
let iterations = 100;
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let _ = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
}
|
||||
let duration = start.elapsed();
|
||||
|
||||
let avg_ms = duration.as_millis() as f64 / iterations as f64;
|
||||
assert!(
|
||||
avg_ms < 50.0,
|
||||
"Long sequence attention should complete in <50ms: {}ms",
|
||||
avg_ms
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_benchmark_block_sizes() {
|
||||
let head_dim = 128;
|
||||
let kv_len = 512;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let iterations = 100;
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 999);
|
||||
|
||||
// Benchmark small blocks
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let _ = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_SMALL);
|
||||
}
|
||||
let small_time = start.elapsed();
|
||||
|
||||
// Benchmark medium blocks
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let _ = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_MEDIUM);
|
||||
}
|
||||
let medium_time = start.elapsed();
|
||||
|
||||
// Benchmark large blocks
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let _ = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_LARGE);
|
||||
}
|
||||
let large_time = start.elapsed();
|
||||
|
||||
// All should complete in reasonable time
|
||||
assert!(small_time.as_millis() < 5000);
|
||||
assert!(medium_time.as_millis() < 5000);
|
||||
assert!(large_time.as_millis() < 5000);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Numerical Stability Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_attention_large_logits() {
|
||||
let head_dim = 32;
|
||||
let kv_len = 8;
|
||||
|
||||
// Create query and key that will produce large dot products
|
||||
let query = vec![10.0; head_dim];
|
||||
let key = vec![10.0; kv_len * head_dim];
|
||||
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| i as f32).collect();
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
// Output should be finite
|
||||
assert!(
|
||||
output.iter().all(|&x| x.is_finite()),
|
||||
"Should handle large dot products"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_small_values() {
|
||||
let head_dim = 32;
|
||||
let kv_len = 8;
|
||||
|
||||
// Very small values
|
||||
let query = vec![1e-6; head_dim];
|
||||
let key = vec![1e-6; kv_len * head_dim];
|
||||
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| i as f32).collect();
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
// Output should be finite
|
||||
assert!(
|
||||
output.iter().all(|&x| x.is_finite()),
|
||||
"Should handle small values"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_mixed_signs() {
|
||||
let head_dim = 32;
|
||||
let kv_len = 8;
|
||||
|
||||
// Mix of positive and negative values
|
||||
let query: Vec<f32> = (0..head_dim)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
let key: Vec<f32> = (0..kv_len * head_dim)
|
||||
.map(|i| if i % 3 == 0 { -0.5 } else { 0.5 })
|
||||
.collect();
|
||||
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.01).collect();
|
||||
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
assert!(output.iter().all(|&x| x.is_finite()));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_attention_single_head_dim() {
|
||||
let head_dim = 1;
|
||||
let kv_len = 4;
|
||||
|
||||
let query = vec![1.0];
|
||||
let key = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let value = vec![10.0, 20.0, 30.0, 40.0];
|
||||
|
||||
let scale = 1.0;
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(output[0].is_finite());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_large_head_dim() {
|
||||
let head_dim = 512;
|
||||
let kv_len = 16;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, 1111);
|
||||
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
assert_eq!(output.len(), head_dim);
|
||||
assert!(output.iter().all(|&x| x.is_finite()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_power_of_two_dims() {
|
||||
// Test common power-of-2 dimensions
|
||||
for head_dim in [32, 64, 128, 256] {
|
||||
let kv_len = 64;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, head_dim as u64);
|
||||
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
assert_eq!(output.len(), head_dim);
|
||||
assert!(
|
||||
output.iter().all(|&x| x.is_finite()),
|
||||
"Failed for head_dim={}",
|
||||
head_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attention_non_power_of_two_dims() {
|
||||
// Test non-power-of-2 dimensions
|
||||
for head_dim in [17, 33, 65, 100, 127] {
|
||||
let kv_len = 32;
|
||||
let scale = 1.0 / (head_dim as f32).sqrt();
|
||||
|
||||
let (query, key, value) = generate_test_data(head_dim, kv_len, head_dim as u64);
|
||||
|
||||
let output = flash_attention_neon(&query, &key, &value, scale, false);
|
||||
|
||||
assert_eq!(output.len(), head_dim);
|
||||
assert!(
|
||||
output.iter().all(|&x| x.is_finite()),
|
||||
"Failed for head_dim={}",
|
||||
head_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
785
vendor/ruvector/crates/ruvllm/src/tests/generation_tests.rs
vendored
Normal file
785
vendor/ruvector/crates/ruvllm/src/tests/generation_tests.rs
vendored
Normal file
@@ -0,0 +1,785 @@
|
||||
//! Token Generation Tests
|
||||
//!
|
||||
//! Tests for autoregressive token generation, sampling strategies,
|
||||
//! streaming callbacks, KV cache integration, and speculative decoding.
|
||||
|
||||
use crate::speculative::{
|
||||
log_softmax, sample_from_probs, softmax, top_k_filter, top_p_filter, AtomicSpeculativeStats,
|
||||
SpeculationTree, SpeculativeConfig, SpeculativeStats, TreeNode, VerificationResult,
|
||||
};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use std::time::Duration;
|
||||
|
||||
// ============================================================================
|
||||
// Softmax and Sampling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_softmax_produces_valid_distribution() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let probs = softmax(&logits);
|
||||
|
||||
// Sum should be 1.0
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!(
|
||||
(sum - 1.0).abs() < 1e-5,
|
||||
"Softmax sum should be 1.0, got {}",
|
||||
sum
|
||||
);
|
||||
|
||||
// All probabilities should be positive
|
||||
assert!(
|
||||
probs.iter().all(|&p| p > 0.0),
|
||||
"All probabilities should be positive"
|
||||
);
|
||||
|
||||
// Ordering should be preserved
|
||||
for i in 0..probs.len() - 1 {
|
||||
assert!(
|
||||
probs[i] < probs[i + 1],
|
||||
"Higher logits should have higher probs"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_handles_large_logits() {
|
||||
// Test numerical stability with large logits
|
||||
let logits = vec![1000.0, 1001.0, 1002.0];
|
||||
let probs = softmax(&logits);
|
||||
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!(
|
||||
(sum - 1.0).abs() < 1e-4,
|
||||
"Should handle large logits: sum = {}",
|
||||
sum
|
||||
);
|
||||
assert!(
|
||||
probs.iter().all(|p| p.is_finite()),
|
||||
"All probs should be finite"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_handles_negative_logits() {
|
||||
let logits = vec![-5.0, -3.0, -1.0, 0.0, 1.0];
|
||||
let probs = softmax(&logits);
|
||||
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-5, "Should handle negative logits");
|
||||
assert!(probs[4] > probs[0], "Larger logit should have higher prob");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_empty_input() {
|
||||
let logits: Vec<f32> = vec![];
|
||||
let probs = softmax(&logits);
|
||||
assert!(probs.is_empty(), "Empty input should return empty output");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_single_element() {
|
||||
let logits = vec![5.0];
|
||||
let probs = softmax(&logits);
|
||||
assert_eq!(probs.len(), 1);
|
||||
assert!(
|
||||
(probs[0] - 1.0).abs() < 1e-5,
|
||||
"Single element should have prob 1.0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_softmax_relationship() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let probs = softmax(&logits);
|
||||
let log_probs = log_softmax(&logits);
|
||||
|
||||
// log_softmax should equal log(softmax)
|
||||
for (lp, p) in log_probs.iter().zip(probs.iter()) {
|
||||
let expected = p.ln();
|
||||
assert!(
|
||||
(lp - expected).abs() < 1e-4,
|
||||
"log_softmax should match log(softmax)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_log_softmax_numerical_stability() {
|
||||
// log_softmax should be stable even when softmax would underflow
|
||||
let logits = vec![-1000.0, -999.0, -998.0];
|
||||
let log_probs = log_softmax(&logits);
|
||||
|
||||
assert!(
|
||||
log_probs.iter().all(|p| p.is_finite()),
|
||||
"log_softmax should handle extreme values"
|
||||
);
|
||||
// Check that relative ordering is preserved
|
||||
assert!(log_probs[0] < log_probs[1] && log_probs[1] < log_probs[2]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Top-K Filtering Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_top_k_filter_basic() {
|
||||
let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
|
||||
top_k_filter(&mut logits, 2);
|
||||
|
||||
// Only top 2 (indices 1 and 3 with values 5.0 and 4.0) should remain finite
|
||||
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
|
||||
assert_eq!(finite_count, 2, "Only top-k elements should remain");
|
||||
|
||||
// Check that correct elements are kept
|
||||
assert!(logits[1].is_finite(), "5.0 should remain");
|
||||
assert!(logits[3].is_finite(), "4.0 should remain");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_filter_k_greater_than_length() {
|
||||
let mut logits = vec![1.0, 2.0, 3.0];
|
||||
top_k_filter(&mut logits, 10);
|
||||
|
||||
// All should remain unchanged
|
||||
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
|
||||
assert_eq!(finite_count, 3, "All should remain when k > length");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_filter_k_zero() {
|
||||
let mut logits = vec![1.0, 2.0, 3.0];
|
||||
top_k_filter(&mut logits, 0);
|
||||
|
||||
// All should remain unchanged
|
||||
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
|
||||
assert_eq!(finite_count, 3, "All should remain when k = 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_filter_k_one() {
|
||||
let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
|
||||
top_k_filter(&mut logits, 1);
|
||||
|
||||
// Only the maximum should remain
|
||||
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
|
||||
assert_eq!(finite_count, 1, "Only one element should remain");
|
||||
assert!(logits[1].is_finite(), "Maximum (5.0) should remain");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Top-P (Nucleus) Filtering Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_top_p_filter_basic() {
|
||||
// Create logits where first element dominates
|
||||
let mut logits = vec![10.0, 1.0, 0.0, -1.0, -2.0];
|
||||
top_p_filter(&mut logits, 0.9);
|
||||
|
||||
// At least the highest probability token should remain
|
||||
assert!(logits[0].is_finite(), "Highest prob token should remain");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_p_filter_p_one() {
|
||||
let mut logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let original = logits.clone();
|
||||
top_p_filter(&mut logits, 1.0);
|
||||
|
||||
// All should remain unchanged when p >= 1.0
|
||||
assert_eq!(logits, original, "All should remain when p = 1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_p_filter_p_zero() {
|
||||
let mut logits = vec![1.0, 2.0, 3.0];
|
||||
top_p_filter(&mut logits, 0.0);
|
||||
|
||||
// Only top token should remain
|
||||
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
|
||||
assert!(finite_count >= 1, "At least one token should remain");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Sampling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_sample_from_probs_deterministic() {
|
||||
let probs = vec![0.0, 0.0, 1.0, 0.0]; // Deterministic: only index 2
|
||||
let mut rng = StdRng::seed_from_u64(12345);
|
||||
|
||||
for _ in 0..10 {
|
||||
let idx = sample_from_probs(&probs, &mut rng);
|
||||
assert_eq!(idx, 2, "Should always sample index 2");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_from_probs_uniform() {
|
||||
let probs = vec![0.25, 0.25, 0.25, 0.25];
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let mut counts = vec![0usize; 4];
|
||||
|
||||
// Sample many times
|
||||
for _ in 0..10000 {
|
||||
let idx = sample_from_probs(&probs, &mut rng);
|
||||
counts[idx] += 1;
|
||||
}
|
||||
|
||||
// Each should be sampled approximately 2500 times
|
||||
for (i, &count) in counts.iter().enumerate() {
|
||||
let expected = 2500.0;
|
||||
let actual = count as f64;
|
||||
let ratio = actual / expected;
|
||||
assert!(
|
||||
(0.8..=1.2).contains(&ratio),
|
||||
"Index {} should be sampled uniformly, got {} (expected ~{})",
|
||||
i,
|
||||
count,
|
||||
expected
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_from_probs_skewed() {
|
||||
let probs = vec![0.9, 0.05, 0.03, 0.02]; // Heavily skewed
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let mut counts = vec![0usize; 4];
|
||||
|
||||
for _ in 0..1000 {
|
||||
let idx = sample_from_probs(&probs, &mut rng);
|
||||
counts[idx] += 1;
|
||||
}
|
||||
|
||||
// Index 0 should dominate
|
||||
assert!(counts[0] > 800, "Index 0 should be sampled most often");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Temperature Scaling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_temperature_scaling_sharpens() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let temperature = 0.1; // Low temperature -> sharper distribution
|
||||
|
||||
let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
|
||||
let probs = softmax(&scaled);
|
||||
|
||||
// Highest logit should have much higher probability
|
||||
assert!(
|
||||
probs[3] > 0.99,
|
||||
"Low temperature should concentrate probability on max"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_scaling_flattens() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let temperature = 10.0; // High temperature -> flatter distribution
|
||||
|
||||
let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
|
||||
let probs = softmax(&scaled);
|
||||
|
||||
// Distribution should be more uniform
|
||||
let min_prob = probs.iter().cloned().fold(f32::INFINITY, f32::min);
|
||||
let max_prob = probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
assert!(
|
||||
max_prob - min_prob < 0.2,
|
||||
"High temperature should flatten distribution"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temperature_one_unchanged() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let temperature = 1.0;
|
||||
|
||||
let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
|
||||
let probs1 = softmax(&logits);
|
||||
let probs2 = softmax(&scaled);
|
||||
|
||||
for (p1, p2) in probs1.iter().zip(probs2.iter()) {
|
||||
assert!(
|
||||
(p1 - p2).abs() < 1e-6,
|
||||
"Temperature 1.0 should not change distribution"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Speculative Decoding Config Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_speculative_config_default() {
|
||||
let config = SpeculativeConfig::default();
|
||||
|
||||
assert_eq!(config.lookahead, 4);
|
||||
assert!((config.acceptance_threshold - 0.5).abs() < 0.01);
|
||||
assert_eq!(config.draft_temperature, 0.0);
|
||||
assert!(!config.tree_speculation);
|
||||
assert!(config.adaptive_lookahead);
|
||||
assert_eq!(config.min_lookahead, 2);
|
||||
assert_eq!(config.max_lookahead, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_config_custom() {
|
||||
let config = SpeculativeConfig {
|
||||
lookahead: 8,
|
||||
acceptance_threshold: 0.7,
|
||||
draft_temperature: 0.3,
|
||||
tree_speculation: true,
|
||||
max_tree_depth: 4,
|
||||
tree_branching_factor: 3,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(config.lookahead, 8);
|
||||
assert!((config.acceptance_threshold - 0.7).abs() < 0.01);
|
||||
assert!(config.tree_speculation);
|
||||
assert_eq!(config.max_tree_depth, 4);
|
||||
assert_eq!(config.tree_branching_factor, 3);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Speculative Stats Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_speculative_stats_new() {
|
||||
let stats = SpeculativeStats::new();
|
||||
|
||||
assert_eq!(stats.draft_tokens, 0);
|
||||
assert_eq!(stats.accepted_tokens, 0);
|
||||
assert_eq!(stats.acceptance_rate, 0.0);
|
||||
assert_eq!(stats.speedup, 0.0);
|
||||
assert_eq!(stats.main_forward_passes, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_stats_record_round() {
|
||||
let mut stats = SpeculativeStats::new();
|
||||
|
||||
// Record a round with 4 drafts, 3 accepted
|
||||
stats.record_round(4, 3, 10.0);
|
||||
|
||||
assert_eq!(stats.draft_tokens, 4);
|
||||
assert_eq!(stats.accepted_tokens, 3);
|
||||
assert!((stats.acceptance_rate - 0.75).abs() < 0.01);
|
||||
assert_eq!(stats.main_forward_passes, 1);
|
||||
assert_eq!(stats.total_tokens_generated, 4); // 3 accepted + 1 correction
|
||||
assert!((stats.total_speculation_time_ms - 10.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_stats_multiple_rounds() {
|
||||
let mut stats = SpeculativeStats::new();
|
||||
|
||||
// Round 1: 4 drafts, 4 accepted (100% acceptance)
|
||||
stats.record_round(4, 4, 10.0);
|
||||
|
||||
// Round 2: 4 drafts, 2 accepted (50% acceptance)
|
||||
stats.record_round(4, 2, 15.0);
|
||||
|
||||
assert_eq!(stats.draft_tokens, 8);
|
||||
assert_eq!(stats.accepted_tokens, 6);
|
||||
assert!((stats.acceptance_rate - 0.75).abs() < 0.01); // 6/8 = 0.75
|
||||
assert_eq!(stats.main_forward_passes, 2);
|
||||
// Total tokens depends on implementation - just check it's reasonable
|
||||
assert!(
|
||||
stats.total_tokens_generated >= 6,
|
||||
"Should generate at least accepted tokens"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_stats_reset() {
|
||||
let mut stats = SpeculativeStats::new();
|
||||
stats.record_round(4, 3, 10.0);
|
||||
stats.reset();
|
||||
|
||||
assert_eq!(stats.draft_tokens, 0);
|
||||
assert_eq!(stats.accepted_tokens, 0);
|
||||
assert_eq!(stats.acceptance_rate, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculative_stats_speedup_calculation() {
|
||||
let mut stats = SpeculativeStats::new();
|
||||
|
||||
// If we accept 4 tokens per main pass on average, speedup should be ~4x
|
||||
stats.record_round(4, 4, 10.0);
|
||||
stats.record_round(4, 4, 10.0);
|
||||
|
||||
// 10 total tokens, 2 main passes -> 5 tokens/pass
|
||||
assert!(
|
||||
stats.speedup > 4.0,
|
||||
"Speedup should reflect tokens per main pass"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Atomic Speculative Stats Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_atomic_stats_new() {
|
||||
let stats = AtomicSpeculativeStats::new();
|
||||
let snapshot = stats.snapshot();
|
||||
|
||||
assert_eq!(snapshot.draft_tokens, 0);
|
||||
assert_eq!(snapshot.accepted_tokens, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_stats_record_round() {
|
||||
let stats = AtomicSpeculativeStats::new();
|
||||
stats.record_round(4, 3, Duration::from_millis(10));
|
||||
|
||||
let snapshot = stats.snapshot();
|
||||
assert_eq!(snapshot.draft_tokens, 4);
|
||||
assert_eq!(snapshot.accepted_tokens, 3);
|
||||
assert!((snapshot.acceptance_rate - 0.75).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_stats_thread_safe() {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let stats = Arc::new(AtomicSpeculativeStats::new());
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn multiple threads recording rounds
|
||||
for _ in 0..10 {
|
||||
let stats_clone = Arc::clone(&stats);
|
||||
handles.push(thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
stats_clone.record_round(4, 3, Duration::from_millis(1));
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
let snapshot = stats.snapshot();
|
||||
assert_eq!(snapshot.draft_tokens, 4000); // 10 threads * 100 rounds * 4 drafts
|
||||
assert_eq!(snapshot.accepted_tokens, 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_stats_reset() {
|
||||
let stats = AtomicSpeculativeStats::new();
|
||||
stats.record_round(4, 3, Duration::from_millis(10));
|
||||
stats.reset();
|
||||
|
||||
let snapshot = stats.snapshot();
|
||||
assert_eq!(snapshot.draft_tokens, 0);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tree Node Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_tree_node_new() {
|
||||
let node = TreeNode::new(42, 0.8, 0);
|
||||
|
||||
assert_eq!(node.token, 42);
|
||||
assert!((node.prob - 0.8).abs() < 0.01);
|
||||
assert!((node.logprob - 0.8f32.ln()).abs() < 0.01);
|
||||
assert_eq!(node.depth, 0);
|
||||
assert!(node.children.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node_add_child() {
|
||||
let mut root = TreeNode::new(0, 1.0, 0);
|
||||
|
||||
let child1 = root.add_child(1, 0.6);
|
||||
assert_eq!(child1.token, 1);
|
||||
assert_eq!(child1.depth, 1);
|
||||
|
||||
let child2 = root.add_child(2, 0.4);
|
||||
assert_eq!(child2.token, 2);
|
||||
|
||||
assert_eq!(root.children.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node_get_paths() {
|
||||
let mut root = TreeNode::new(0, 1.0, 0);
|
||||
|
||||
// Build a tree:
|
||||
// 0
|
||||
// / \
|
||||
// 1 2
|
||||
// /
|
||||
// 3
|
||||
|
||||
{
|
||||
let child1 = root.add_child(1, 0.6);
|
||||
child1.add_child(3, 0.5);
|
||||
}
|
||||
root.add_child(2, 0.4);
|
||||
|
||||
let paths = root.get_paths();
|
||||
assert_eq!(paths.len(), 2);
|
||||
|
||||
// Should have paths [0, 1, 3] and [0, 2]
|
||||
assert!(paths.iter().any(|p| p == &vec![0, 1, 3]));
|
||||
assert!(paths.iter().any(|p| p == &vec![0, 2]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node_best_path() {
|
||||
let mut root = TreeNode::new(0, 1.0, 0);
|
||||
|
||||
// Build tree with different probabilities
|
||||
{
|
||||
let child1 = root.add_child(1, 0.6);
|
||||
child1.add_child(3, 0.5);
|
||||
}
|
||||
root.add_child(2, 0.4);
|
||||
|
||||
let best = root.best_path();
|
||||
// Should follow highest probability children: 0 -> 1 -> 3
|
||||
assert_eq!(best, vec![0, 1, 3]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Speculation Tree Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_speculation_tree_new() {
|
||||
let tree = SpeculationTree::new(3, 2);
|
||||
|
||||
assert_eq!(tree.max_depth, 3);
|
||||
assert_eq!(tree.branching_factor, 2);
|
||||
assert_eq!(tree.node_count, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_tree_clear() {
|
||||
let mut tree = SpeculationTree::new(3, 2);
|
||||
tree.root.add_child(1, 0.5);
|
||||
tree.node_count += 1;
|
||||
|
||||
tree.clear();
|
||||
|
||||
assert_eq!(tree.node_count, 1);
|
||||
assert!(tree.root.children.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_tree_best_path_empty() {
|
||||
let tree = SpeculationTree::new(3, 2);
|
||||
let path = tree.best_path();
|
||||
|
||||
assert!(path.is_empty(), "Empty tree should have empty best path");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_speculation_tree_best_path_linear() {
|
||||
let mut tree = SpeculationTree::new(4, 2);
|
||||
|
||||
// Build linear path: root -> 1 -> 2 -> 3
|
||||
let node1 = tree.root.add_child(1, 0.8);
|
||||
tree.node_count += 1;
|
||||
let node2 = node1.add_child(2, 0.7);
|
||||
tree.node_count += 1;
|
||||
node2.add_child(3, 0.6);
|
||||
tree.node_count += 1;
|
||||
|
||||
let path = tree.best_path();
|
||||
assert_eq!(path, vec![1, 2, 3]);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Verification Result Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_verification_result_all_accepted() {
|
||||
let result = VerificationResult {
|
||||
accepted_count: 4,
|
||||
next_token: 100,
|
||||
accepted_logprobs: vec![-0.1, -0.2, -0.1, -0.15],
|
||||
next_logprob: -0.3,
|
||||
all_accepted: true,
|
||||
};
|
||||
|
||||
assert_eq!(result.accepted_count, 4);
|
||||
assert_eq!(result.next_token, 100);
|
||||
assert!(result.all_accepted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verification_result_partial_accept() {
|
||||
let result = VerificationResult {
|
||||
accepted_count: 2,
|
||||
next_token: 50, // Correction token
|
||||
accepted_logprobs: vec![-0.1, -0.2],
|
||||
next_logprob: -0.5,
|
||||
all_accepted: false,
|
||||
};
|
||||
|
||||
assert_eq!(result.accepted_count, 2);
|
||||
assert!(!result.all_accepted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_verification_result_none_accepted() {
|
||||
let result = VerificationResult {
|
||||
accepted_count: 0,
|
||||
next_token: 25, // Immediate correction
|
||||
accepted_logprobs: vec![],
|
||||
next_logprob: -0.4,
|
||||
all_accepted: false,
|
||||
};
|
||||
|
||||
assert_eq!(result.accepted_count, 0);
|
||||
assert!(result.accepted_logprobs.is_empty());
|
||||
assert!(!result.all_accepted);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Integration Sampling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_full_sampling_pipeline() {
|
||||
// Test basic sampling pipeline functionality
|
||||
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
|
||||
// Convert to probabilities
|
||||
let probs = softmax(&logits);
|
||||
|
||||
// Verify softmax produces valid distribution
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!(
|
||||
(sum - 1.0).abs() < 1e-4,
|
||||
"Softmax should sum to 1.0, got {}",
|
||||
sum
|
||||
);
|
||||
assert!(
|
||||
probs.iter().all(|&p| p > 0.0),
|
||||
"All probabilities should be positive"
|
||||
);
|
||||
|
||||
// Sample with fixed RNG
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
let mut samples = vec![0usize; 5];
|
||||
for _ in 0..1000 {
|
||||
let idx = sample_from_probs(&probs, &mut rng);
|
||||
if idx < samples.len() {
|
||||
samples[idx] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Higher logits should be sampled more frequently on average
|
||||
let total_samples: usize = samples.iter().sum();
|
||||
assert_eq!(total_samples, 1000, "Should have 1000 total samples");
|
||||
|
||||
// Higher indices (higher logits) should be more frequent
|
||||
// This is a statistical test - with 1000 samples, index 4 (highest logit)
|
||||
// should be sampled more often than index 0 (lowest logit)
|
||||
assert!(
|
||||
samples[4] > samples[0],
|
||||
"Higher logit should be sampled more: idx4={}, idx0={}",
|
||||
samples[4],
|
||||
samples[0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greedy_decoding_simulation() {
|
||||
// Simulate greedy decoding (temperature = 0 equivalent)
|
||||
let logits = vec![1.0, 3.0, 2.0, 5.0, 4.0];
|
||||
|
||||
// Greedy: pick argmax
|
||||
let argmax = logits
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(argmax, 3, "Greedy should select index 3 (value 5.0)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beam_search_simulation() {
|
||||
// Simulate a simple beam search step
|
||||
let beam_width = 3;
|
||||
let logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
|
||||
|
||||
// Get top-k indices
|
||||
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
let top_indices: Vec<usize> = indexed.iter().take(beam_width).map(|(i, _)| *i).collect();
|
||||
|
||||
assert_eq!(
|
||||
top_indices,
|
||||
vec![1, 3, 2],
|
||||
"Top-3 should be indices 1, 3, 2"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases and Error Handling
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_softmax_with_inf() {
|
||||
let logits = vec![f32::NEG_INFINITY, 1.0, 2.0];
|
||||
let probs = softmax(&logits);
|
||||
|
||||
// First element should have probability ~0
|
||||
assert!(
|
||||
probs[0] < 1e-10 || probs[0].abs() < 1e-10,
|
||||
"NEG_INFINITY should give ~0 probability"
|
||||
);
|
||||
|
||||
// Sum should still be ~1
|
||||
let sum: f32 = probs.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 1e-4, "Sum should be 1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_numerical_edge_case() {
|
||||
// Probabilities that might cause issues
|
||||
let probs = vec![0.9999999, 0.0000001];
|
||||
let mut rng = StdRng::seed_from_u64(42);
|
||||
|
||||
// Should not panic
|
||||
for _ in 0..100 {
|
||||
let idx = sample_from_probs(&probs, &mut rng);
|
||||
assert!(idx < 2, "Index should be valid");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_top_k_with_ties() {
|
||||
let mut logits = vec![5.0, 5.0, 5.0, 1.0, 2.0];
|
||||
top_k_filter(&mut logits, 3);
|
||||
|
||||
// All three 5.0s should remain
|
||||
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
|
||||
assert!(
|
||||
finite_count >= 3,
|
||||
"Should keep at least k elements when ties exist"
|
||||
);
|
||||
}
|
||||
736
vendor/ruvector/crates/ruvllm/src/tests/gguf_tests.rs
vendored
Normal file
736
vendor/ruvector/crates/ruvllm/src/tests/gguf_tests.rs
vendored
Normal file
@@ -0,0 +1,736 @@
|
||||
//! GGUF Loading Tests
|
||||
//!
|
||||
//! Tests for GGUF header/metadata parsing, tensor loading, quantization
|
||||
//! format handling, architecture detection, memory mapping, and error handling.
|
||||
|
||||
use crate::gguf::parser::GgufValueType;
|
||||
use crate::gguf::{
|
||||
parse_header, parse_metadata, GgufHeader, GgufQuantType, GgufValue, GGUF_MAGIC, GGUF_VERSION,
|
||||
};
|
||||
use std::io::Cursor;
|
||||
|
||||
// ============================================================================
|
||||
// Header Parsing Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_parse_valid_header() {
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes()); // magic
|
||||
data.extend_from_slice(&GGUF_VERSION.to_le_bytes()); // version
|
||||
data.extend_from_slice(&10u64.to_le_bytes()); // tensor_count
|
||||
data.extend_from_slice(&5u64.to_le_bytes()); // metadata_kv_count
|
||||
|
||||
let mut cursor = Cursor::new(data);
|
||||
let header = parse_header(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(header.magic, GGUF_MAGIC);
|
||||
assert_eq!(header.version, GGUF_VERSION);
|
||||
assert_eq!(header.tensor_count, 10);
|
||||
assert_eq!(header.metadata_kv_count, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gguf_magic_is_correct() {
|
||||
// "GGUF" in little-endian bytes
|
||||
let expected = 0x46554747u32;
|
||||
assert_eq!(GGUF_MAGIC, expected);
|
||||
|
||||
// Verify it spells "GGUF"
|
||||
let bytes = GGUF_MAGIC.to_le_bytes();
|
||||
assert_eq!(&bytes, b"GGUF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_header_truncated() {
|
||||
// Only provide partial header
|
||||
let data = vec![0x47, 0x47, 0x55, 0x46]; // Just magic
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
let result = parse_header(&mut cursor);
|
||||
assert!(result.is_err(), "Truncated header should fail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_header_empty() {
|
||||
let data: Vec<u8> = vec![];
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
let result = parse_header(&mut cursor);
|
||||
assert!(result.is_err(), "Empty input should fail");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GgufValue Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_gguf_value_string() {
|
||||
let val = GgufValue::String("test_value".to_string());
|
||||
|
||||
assert_eq!(val.as_str(), Some("test_value"));
|
||||
assert_eq!(val.as_u64(), None);
|
||||
assert_eq!(val.as_i64(), None);
|
||||
assert_eq!(val.as_f32(), None);
|
||||
assert_eq!(val.as_bool(), None);
|
||||
assert!(val.as_array().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gguf_value_integer_conversions() {
|
||||
// Test U32
|
||||
let val = GgufValue::U32(42);
|
||||
assert_eq!(val.as_u64(), Some(42));
|
||||
assert_eq!(val.as_i64(), Some(42));
|
||||
assert_eq!(val.as_f32(), Some(42.0));
|
||||
assert_eq!(val.as_str(), None);
|
||||
|
||||
// Test I32
|
||||
let val = GgufValue::I32(-5);
|
||||
assert_eq!(val.as_i64(), Some(-5));
|
||||
assert_eq!(val.as_u64(), None); // Negative cannot be u64
|
||||
|
||||
// Test U64
|
||||
let val = GgufValue::U64(u64::MAX);
|
||||
assert_eq!(val.as_u64(), Some(u64::MAX));
|
||||
assert_eq!(val.as_i64(), None); // Too large for i64
|
||||
|
||||
// Test I64
|
||||
let val = GgufValue::I64(-100);
|
||||
assert_eq!(val.as_i64(), Some(-100));
|
||||
assert_eq!(val.as_u64(), None);
|
||||
|
||||
// Test I64 positive
|
||||
let val = GgufValue::I64(100);
|
||||
assert_eq!(val.as_i64(), Some(100));
|
||||
assert_eq!(val.as_u64(), Some(100));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gguf_value_float_conversions() {
|
||||
// Test F32
|
||||
let val = GgufValue::F32(3.14);
|
||||
assert!((val.as_f32().unwrap() - 3.14).abs() < 0.001);
|
||||
assert!((val.as_f64().unwrap() - 3.14).abs() < 0.001);
|
||||
assert_eq!(val.as_str(), None);
|
||||
|
||||
// Test F64
|
||||
let val = GgufValue::F64(2.71828);
|
||||
assert!((val.as_f64().unwrap() - 2.71828).abs() < 0.00001);
|
||||
assert!((val.as_f32().unwrap() - 2.71828).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gguf_value_bool() {
|
||||
let val_true = GgufValue::Bool(true);
|
||||
let val_false = GgufValue::Bool(false);
|
||||
|
||||
assert_eq!(val_true.as_bool(), Some(true));
|
||||
assert_eq!(val_false.as_bool(), Some(false));
|
||||
assert_eq!(val_true.as_str(), None);
|
||||
|
||||
// Test implicit bool from U8
|
||||
let val = GgufValue::U8(1);
|
||||
assert_eq!(val.as_bool(), Some(true));
|
||||
|
||||
let val = GgufValue::U8(0);
|
||||
assert_eq!(val.as_bool(), Some(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gguf_value_array() {
|
||||
let arr = vec![GgufValue::U32(1), GgufValue::U32(2), GgufValue::U32(3)];
|
||||
let val = GgufValue::Array(arr);
|
||||
|
||||
let array = val.as_array().unwrap();
|
||||
assert_eq!(array.len(), 3);
|
||||
assert_eq!(array[0].as_u64(), Some(1));
|
||||
assert_eq!(array[1].as_u64(), Some(2));
|
||||
assert_eq!(array[2].as_u64(), Some(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gguf_value_small_integers() {
|
||||
// Test U8
|
||||
let val = GgufValue::U8(255);
|
||||
assert_eq!(val.as_u64(), Some(255));
|
||||
|
||||
// Test I8
|
||||
let val = GgufValue::I8(-128);
|
||||
assert_eq!(val.as_i64(), Some(-128));
|
||||
assert_eq!(val.as_u64(), None);
|
||||
|
||||
// Test U16
|
||||
let val = GgufValue::U16(65535);
|
||||
assert_eq!(val.as_u64(), Some(65535));
|
||||
|
||||
// Test I16
|
||||
let val = GgufValue::I16(-32768);
|
||||
assert_eq!(val.as_i64(), Some(-32768));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GgufValueType Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_value_type_conversion() {
|
||||
assert_eq!(GgufValueType::try_from(0).unwrap(), GgufValueType::U8);
|
||||
assert_eq!(GgufValueType::try_from(1).unwrap(), GgufValueType::I8);
|
||||
assert_eq!(GgufValueType::try_from(2).unwrap(), GgufValueType::U16);
|
||||
assert_eq!(GgufValueType::try_from(3).unwrap(), GgufValueType::I16);
|
||||
assert_eq!(GgufValueType::try_from(4).unwrap(), GgufValueType::U32);
|
||||
assert_eq!(GgufValueType::try_from(5).unwrap(), GgufValueType::I32);
|
||||
assert_eq!(GgufValueType::try_from(6).unwrap(), GgufValueType::F32);
|
||||
assert_eq!(GgufValueType::try_from(7).unwrap(), GgufValueType::Bool);
|
||||
assert_eq!(GgufValueType::try_from(8).unwrap(), GgufValueType::String);
|
||||
assert_eq!(GgufValueType::try_from(9).unwrap(), GgufValueType::Array);
|
||||
assert_eq!(GgufValueType::try_from(10).unwrap(), GgufValueType::U64);
|
||||
assert_eq!(GgufValueType::try_from(11).unwrap(), GgufValueType::I64);
|
||||
assert_eq!(GgufValueType::try_from(12).unwrap(), GgufValueType::F64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_value_type_invalid() {
|
||||
assert!(GgufValueType::try_from(13).is_err());
|
||||
assert!(GgufValueType::try_from(100).is_err());
|
||||
assert!(GgufValueType::try_from(255).is_err());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Quantization Type Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_quant_type_from_u32() {
|
||||
assert!(GgufQuantType::try_from(0u32).is_ok()); // F32
|
||||
assert!(GgufQuantType::try_from(1u32).is_ok()); // F16
|
||||
assert!(GgufQuantType::try_from(2u32).is_ok()); // Q4_0
|
||||
assert!(GgufQuantType::try_from(3u32).is_ok()); // Q4_1
|
||||
assert!(GgufQuantType::try_from(8u32).is_ok()); // Q8_0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quant_type_block_size() {
|
||||
assert_eq!(GgufQuantType::F32.block_size(), 1);
|
||||
assert_eq!(GgufQuantType::F16.block_size(), 1);
|
||||
assert_eq!(GgufQuantType::Q4_0.block_size(), 32);
|
||||
assert_eq!(GgufQuantType::Q4_1.block_size(), 32);
|
||||
assert_eq!(GgufQuantType::Q8_0.block_size(), 32);
|
||||
assert_eq!(GgufQuantType::Q4_K.block_size(), 256);
|
||||
assert_eq!(GgufQuantType::Q2_K.block_size(), 256);
|
||||
assert_eq!(GgufQuantType::Q3_K.block_size(), 256);
|
||||
assert_eq!(GgufQuantType::Q5_K.block_size(), 256);
|
||||
assert_eq!(GgufQuantType::Q6_K.block_size(), 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quant_type_type_size() {
|
||||
// F32: 4 bytes per element, 1 element per block
|
||||
assert_eq!(GgufQuantType::F32.type_size(), 4);
|
||||
|
||||
// F16: 2 bytes per element, 1 element per block
|
||||
assert_eq!(GgufQuantType::F16.type_size(), 2);
|
||||
|
||||
// Q4_0: 2 bytes scale + 16 bytes data (32 elements * 4 bits / 8) = 18 bytes
|
||||
assert_eq!(GgufQuantType::Q4_0.type_size(), 18);
|
||||
|
||||
// Q4_1: 2 bytes scale + 2 bytes min + 16 bytes data = 20 bytes
|
||||
assert_eq!(GgufQuantType::Q4_1.type_size(), 20);
|
||||
|
||||
// Q8_0: 2 bytes scale + 32 bytes data = 34 bytes
|
||||
assert_eq!(GgufQuantType::Q8_0.type_size(), 34);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quant_type_is_quantized() {
|
||||
assert!(!GgufQuantType::F32.is_quantized());
|
||||
assert!(!GgufQuantType::F16.is_quantized());
|
||||
|
||||
assert!(GgufQuantType::Q4_0.is_quantized());
|
||||
assert!(GgufQuantType::Q4_1.is_quantized());
|
||||
assert!(GgufQuantType::Q8_0.is_quantized());
|
||||
assert!(GgufQuantType::Q4_K.is_quantized());
|
||||
assert!(GgufQuantType::Q2_K.is_quantized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quant_type_bits_per_weight() {
|
||||
// bits_per_weight returns f32
|
||||
assert!((GgufQuantType::F32.bits_per_weight() - 32.0).abs() < 0.1);
|
||||
assert!((GgufQuantType::F16.bits_per_weight() - 16.0).abs() < 0.1);
|
||||
// Q8_0: 34 bytes * 8 / 32 elements = 8.5 bits
|
||||
assert!((GgufQuantType::Q8_0.bits_per_weight() - 8.5).abs() < 0.1);
|
||||
|
||||
// Q4_0: (18 bytes * 8 bits) / 32 elements = 4.5 bits
|
||||
let q4_bits =
|
||||
(GgufQuantType::Q4_0.type_size() * 8) as f32 / GgufQuantType::Q4_0.block_size() as f32;
|
||||
assert!((q4_bits - 4.5).abs() < 0.1);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Architecture Detection Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_architecture_metadata_key() {
|
||||
// Verify common architecture metadata keys
|
||||
let arch_keys = [
|
||||
"general.architecture",
|
||||
"llama.context_length",
|
||||
"llama.embedding_length",
|
||||
"llama.attention.head_count",
|
||||
"llama.attention.head_count_kv",
|
||||
"llama.block_count",
|
||||
"llama.rope.freq_base",
|
||||
"mistral.context_length",
|
||||
"phi.context_length",
|
||||
];
|
||||
|
||||
for key in &arch_keys {
|
||||
// Just verify the key format is valid
|
||||
assert!(!key.is_empty());
|
||||
assert!(key.contains('.') || key.starts_with("general"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_architecture_detection_patterns() {
|
||||
// Test architecture pattern matching logic
|
||||
let arch_patterns = [
|
||||
("llama", "llama"),
|
||||
("mistral", "mistral"),
|
||||
("phi", "phi"),
|
||||
("phi2", "phi"),
|
||||
("phi3", "phi"),
|
||||
("qwen", "qwen"),
|
||||
("qwen2", "qwen"),
|
||||
("gemma", "gemma"),
|
||||
];
|
||||
|
||||
for (input, expected_prefix) in &arch_patterns {
|
||||
let normalized = input.to_lowercase();
|
||||
assert!(
|
||||
normalized.starts_with(expected_prefix) || normalized.contains(expected_prefix),
|
||||
"{} should match {} pattern",
|
||||
input,
|
||||
expected_prefix
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Metadata Parsing Tests
|
||||
// ============================================================================
|
||||
|
||||
fn build_metadata_entry(key: &str, value_type: u32, value_bytes: &[u8]) -> Vec<u8> {
|
||||
let mut data = vec![];
|
||||
|
||||
// Key: length (u64) + bytes
|
||||
data.extend_from_slice(&(key.len() as u64).to_le_bytes());
|
||||
data.extend_from_slice(key.as_bytes());
|
||||
|
||||
// Value type
|
||||
data.extend_from_slice(&value_type.to_le_bytes());
|
||||
|
||||
// Value data
|
||||
data.extend_from_slice(value_bytes);
|
||||
|
||||
data
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_metadata_u32() {
|
||||
let key = "test.value";
|
||||
let value = 12345u32;
|
||||
|
||||
let data = build_metadata_entry(key, 4, &value.to_le_bytes());
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
let metadata = parse_metadata(&mut cursor, 1).unwrap();
|
||||
|
||||
assert!(metadata.contains_key(key));
|
||||
assert_eq!(metadata.get(key).unwrap().as_u64(), Some(12345));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_metadata_f32() {
|
||||
let key = "test.float";
|
||||
let value = 3.14159f32;
|
||||
|
||||
let data = build_metadata_entry(key, 6, &value.to_le_bytes());
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
let metadata = parse_metadata(&mut cursor, 1).unwrap();
|
||||
|
||||
let parsed = metadata.get(key).unwrap().as_f32().unwrap();
|
||||
assert!((parsed - 3.14159).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_metadata_string() {
|
||||
let key = "test.name";
|
||||
let value = "hello_world";
|
||||
|
||||
let mut value_bytes = vec![];
|
||||
value_bytes.extend_from_slice(&(value.len() as u64).to_le_bytes());
|
||||
value_bytes.extend_from_slice(value.as_bytes());
|
||||
|
||||
let data = build_metadata_entry(key, 8, &value_bytes);
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
let metadata = parse_metadata(&mut cursor, 1).unwrap();
|
||||
|
||||
assert_eq!(metadata.get(key).unwrap().as_str(), Some("hello_world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_metadata_bool() {
|
||||
let key = "test.enabled";
|
||||
let value = 1u8;
|
||||
|
||||
let data = build_metadata_entry(key, 7, &[value]);
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
let metadata = parse_metadata(&mut cursor, 1).unwrap();
|
||||
|
||||
assert_eq!(metadata.get(key).unwrap().as_bool(), Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_metadata_multiple_entries() {
|
||||
let mut data = vec![];
|
||||
|
||||
// Entry 1: U32
|
||||
data.extend(build_metadata_entry("key1", 4, &42u32.to_le_bytes()));
|
||||
|
||||
// Entry 2: F32
|
||||
data.extend(build_metadata_entry("key2", 6, &1.5f32.to_le_bytes()));
|
||||
|
||||
let mut cursor = Cursor::new(data);
|
||||
let metadata = parse_metadata(&mut cursor, 2).unwrap();
|
||||
|
||||
assert_eq!(metadata.len(), 2);
|
||||
assert_eq!(metadata.get("key1").unwrap().as_u64(), Some(42));
|
||||
assert!((metadata.get("key2").unwrap().as_f32().unwrap() - 1.5).abs() < 0.001);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_parse_metadata_truncated_key() {
|
||||
// Key length says 100 but only provide 5 bytes
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(&100u64.to_le_bytes()); // Key length
|
||||
data.extend_from_slice(b"test"); // Only 4 bytes
|
||||
|
||||
let mut cursor = Cursor::new(data);
|
||||
let result = parse_metadata(&mut cursor, 1);
|
||||
|
||||
assert!(result.is_err(), "Truncated key should fail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_metadata_invalid_value_type() {
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(&4u64.to_le_bytes()); // Key length
|
||||
data.extend_from_slice(b"test");
|
||||
data.extend_from_slice(&255u32.to_le_bytes()); // Invalid type
|
||||
|
||||
let mut cursor = Cursor::new(data);
|
||||
let result = parse_metadata(&mut cursor, 1);
|
||||
|
||||
assert!(result.is_err(), "Invalid value type should fail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_too_long_protection() {
|
||||
// Attempt to create a string entry with unreasonable length
|
||||
let key = "malicious.string";
|
||||
let claimed_len = 10_000_000u64; // 10MB string
|
||||
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(&(key.len() as u64).to_le_bytes());
|
||||
data.extend_from_slice(key.as_bytes());
|
||||
data.extend_from_slice(&8u32.to_le_bytes()); // String type
|
||||
data.extend_from_slice(&claimed_len.to_le_bytes());
|
||||
// Don't actually provide the data
|
||||
|
||||
let mut cursor = Cursor::new(data);
|
||||
let result = parse_metadata(&mut cursor, 1);
|
||||
|
||||
assert!(result.is_err(), "Unreasonably long string should fail");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// TensorInfo Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_tensor_info_byte_size() {
|
||||
use crate::gguf::tensors::TensorInfo;
|
||||
|
||||
// F32 tensor: 1024 elements * 4 bytes
|
||||
let info = TensorInfo {
|
||||
name: "test.weight".to_string(),
|
||||
shape: vec![1024],
|
||||
dtype: GgufQuantType::F32,
|
||||
offset: 0,
|
||||
};
|
||||
assert_eq!(info.byte_size(), 1024 * 4);
|
||||
|
||||
// F16 tensor: 1024 elements * 2 bytes
|
||||
let info = TensorInfo {
|
||||
name: "test.weight".to_string(),
|
||||
shape: vec![1024],
|
||||
dtype: GgufQuantType::F16,
|
||||
offset: 0,
|
||||
};
|
||||
assert_eq!(info.byte_size(), 1024 * 2);
|
||||
|
||||
// Q4_0 tensor: 1024 elements / 32 block_size * 18 bytes_per_block = 576 bytes
|
||||
let info = TensorInfo {
|
||||
name: "test.weight".to_string(),
|
||||
shape: vec![1024],
|
||||
dtype: GgufQuantType::Q4_0,
|
||||
offset: 0,
|
||||
};
|
||||
assert_eq!(info.byte_size(), (1024 / 32) * 18);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_info_multidimensional() {
|
||||
use crate::gguf::tensors::TensorInfo;
|
||||
|
||||
// 2D tensor: 512 x 256 = 131072 elements
|
||||
let info = TensorInfo {
|
||||
name: "model.layers.0.attention.wq.weight".to_string(),
|
||||
shape: vec![512, 256],
|
||||
dtype: GgufQuantType::F32,
|
||||
offset: 4096,
|
||||
};
|
||||
|
||||
let num_elements: usize = info.shape.iter().product();
|
||||
assert_eq!(num_elements, 131072);
|
||||
assert_eq!(info.byte_size(), 131072 * 4);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Mapping Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_alignment_calculation() {
|
||||
// Test alignment helper logic
|
||||
fn align_offset(offset: u64, alignment: u64) -> u64 {
|
||||
(offset + alignment - 1) / alignment * alignment
|
||||
}
|
||||
|
||||
assert_eq!(align_offset(0, 32), 0);
|
||||
assert_eq!(align_offset(1, 32), 32);
|
||||
assert_eq!(align_offset(31, 32), 32);
|
||||
assert_eq!(align_offset(32, 32), 32);
|
||||
assert_eq!(align_offset(33, 32), 64);
|
||||
assert_eq!(align_offset(100, 64), 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_alignment_constant() {
|
||||
use crate::gguf::DEFAULT_ALIGNMENT;
|
||||
|
||||
assert_eq!(DEFAULT_ALIGNMENT, 32);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Quantization Format Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_all_quantization_types_defined() {
|
||||
// Ensure all expected quantization types exist
|
||||
let types = [
|
||||
GgufQuantType::F32,
|
||||
GgufQuantType::F16,
|
||||
GgufQuantType::Q4_0,
|
||||
GgufQuantType::Q4_1,
|
||||
GgufQuantType::Q5_0,
|
||||
GgufQuantType::Q5_1,
|
||||
GgufQuantType::Q8_0,
|
||||
GgufQuantType::Q8_1,
|
||||
GgufQuantType::Q2_K,
|
||||
GgufQuantType::Q3_K,
|
||||
GgufQuantType::Q4_K,
|
||||
GgufQuantType::Q5_K,
|
||||
GgufQuantType::Q6_K,
|
||||
];
|
||||
|
||||
for qt in &types {
|
||||
assert!(
|
||||
qt.block_size() > 0,
|
||||
"{:?} should have positive block size",
|
||||
qt
|
||||
);
|
||||
assert!(
|
||||
qt.type_size() > 0,
|
||||
"{:?} should have positive type size",
|
||||
qt
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_type_display() {
|
||||
// Verify quantization types can be formatted
|
||||
let qt = GgufQuantType::Q4_K;
|
||||
let formatted = format!("{:?}", qt);
|
||||
assert!(formatted.contains("Q4_K") || formatted.contains("4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_quant_larger_block_size() {
|
||||
// K-quantization uses larger blocks (256) vs legacy (32)
|
||||
assert_eq!(GgufQuantType::Q4_0.block_size(), 32);
|
||||
assert_eq!(GgufQuantType::Q4_K.block_size(), 256);
|
||||
|
||||
// K-quant should have more data per block due to super-blocks
|
||||
assert!(GgufQuantType::Q4_K.type_size() > GgufQuantType::Q4_0.type_size());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Model Config Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_model_config_default() {
|
||||
use crate::gguf::ModelConfig;
|
||||
|
||||
let config = ModelConfig::default();
|
||||
|
||||
assert!(config.architecture.is_none());
|
||||
assert!(config.context_length.is_none());
|
||||
assert!(config.embedding_length.is_none());
|
||||
assert!(config.head_count.is_none());
|
||||
assert!(config.head_count_kv.is_none());
|
||||
assert!(config.layer_count.is_none());
|
||||
assert!(config.vocab_size.is_none());
|
||||
assert!(config.rope_freq_base.is_none());
|
||||
assert!(config.feed_forward_length.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_config_populated() {
|
||||
use crate::gguf::ModelConfig;
|
||||
|
||||
let config = ModelConfig {
|
||||
architecture: Some("llama".to_string()),
|
||||
context_length: Some(4096),
|
||||
embedding_length: Some(4096),
|
||||
head_count: Some(32),
|
||||
head_count_kv: Some(8),
|
||||
layer_count: Some(32),
|
||||
vocab_size: Some(32000),
|
||||
rope_freq_base: Some(10000.0),
|
||||
feed_forward_length: Some(11008),
|
||||
};
|
||||
|
||||
assert_eq!(config.architecture.as_deref(), Some("llama"));
|
||||
assert_eq!(config.context_length, Some(4096));
|
||||
assert_eq!(config.head_count, Some(32));
|
||||
assert_eq!(config.head_count_kv, Some(8));
|
||||
|
||||
// GQA ratio
|
||||
let gqa_ratio = config.head_count.unwrap() / config.head_count_kv.unwrap();
|
||||
assert_eq!(gqa_ratio, 4);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests (Without Real Files)
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_complete_header_metadata_flow() {
|
||||
// Build a minimal but complete GGUF-like data structure
|
||||
let mut data = vec![];
|
||||
|
||||
// Header
|
||||
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
|
||||
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
|
||||
data.extend_from_slice(&0u64.to_le_bytes()); // No tensors
|
||||
data.extend_from_slice(&1u64.to_le_bytes()); // 1 metadata entry
|
||||
|
||||
// Metadata entry: architecture = "llama"
|
||||
let key = "general.architecture";
|
||||
let value = "llama";
|
||||
data.extend_from_slice(&(key.len() as u64).to_le_bytes());
|
||||
data.extend_from_slice(key.as_bytes());
|
||||
data.extend_from_slice(&8u32.to_le_bytes()); // String type
|
||||
data.extend_from_slice(&(value.len() as u64).to_le_bytes());
|
||||
data.extend_from_slice(value.as_bytes());
|
||||
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
// Parse header
|
||||
let header = parse_header(&mut cursor).unwrap();
|
||||
assert_eq!(header.magic, GGUF_MAGIC);
|
||||
assert_eq!(header.metadata_kv_count, 1);
|
||||
|
||||
// Parse metadata
|
||||
let metadata = parse_metadata(&mut cursor, header.metadata_kv_count).unwrap();
|
||||
assert_eq!(
|
||||
metadata.get("general.architecture").unwrap().as_str(),
|
||||
Some("llama")
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_empty_string_value() {
|
||||
let key = "test.empty";
|
||||
let value = "";
|
||||
|
||||
let mut value_bytes = vec![];
|
||||
value_bytes.extend_from_slice(&0u64.to_le_bytes()); // length = 0
|
||||
|
||||
let data = build_metadata_entry(key, 8, &value_bytes);
|
||||
let mut cursor = Cursor::new(data);
|
||||
|
||||
let metadata = parse_metadata(&mut cursor, 1).unwrap();
|
||||
|
||||
assert_eq!(metadata.get(key).unwrap().as_str(), Some(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_tensor_count() {
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
|
||||
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
|
||||
data.extend_from_slice(&0u64.to_le_bytes()); // Zero tensors
|
||||
data.extend_from_slice(&0u64.to_le_bytes()); // Zero metadata
|
||||
|
||||
let mut cursor = Cursor::new(data);
|
||||
let header = parse_header(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(header.tensor_count, 0);
|
||||
assert_eq!(header.metadata_kv_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_tensor_count() {
|
||||
// Should parse headers with large counts (though reading would require actual data)
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
|
||||
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
|
||||
data.extend_from_slice(&1000u64.to_le_bytes()); // 1000 tensors
|
||||
data.extend_from_slice(&500u64.to_le_bytes()); // 500 metadata entries
|
||||
|
||||
let mut cursor = Cursor::new(data);
|
||||
let header = parse_header(&mut cursor).unwrap();
|
||||
|
||||
assert_eq!(header.tensor_count, 1000);
|
||||
assert_eq!(header.metadata_kv_count, 500);
|
||||
}
|
||||
19
vendor/ruvector/crates/ruvllm/src/tests/mod.rs
vendored
Normal file
19
vendor/ruvector/crates/ruvllm/src/tests/mod.rs
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
//! Comprehensive test suite for RuvLLM
|
||||
//!
|
||||
//! This module organizes all unit tests for the RuvLLM crate.
|
||||
|
||||
mod activation_tests;
|
||||
mod attention_tests;
|
||||
mod generation_tests;
|
||||
mod gguf_tests;
|
||||
mod witness_log_tests;
|
||||
|
||||
// Basic lib configuration tests (moved from lib.rs)
|
||||
use crate::RuvLLMConfig;
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let config = RuvLLMConfig::default();
|
||||
assert_eq!(config.max_sessions, 1000);
|
||||
assert_eq!(config.embedding_dim, 768);
|
||||
}
|
||||
713
vendor/ruvector/crates/ruvllm/src/tests/witness_log_tests.rs
vendored
Normal file
713
vendor/ruvector/crates/ruvllm/src/tests/witness_log_tests.rs
vendored
Normal file
@@ -0,0 +1,713 @@
|
||||
//! Witness Log Tests
|
||||
//!
|
||||
//! Tests for async write batching, flush on shutdown, backpressure handling,
|
||||
//! and the overall witness logging system.
|
||||
|
||||
use crate::types::ModelSize;
|
||||
use crate::witness_log::{
|
||||
AsyncWriteConfig, LatencyBreakdown, RoutingDecision, WitnessEntry, WitnessLog,
|
||||
};
|
||||
use std::time::Instant;
|
||||
|
||||
// ============================================================================
|
||||
// LatencyBreakdown Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_latency_breakdown_default() {
|
||||
let latency = LatencyBreakdown::default();
|
||||
|
||||
assert_eq!(latency.embedding_ms, 0.0);
|
||||
assert_eq!(latency.retrieval_ms, 0.0);
|
||||
assert_eq!(latency.routing_ms, 0.0);
|
||||
assert_eq!(latency.attention_ms, 0.0);
|
||||
assert_eq!(latency.generation_ms, 0.0);
|
||||
assert_eq!(latency.total_ms, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_latency_breakdown_compute_total() {
|
||||
let mut latency = LatencyBreakdown {
|
||||
embedding_ms: 10.0,
|
||||
retrieval_ms: 5.0,
|
||||
routing_ms: 2.0,
|
||||
attention_ms: 50.0,
|
||||
generation_ms: 100.0,
|
||||
total_ms: 0.0,
|
||||
};
|
||||
|
||||
latency.compute_total();
|
||||
|
||||
assert_eq!(latency.total_ms, 167.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_latency_breakdown_exceeds_threshold() {
|
||||
let latency = LatencyBreakdown {
|
||||
embedding_ms: 10.0,
|
||||
retrieval_ms: 5.0,
|
||||
routing_ms: 2.0,
|
||||
attention_ms: 50.0,
|
||||
generation_ms: 100.0,
|
||||
total_ms: 167.0,
|
||||
};
|
||||
|
||||
assert!(latency.exceeds_threshold(100.0));
|
||||
assert!(!latency.exceeds_threshold(200.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_latency_breakdown_slowest_component() {
|
||||
let latency = LatencyBreakdown {
|
||||
embedding_ms: 10.0,
|
||||
retrieval_ms: 5.0,
|
||||
routing_ms: 2.0,
|
||||
attention_ms: 50.0,
|
||||
generation_ms: 100.0,
|
||||
total_ms: 167.0,
|
||||
};
|
||||
|
||||
let (name, value) = latency.slowest_component();
|
||||
assert_eq!(name, "generation");
|
||||
assert_eq!(value, 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_latency_breakdown_slowest_component_attention() {
|
||||
let latency = LatencyBreakdown {
|
||||
embedding_ms: 10.0,
|
||||
retrieval_ms: 5.0,
|
||||
routing_ms: 2.0,
|
||||
attention_ms: 200.0,
|
||||
generation_ms: 100.0,
|
||||
total_ms: 317.0,
|
||||
};
|
||||
|
||||
let (name, _) = latency.slowest_component();
|
||||
assert_eq!(name, "attention");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_latency_breakdown_all_zeros() {
|
||||
let latency = LatencyBreakdown::default();
|
||||
let (_, value) = latency.slowest_component();
|
||||
assert_eq!(value, 0.0);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// RoutingDecision Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_routing_decision_default() {
|
||||
let decision = RoutingDecision::default();
|
||||
|
||||
assert_eq!(decision.model, ModelSize::Small);
|
||||
assert_eq!(decision.context_size, 0);
|
||||
assert!((decision.temperature - 0.7).abs() < 0.01);
|
||||
assert!((decision.top_p - 0.9).abs() < 0.01);
|
||||
assert!((decision.confidence - 0.5).abs() < 0.01);
|
||||
assert_eq!(decision.model_probs, [0.25, 0.25, 0.25, 0.25]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_routing_decision_custom() {
|
||||
let decision = RoutingDecision {
|
||||
model: ModelSize::Large,
|
||||
context_size: 4096,
|
||||
temperature: 0.3,
|
||||
top_p: 0.95,
|
||||
confidence: 0.85,
|
||||
model_probs: [0.1, 0.1, 0.2, 0.6],
|
||||
};
|
||||
|
||||
assert_eq!(decision.model, ModelSize::Large);
|
||||
assert_eq!(decision.context_size, 4096);
|
||||
assert!((decision.confidence - 0.85).abs() < 0.01);
|
||||
|
||||
// Probabilities should sum to 1.0
|
||||
let sum: f32 = decision.model_probs.iter().sum();
|
||||
assert!((sum - 1.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_routing_decision_serialization() {
|
||||
let decision = RoutingDecision {
|
||||
model: ModelSize::Medium,
|
||||
context_size: 2048,
|
||||
temperature: 0.5,
|
||||
top_p: 0.85,
|
||||
confidence: 0.7,
|
||||
model_probs: [0.2, 0.3, 0.3, 0.2],
|
||||
};
|
||||
|
||||
// Test that serialization works
|
||||
let json = serde_json::to_string(&decision).unwrap();
|
||||
assert!(json.contains("context_size"));
|
||||
|
||||
// Test roundtrip
|
||||
let deserialized: RoutingDecision = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.context_size, 2048);
|
||||
assert!((deserialized.temperature - 0.5).abs() < 0.01);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WitnessEntry Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_new() {
|
||||
let entry = WitnessEntry::new(
|
||||
"session-123".to_string(),
|
||||
vec![0.1; 768],
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
|
||||
assert!(!entry.request_id.is_nil());
|
||||
assert_eq!(entry.session_id, "session-123");
|
||||
assert_eq!(entry.query_embedding.len(), 768);
|
||||
assert_eq!(entry.model_used, ModelSize::Small);
|
||||
assert_eq!(entry.quality_score, 0.0);
|
||||
assert!(entry.is_success());
|
||||
assert!(entry.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_with_quality() {
|
||||
let entry = WitnessEntry::new(
|
||||
"session-456".to_string(),
|
||||
vec![0.5; 768],
|
||||
RoutingDecision::default(),
|
||||
)
|
||||
.with_quality(0.85);
|
||||
|
||||
assert!((entry.quality_score - 0.85).abs() < 0.01);
|
||||
assert!(entry.meets_quality_threshold(0.8));
|
||||
assert!(!entry.meets_quality_threshold(0.9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_with_latency() {
|
||||
let latency = LatencyBreakdown {
|
||||
embedding_ms: 5.0,
|
||||
retrieval_ms: 10.0,
|
||||
routing_ms: 1.0,
|
||||
attention_ms: 30.0,
|
||||
generation_ms: 50.0,
|
||||
total_ms: 96.0,
|
||||
};
|
||||
|
||||
let entry = WitnessEntry::new(
|
||||
"session-789".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
)
|
||||
.with_latency(latency);
|
||||
|
||||
assert_eq!(entry.latency.total_ms, 96.0);
|
||||
assert_eq!(entry.latency.generation_ms, 50.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_with_error() {
|
||||
use crate::types::ErrorInfo;
|
||||
|
||||
let error = ErrorInfo {
|
||||
code: "TIMEOUT".to_string(),
|
||||
message: "Request timed out".to_string(),
|
||||
stack_trace: None,
|
||||
recovery_attempted: false,
|
||||
};
|
||||
|
||||
let entry = WitnessEntry::new(
|
||||
"session-error".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
)
|
||||
.with_error(error);
|
||||
|
||||
assert!(!entry.is_success());
|
||||
assert!(entry.error.is_some());
|
||||
assert_eq!(entry.error.as_ref().unwrap().code, "TIMEOUT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_quality_threshold_edge_cases() {
|
||||
let entry_zero = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
)
|
||||
.with_quality(0.0);
|
||||
|
||||
assert!(entry_zero.meets_quality_threshold(0.0));
|
||||
assert!(!entry_zero.meets_quality_threshold(0.1));
|
||||
|
||||
let entry_one = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
)
|
||||
.with_quality(1.0);
|
||||
|
||||
assert!(entry_one.meets_quality_threshold(1.0));
|
||||
assert!(entry_one.meets_quality_threshold(0.99));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_timestamp() {
|
||||
let before = chrono::Utc::now();
|
||||
let entry = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
let after = chrono::Utc::now();
|
||||
|
||||
assert!(entry.timestamp >= before);
|
||||
assert!(entry.timestamp <= after);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_unique_ids() {
|
||||
let entry1 = WitnessEntry::new("s1".to_string(), vec![0.0; 768], RoutingDecision::default());
|
||||
let entry2 = WitnessEntry::new("s1".to_string(), vec![0.0; 768], RoutingDecision::default());
|
||||
|
||||
// Each entry should have unique request_id
|
||||
assert_ne!(entry1.request_id, entry2.request_id);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// AsyncWriteConfig Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_async_write_config_default() {
|
||||
let config = AsyncWriteConfig::default();
|
||||
|
||||
assert_eq!(config.max_batch_size, 100);
|
||||
assert_eq!(config.max_wait_ms, 1000);
|
||||
assert_eq!(config.max_queue_depth, 10000);
|
||||
assert!(!config.fsync_critical);
|
||||
assert_eq!(config.flush_interval_ms, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_async_write_config_custom() {
|
||||
let config = AsyncWriteConfig {
|
||||
max_batch_size: 50,
|
||||
max_wait_ms: 500,
|
||||
max_queue_depth: 5000,
|
||||
fsync_critical: true,
|
||||
flush_interval_ms: 250,
|
||||
};
|
||||
|
||||
assert_eq!(config.max_batch_size, 50);
|
||||
assert!(config.fsync_critical);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WritebackQueue Behavior Tests (Indirect via WitnessLog)
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_writeback_batching_behavior() {
|
||||
// Simulate the batching behavior
|
||||
let max_batch_size = 10;
|
||||
let mut batch: Vec<WitnessEntry> = Vec::new();
|
||||
|
||||
// Add entries
|
||||
for i in 0..15 {
|
||||
let entry = WitnessEntry::new(
|
||||
format!("session-{}", i),
|
||||
vec![i as f32 / 100.0; 768],
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
batch.push(entry);
|
||||
|
||||
// Check if batch should be flushed
|
||||
if batch.len() >= max_batch_size {
|
||||
assert_eq!(batch.len(), 10);
|
||||
batch.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Remaining entries
|
||||
assert_eq!(batch.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backpressure_behavior() {
|
||||
// Simulate backpressure when queue is full
|
||||
let max_queue_depth = 100;
|
||||
let mut queue_len = 0;
|
||||
let mut dropped = 0;
|
||||
|
||||
for _ in 0..150 {
|
||||
if queue_len < max_queue_depth {
|
||||
queue_len += 1;
|
||||
} else {
|
||||
dropped += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(queue_len, 100);
|
||||
assert_eq!(dropped, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_time_based_flush_simulation() {
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
let max_wait = Duration::from_millis(100);
|
||||
let start = Instant::now();
|
||||
|
||||
// Simulate waiting for time-based flush
|
||||
sleep(Duration::from_millis(50));
|
||||
assert!(start.elapsed() < max_wait, "Not yet time to flush");
|
||||
|
||||
sleep(Duration::from_millis(60));
|
||||
assert!(start.elapsed() >= max_wait, "Should flush by now");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WitnessLog Stats Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_witness_log_stats_structure() {
|
||||
use crate::witness_log::WitnessLogStats;
|
||||
|
||||
let stats = WitnessLogStats {
|
||||
total_entries: 1000,
|
||||
success_count: 950,
|
||||
error_count: 50,
|
||||
success_rate: 0.95,
|
||||
pending_writes: 25,
|
||||
dropped_entries: 0,
|
||||
background_running: false,
|
||||
};
|
||||
|
||||
assert_eq!(stats.total_entries, 1000);
|
||||
assert_eq!(stats.success_count + stats.error_count, 1000);
|
||||
assert!((stats.success_rate - 0.95).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_log_stats_default() {
|
||||
use crate::witness_log::WitnessLogStats;
|
||||
|
||||
let stats = WitnessLogStats::default();
|
||||
|
||||
assert_eq!(stats.total_entries, 0);
|
||||
assert_eq!(stats.success_count, 0);
|
||||
assert_eq!(stats.error_count, 0);
|
||||
assert_eq!(stats.success_rate, 0.0);
|
||||
assert_eq!(stats.pending_writes, 0);
|
||||
assert_eq!(stats.dropped_entries, 0);
|
||||
assert!(!stats.background_running);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_log_stats_serialization() {
|
||||
use crate::witness_log::WitnessLogStats;
|
||||
|
||||
let stats = WitnessLogStats {
|
||||
total_entries: 100,
|
||||
success_count: 95,
|
||||
error_count: 5,
|
||||
success_rate: 0.95,
|
||||
pending_writes: 10,
|
||||
dropped_entries: 0,
|
||||
background_running: false,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&stats).unwrap();
|
||||
assert!(json.contains("total_entries"));
|
||||
assert!(json.contains("success_rate"));
|
||||
|
||||
let deserialized: WitnessLogStats = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.total_entries, 100);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrent Access Simulation Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_entry_creation() {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let mut handles = vec![];
|
||||
|
||||
// Spawn multiple threads creating entries
|
||||
for _ in 0..10 {
|
||||
let counter_clone = Arc::clone(&counter);
|
||||
handles.push(thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
let _ = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
counter_clone.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(counter.load(Ordering::Relaxed), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unique_ids_concurrent() {
|
||||
use std::collections::HashSet;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
|
||||
let ids = Arc::new(Mutex::new(HashSet::new()));
|
||||
let mut handles = vec![];
|
||||
|
||||
for _ in 0..10 {
|
||||
let ids_clone = Arc::clone(&ids);
|
||||
handles.push(thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
let entry = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
ids_clone.lock().unwrap().insert(entry.request_id);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
let unique_count = ids.lock().unwrap().len();
|
||||
assert_eq!(unique_count, 1000, "All IDs should be unique");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_error_chain() {
|
||||
use crate::types::ErrorInfo;
|
||||
|
||||
let entry = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
)
|
||||
.with_quality(0.5)
|
||||
.with_latency(LatencyBreakdown {
|
||||
embedding_ms: 10.0,
|
||||
retrieval_ms: 5.0,
|
||||
routing_ms: 2.0,
|
||||
attention_ms: 30.0,
|
||||
generation_ms: 50.0,
|
||||
total_ms: 97.0,
|
||||
})
|
||||
.with_error(ErrorInfo {
|
||||
code: "GEN_FAILED".to_string(),
|
||||
message: "Generation failed".to_string(),
|
||||
stack_trace: None,
|
||||
recovery_attempted: false,
|
||||
});
|
||||
|
||||
// All builder methods should work together
|
||||
assert!((entry.quality_score - 0.5).abs() < 0.01);
|
||||
assert_eq!(entry.latency.total_ms, 97.0);
|
||||
assert!(!entry.is_success());
|
||||
assert_eq!(entry.error.as_ref().unwrap().code, "GEN_FAILED");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tag Filtering Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_tags() {
|
||||
let mut entry = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
|
||||
entry.tags.push("production".to_string());
|
||||
entry.tags.push("high-priority".to_string());
|
||||
entry.tags.push("api-v2".to_string());
|
||||
|
||||
assert_eq!(entry.tags.len(), 3);
|
||||
assert!(entry.tags.contains(&"production".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_witness_entry_filter_by_tag() {
|
||||
let entries: Vec<WitnessEntry> = (0..10)
|
||||
.map(|i| {
|
||||
let mut entry = WitnessEntry::new(
|
||||
format!("session-{}", i),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
if i % 2 == 0 {
|
||||
entry.tags.push("even".to_string());
|
||||
} else {
|
||||
entry.tags.push("odd".to_string());
|
||||
}
|
||||
entry
|
||||
})
|
||||
.collect();
|
||||
|
||||
let even_entries: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| e.tags.contains(&"even".to_string()))
|
||||
.collect();
|
||||
|
||||
assert_eq!(even_entries.len(), 5);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Measurement Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_entry_creation_performance() {
|
||||
let iterations = 10000;
|
||||
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let _ = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![0.0; 768],
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
}
|
||||
let duration = start.elapsed();
|
||||
|
||||
let avg_us = duration.as_micros() as f64 / iterations as f64;
|
||||
assert!(
|
||||
avg_us < 100.0,
|
||||
"Entry creation should be fast: {}us",
|
||||
avg_us
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_latency_breakdown_performance() {
|
||||
let iterations = 100000;
|
||||
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let mut latency = LatencyBreakdown {
|
||||
embedding_ms: 10.0,
|
||||
retrieval_ms: 5.0,
|
||||
routing_ms: 2.0,
|
||||
attention_ms: 50.0,
|
||||
generation_ms: 100.0,
|
||||
total_ms: 0.0,
|
||||
};
|
||||
latency.compute_total();
|
||||
let _ = latency.slowest_component();
|
||||
}
|
||||
let duration = start.elapsed();
|
||||
|
||||
let avg_ns = duration.as_nanos() as f64 / iterations as f64;
|
||||
assert!(
|
||||
avg_ns < 1000.0,
|
||||
"Latency operations should be fast: {}ns",
|
||||
avg_ns
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_empty_embedding() {
|
||||
let entry = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
vec![], // Empty embedding
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
|
||||
assert!(entry.query_embedding.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_embedding() {
|
||||
let large_embedding = vec![0.1; 4096]; // 4K dimension embedding
|
||||
|
||||
let entry = WitnessEntry::new(
|
||||
"session".to_string(),
|
||||
large_embedding.clone(),
|
||||
RoutingDecision::default(),
|
||||
);
|
||||
|
||||
assert_eq!(entry.query_embedding.len(), 4096);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_session_id() {
|
||||
let entry = WitnessEntry::new("".to_string(), vec![0.0; 768], RoutingDecision::default());
|
||||
|
||||
assert!(entry.session_id.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_session_id() {
|
||||
let long_id = "x".repeat(1000);
|
||||
|
||||
let entry = WitnessEntry::new(long_id.clone(), vec![0.0; 768], RoutingDecision::default());
|
||||
|
||||
assert_eq!(entry.session_id.len(), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extreme_latency_values() {
|
||||
let latency = LatencyBreakdown {
|
||||
embedding_ms: f32::MAX / 10.0,
|
||||
retrieval_ms: 0.0,
|
||||
routing_ms: 0.0,
|
||||
attention_ms: 0.0,
|
||||
generation_ms: 0.0,
|
||||
total_ms: 0.0,
|
||||
};
|
||||
|
||||
assert!(latency.embedding_ms.is_finite());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_confidence_routing() {
|
||||
let decision = RoutingDecision {
|
||||
model: ModelSize::Tiny,
|
||||
confidence: 0.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(decision.confidence, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_confidence_routing() {
|
||||
let decision = RoutingDecision {
|
||||
model: ModelSize::Large,
|
||||
confidence: 1.0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(decision.confidence, 1.0);
|
||||
}
|
||||
Reference in New Issue
Block a user