Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,644 @@
//! Activation Function Tests
//!
//! Tests for NEON vs scalar implementations of activation functions:
//! SiLU, GELU, ReLU, and Softmax, including correctness and benchmarks.
use std::time::Instant;
// ============================================================================
// SiLU (Swish) Activation Tests
// ============================================================================
/// Reference SiLU implementation: x * sigmoid(x)
fn silu_reference(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
/// Vectorized SiLU for testing
fn silu_vec_reference(input: &[f32]) -> Vec<f32> {
input.iter().map(|&x| silu_reference(x)).collect()
}
#[test]
fn test_silu_basic_values() {
// Test known values
let inputs = vec![0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5];
for x in inputs {
let result = silu_reference(x);
// SiLU(0) = 0
if x == 0.0 {
assert!((result - 0.0).abs() < 1e-6, "SiLU(0) should be 0");
}
// SiLU should be finite for all finite inputs
assert!(result.is_finite(), "SiLU({}) should be finite", x);
// For positive x, SiLU(x) < x (since sigmoid < 1)
if x > 0.0 {
assert!(result < x, "SiLU({}) should be less than {}", x, x);
}
}
}
#[test]
fn test_silu_vector() {
let input = vec![0.0, 0.5, 1.0, 1.5, 2.0, -0.5, -1.0, -1.5];
let output = silu_vec_reference(&input);
assert_eq!(output.len(), input.len());
// Verify each element
for (i, (&x, &y)) in input.iter().zip(output.iter()).enumerate() {
let expected = silu_reference(x);
assert!(
(y - expected).abs() < 1e-6,
"SiLU mismatch at index {}: got {}, expected {}",
i,
y,
expected
);
}
}
#[test]
fn test_silu_symmetry() {
// SiLU is NOT symmetric: silu(-x) != -silu(x)
// But there's a relationship: silu(-x) = -x * sigmoid(-x) = -x/(1+e^x)
let x = 1.5;
let silu_pos = silu_reference(x);
let silu_neg = silu_reference(-x);
// They should NOT be equal in magnitude
assert!((silu_pos.abs() - silu_neg.abs()).abs() > 0.1);
}
#[test]
fn test_silu_large_values() {
// Test numerical stability with large values
let large_positive = 100.0f32;
let large_negative = -100.0f32;
let result_pos = silu_reference(large_positive);
let result_neg = silu_reference(large_negative);
// For large positive x, SiLU(x) ≈ x
assert!((result_pos - large_positive).abs() < 1e-4);
// For large negative x, SiLU(x) ≈ 0
assert!(result_neg.abs() < 1e-4);
}
// ============================================================================
// GELU Activation Tests
// ============================================================================
/// Reference GELU implementation (approximation)
fn gelu_reference(x: f32) -> f32 {
// Approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
let sqrt_2_pi = 0.7978845608_f32;
let coeff = 0.044715_f32;
let inner = sqrt_2_pi * (x + coeff * x * x * x);
0.5 * x * (1.0 + inner.tanh())
}
/// Exact GELU (using erf)
fn gelu_exact(x: f32) -> f32 {
// GELU(x) = x * Phi(x) where Phi is standard normal CDF
// = 0.5 * x * (1 + erf(x / sqrt(2)))
let sqrt_2 = std::f32::consts::SQRT_2;
0.5 * x * (1.0 + erf_approx(x / sqrt_2))
}
/// Simple erf approximation for testing
fn erf_approx(x: f32) -> f32 {
// Abramowitz and Stegun approximation
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let a1 = 0.254829592_f32;
let a2 = -0.284496736_f32;
let a3 = 1.421413741_f32;
let a4 = -1.453152027_f32;
let a5 = 1.061405429_f32;
let p = 0.3275911_f32;
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
sign * y
}
fn gelu_vec_reference(input: &[f32]) -> Vec<f32> {
input.iter().map(|&x| gelu_reference(x)).collect()
}
#[test]
fn test_gelu_basic_values() {
// GELU(0) = 0
assert!((gelu_reference(0.0) - 0.0).abs() < 1e-6);
// For large positive x, GELU(x) ≈ x
let large = 5.0;
assert!((gelu_reference(large) - large).abs() < 0.1);
// For large negative x, GELU(x) ≈ 0
assert!(gelu_reference(-5.0).abs() < 0.1);
}
#[test]
fn test_gelu_approx_vs_exact() {
// Test that approximation is close to exact GELU
let test_values = vec![-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0];
for x in test_values {
let approx = gelu_reference(x);
let exact = gelu_exact(x);
// Approximation should be within 1%
let error = (approx - exact).abs() / exact.abs().max(1e-6);
assert!(
error < 0.01,
"GELU approximation error too large at x={}: approx={}, exact={}",
x,
approx,
exact
);
}
}
#[test]
fn test_gelu_vector() {
let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0, 3.0, -3.0, 0.5];
let output = gelu_vec_reference(&input);
assert_eq!(output.len(), input.len());
for (i, &y) in output.iter().enumerate() {
assert!(y.is_finite(), "GELU output {} should be finite", i);
}
}
#[test]
fn test_gelu_monotonicity() {
// GELU is approximately monotonic for x > -0.5
let values: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
let outputs = gelu_vec_reference(&values);
for i in 1..outputs.len() {
// Not strictly monotonic but increasing trend for positive values
if values[i] > 0.5 {
assert!(
outputs[i] >= outputs[i - 1] - 1e-6,
"GELU should be increasing for positive values"
);
}
}
}
// ============================================================================
// ReLU Activation Tests
// ============================================================================
fn relu_reference(x: f32) -> f32 {
x.max(0.0)
}
fn relu_vec_reference(input: &[f32]) -> Vec<f32> {
input.iter().map(|&x| relu_reference(x)).collect()
}
#[test]
fn test_relu_basic() {
assert_eq!(relu_reference(5.0), 5.0);
assert_eq!(relu_reference(0.0), 0.0);
assert_eq!(relu_reference(-5.0), 0.0);
assert_eq!(relu_reference(-0.001), 0.0);
assert_eq!(relu_reference(0.001), 0.001);
}
#[test]
fn test_relu_vector() {
let input = vec![-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0];
let expected = vec![0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0];
let output = relu_vec_reference(&input);
assert_eq!(output, expected);
}
#[test]
fn test_relu_is_idempotent() {
// ReLU(ReLU(x)) = ReLU(x)
let input = vec![-5.0, -1.0, 0.0, 1.0, 5.0];
let once = relu_vec_reference(&input);
let twice = relu_vec_reference(&once);
assert_eq!(once, twice);
}
#[test]
fn test_relu_special_values() {
assert!(relu_reference(f32::INFINITY).is_infinite());
assert_eq!(relu_reference(f32::NEG_INFINITY), 0.0);
// NaN handling can vary; either NaN or 0.0 is acceptable
let nan_result = relu_reference(f32::NAN);
assert!(nan_result.is_nan() || nan_result == 0.0);
}
// ============================================================================
// Softmax Tests
// ============================================================================
fn softmax_reference(logits: &[f32]) -> Vec<f32> {
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.collect()
}
#[test]
fn test_softmax_sum_to_one() {
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let probs = softmax_reference(&logits);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"Softmax should sum to 1.0, got {}",
sum
);
}
#[test]
fn test_softmax_all_positive() {
let logits = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
let probs = softmax_reference(&logits);
for p in &probs {
assert!(*p > 0.0, "All softmax outputs should be positive");
}
}
#[test]
fn test_softmax_ordering() {
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let probs = softmax_reference(&logits);
// Probabilities should be in increasing order
for i in 0..probs.len() - 1 {
assert!(
probs[i] < probs[i + 1],
"Higher logit should have higher prob"
);
}
}
#[test]
fn test_softmax_numerical_stability() {
// Test with very large logits (would overflow without max subtraction)
let logits = vec![1000.0, 1001.0, 1002.0];
let probs = softmax_reference(&logits);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-4,
"Softmax should be stable with large inputs"
);
assert!(
probs.iter().all(|p| p.is_finite()),
"All probs should be finite"
);
}
#[test]
fn test_softmax_uniform() {
// Equal logits should give uniform distribution
let logits = vec![5.0, 5.0, 5.0, 5.0];
let probs = softmax_reference(&logits);
for p in &probs {
assert!(
(p - 0.25).abs() < 1e-6,
"Equal logits should give uniform probs"
);
}
}
#[test]
fn test_softmax_temperature_effect() {
let logits = vec![1.0, 2.0, 3.0];
// Temperature 1.0
let probs_t1 = softmax_reference(&logits);
// Temperature 0.5 (sharper)
let scaled_05: Vec<f32> = logits.iter().map(|&x| x / 0.5).collect();
let probs_t05 = softmax_reference(&scaled_05);
// Temperature 2.0 (flatter)
let scaled_20: Vec<f32> = logits.iter().map(|&x| x / 2.0).collect();
let probs_t20 = softmax_reference(&scaled_20);
// Lower temperature should concentrate probability on max
assert!(
probs_t05[2] > probs_t1[2],
"Lower temp should increase max prob"
);
// Higher temperature should flatten distribution
assert!(
probs_t20[0] > probs_t1[0],
"Higher temp should increase min prob"
);
}
// ============================================================================
// Leaky ReLU Tests
// ============================================================================
fn leaky_relu_reference(x: f32, alpha: f32) -> f32 {
if x > 0.0 {
x
} else {
alpha * x
}
}
fn leaky_relu_vec_reference(input: &[f32], alpha: f32) -> Vec<f32> {
input
.iter()
.map(|&x| leaky_relu_reference(x, alpha))
.collect()
}
#[test]
fn test_leaky_relu_basic() {
let alpha = 0.01;
assert_eq!(leaky_relu_reference(5.0, alpha), 5.0);
assert_eq!(leaky_relu_reference(0.0, alpha), 0.0);
// Use tolerance for floating-point comparison
assert!((leaky_relu_reference(-5.0, alpha) - (-0.05)).abs() < 1e-6);
}
#[test]
fn test_leaky_relu_reduces_to_relu() {
let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
let leaky = leaky_relu_vec_reference(&input, 0.0);
let relu = relu_vec_reference(&input);
assert_eq!(leaky, relu, "Leaky ReLU with alpha=0 should equal ReLU");
}
#[test]
fn test_leaky_relu_continuity() {
let alpha = 0.1;
let epsilon = 1e-6;
// Check continuity at x=0
let left = leaky_relu_reference(-epsilon, alpha);
let right = leaky_relu_reference(epsilon, alpha);
let at_zero = leaky_relu_reference(0.0, alpha);
assert!(
(left - at_zero).abs() < 1e-4,
"Should be continuous from left"
);
assert!(
(right - at_zero).abs() < 1e-4,
"Should be continuous from right"
);
}
// ============================================================================
// Performance Comparison Tests (NEON vs Scalar)
// ============================================================================
#[test]
fn test_activation_performance_comparison() {
// Create test data
let size = 10000;
let input: Vec<f32> = (0..size).map(|i| (i as f32 - 5000.0) / 1000.0).collect();
// Warm up
let _ = relu_vec_reference(&input);
let _ = silu_vec_reference(&input);
let _ = gelu_vec_reference(&input);
// Benchmark ReLU
let start = Instant::now();
for _ in 0..100 {
let _ = relu_vec_reference(&input);
}
let relu_time = start.elapsed();
// Benchmark SiLU
let start = Instant::now();
for _ in 0..100 {
let _ = silu_vec_reference(&input);
}
let silu_time = start.elapsed();
// Benchmark GELU
let start = Instant::now();
for _ in 0..100 {
let _ = gelu_vec_reference(&input);
}
let gelu_time = start.elapsed();
// Benchmark Softmax
let softmax_input: Vec<f32> = input[0..1000].to_vec();
let start = Instant::now();
for _ in 0..100 {
let _ = softmax_reference(&softmax_input);
}
let softmax_time = start.elapsed();
// Print timing results (for manual inspection)
// These assertions just verify the functions complete in reasonable time
assert!(relu_time.as_millis() < 1000, "ReLU should complete quickly");
assert!(
silu_time.as_millis() < 2000,
"SiLU should complete in reasonable time"
);
assert!(
gelu_time.as_millis() < 2000,
"GELU should complete in reasonable time"
);
assert!(
softmax_time.as_millis() < 1000,
"Softmax should complete quickly"
);
}
// ============================================================================
// NEON vs Scalar Correctness Tests
// ============================================================================
#[test]
fn test_neon_softmax_vs_scalar() {
// Test our reference softmax implementation produces valid probability distribution
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let scalar_result = softmax_reference(&logits);
// Sum should be 1.0
let sum: f32 = scalar_result.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-4,
"Softmax sum should be 1.0, got {}",
sum
);
// All probabilities should be positive
assert!(scalar_result.iter().all(|&p| p > 0.0 && p < 1.0));
// Ordering should be preserved (higher logits = higher probs)
for i in 0..scalar_result.len() - 1 {
assert!(scalar_result[i] < scalar_result[i + 1]);
}
}
#[test]
fn test_neon_softmax_large_array() {
// Test reference softmax with large array
let logits: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 10.0).collect();
let scalar_result = softmax_reference(&logits);
// Check sum
let scalar_sum: f32 = scalar_result.iter().sum();
assert!(
(scalar_sum - 1.0).abs() < 1e-4,
"Scalar softmax sum should be 1.0, got {}",
scalar_sum
);
// Check all values are valid probabilities
assert!(scalar_result
.iter()
.all(|&p| p >= 0.0 && p <= 1.0 && p.is_finite()));
// Check ordering is preserved
for i in 0..scalar_result.len() - 1 {
assert!(
scalar_result[i] <= scalar_result[i + 1],
"Ordering should be preserved"
);
}
}
// ============================================================================
// Edge Case Tests
// ============================================================================
#[test]
fn test_activation_empty_input() {
let empty: Vec<f32> = vec![];
assert!(relu_vec_reference(&empty).is_empty());
assert!(silu_vec_reference(&empty).is_empty());
assert!(gelu_vec_reference(&empty).is_empty());
}
#[test]
fn test_activation_single_element() {
let single = vec![2.5];
assert_eq!(relu_vec_reference(&single), vec![2.5]);
assert_eq!(silu_vec_reference(&single).len(), 1);
assert_eq!(gelu_vec_reference(&single).len(), 1);
let softmax_result = softmax_reference(&single);
assert_eq!(softmax_result.len(), 1);
assert!((softmax_result[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_activation_all_negative() {
let input = vec![-5.0, -4.0, -3.0, -2.0, -1.0];
// ReLU should be all zeros
let relu_result = relu_vec_reference(&input);
assert!(relu_result.iter().all(|&x| x == 0.0));
// SiLU should be small but non-zero
let silu_result = silu_vec_reference(&input);
assert!(silu_result.iter().all(|&x| x < 0.0));
// Softmax should still sum to 1
let softmax_result = softmax_reference(&input);
let sum: f32 = softmax_result.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_activation_all_zeros() {
let input = vec![0.0, 0.0, 0.0, 0.0];
// ReLU(0) = 0
assert_eq!(relu_vec_reference(&input), input);
// SiLU(0) = 0
let silu_result = silu_vec_reference(&input);
assert!(silu_result.iter().all(|&x| x.abs() < 1e-6));
// GELU(0) = 0
let gelu_result = gelu_vec_reference(&input);
assert!(gelu_result.iter().all(|&x| x.abs() < 1e-6));
// Softmax of all equal values should be uniform
let softmax_result = softmax_reference(&input);
assert!(softmax_result.iter().all(|&x| (x - 0.25).abs() < 1e-6));
}
// ============================================================================
// Gradient-like Tests (Derivative Approximation)
// ============================================================================
#[test]
fn test_relu_derivative() {
let epsilon = 1e-5;
// Positive x: derivative should be 1
let x = 2.0;
let deriv = (relu_reference(x + epsilon) - relu_reference(x - epsilon)) / (2.0 * epsilon);
assert!((deriv - 1.0).abs() < 0.01);
// Negative x: derivative should be 0
let x = -2.0;
let deriv = (relu_reference(x + epsilon) - relu_reference(x - epsilon)) / (2.0 * epsilon);
assert!(deriv.abs() < 0.01);
}
#[test]
fn test_silu_derivative_at_zero() {
let epsilon = 1e-5;
let x = 0.0;
let deriv = (silu_reference(x + epsilon) - silu_reference(x - epsilon)) / (2.0 * epsilon);
// SiLU'(0) = 0.5
assert!(
(deriv - 0.5).abs() < 0.01,
"SiLU derivative at 0 should be 0.5"
);
}
#[test]
fn test_gelu_derivative_positive() {
let epsilon = 1e-5;
let x = 1.0;
let deriv = (gelu_reference(x + epsilon) - gelu_reference(x - epsilon)) / (2.0 * epsilon);
// For positive x, GELU derivative should be close to 1
assert!(
deriv > 0.5 && deriv < 1.5,
"GELU derivative at x=1 should be near 1"
);
}

View File

@@ -0,0 +1,889 @@
//! Attention Tests
//!
//! Tests for Flash Attention, Paged Attention, MQA/GQA implementations,
//! output correctness, memory allocation, pre-allocated buffer reuse, and benchmarks.
use crate::kernels::{
flash_attention_auto, flash_attention_neon, flash_attention_v2, grouped_query_attention_neon,
multi_query_attention_neon, paged_attention_neon, select_block_size, AttentionConfig,
PagedKvCache, BLOCK_SIZE_LARGE, BLOCK_SIZE_MEDIUM, BLOCK_SIZE_SMALL,
};
use std::time::Instant;
// ============================================================================
// Helper Functions
// ============================================================================
/// Reference scalar attention implementation for correctness checking
fn attention_reference(
query: &[f32],
key: &[f32],
value: &[f32],
head_dim: usize,
scale: f32,
) -> Vec<f32> {
let kv_len = key.len() / head_dim;
// Compute scores: Q @ K^T
let mut scores = Vec::with_capacity(kv_len);
for t in 0..kv_len {
let k_offset = t * head_dim;
let score: f32 = query
.iter()
.zip(&key[k_offset..k_offset + head_dim])
.map(|(q, k)| q * k * scale)
.sum();
scores.push(score);
}
// Softmax
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let attn_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
// Weighted sum of values
let mut output = vec![0.0; head_dim];
for (t, weight) in attn_weights.iter().enumerate() {
let v_offset = t * head_dim;
for (i, v) in value[v_offset..v_offset + head_dim].iter().enumerate() {
output[i] += weight * v;
}
}
output
}
/// Generate random test data
fn generate_test_data(head_dim: usize, kv_len: usize, seed: u64) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let mut rng_state = seed;
let next_float = |state: &mut u64| -> f32 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
((*state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0
};
let query: Vec<f32> = (0..head_dim).map(|_| next_float(&mut rng_state)).collect();
let key: Vec<f32> = (0..kv_len * head_dim)
.map(|_| next_float(&mut rng_state))
.collect();
let value: Vec<f32> = (0..kv_len * head_dim)
.map(|_| next_float(&mut rng_state))
.collect();
(query, key, value)
}
/// Check if two vectors are approximately equal
fn vectors_approx_equal(a: &[f32], b: &[f32], tolerance: f32) -> bool {
if a.len() != b.len() {
return false;
}
a.iter()
.zip(b.iter())
.all(|(x, y)| (x - y).abs() < tolerance)
}
// ============================================================================
// Flash Attention Basic Tests
// ============================================================================
#[test]
fn test_flash_attention_basic() {
let head_dim = 16;
let kv_len = 4;
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32) * 0.1).collect();
let key: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.01).collect();
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.02).collect();
let scale = 1.0 / (head_dim as f32).sqrt();
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert_eq!(
output.len(),
head_dim,
"Output should have head_dim elements"
);
assert!(
output.iter().all(|&x| x.is_finite()),
"All outputs should be finite"
);
}
#[test]
fn test_flash_attention_vs_reference() {
let head_dim = 32;
let kv_len = 16;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 12345);
let neon_output = flash_attention_neon(&query, &key, &value, scale, false);
let ref_output = attention_reference(&query, &key, &value, head_dim, scale);
assert!(
vectors_approx_equal(&neon_output, &ref_output, 1e-3),
"NEON and reference outputs should match"
);
}
#[test]
fn test_flash_attention_empty_kv() {
let head_dim = 16;
let query: Vec<f32> = (0..head_dim).map(|i| i as f32).collect();
let key: Vec<f32> = vec![];
let value: Vec<f32> = vec![];
let scale = 1.0 / (head_dim as f32).sqrt();
let output = flash_attention_neon(&query, &key, &value, scale, false);
// Should handle empty KV gracefully - either return empty or zero-filled vector
assert!(output.len() == 0 || output.len() == head_dim);
}
#[test]
fn test_flash_attention_single_token() {
let head_dim = 64;
let kv_len = 1;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 42);
let output = flash_attention_neon(&query, &key, &value, scale, false);
// With single KV token, output should be proportional to the value
// (after softmax, the single token gets weight 1.0)
assert!(
vectors_approx_equal(&output, &value, 1e-5),
"Single token attention should return value directly"
);
}
// ============================================================================
// Flash Attention V2 Block Size Tests
// ============================================================================
#[test]
fn test_flash_attention_v2_small_block() {
let head_dim = 64;
let kv_len = 100;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 111);
let output_small = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_SMALL);
let output_medium = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_MEDIUM);
// Different block sizes should produce same results
assert!(
vectors_approx_equal(&output_small, &output_medium, 1e-3),
"Block sizes should not affect correctness"
);
}
#[test]
fn test_flash_attention_v2_all_block_sizes() {
let head_dim = 128;
let kv_len = 256;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 222);
let output_small = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_SMALL);
let output_medium = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_MEDIUM);
let output_large = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_LARGE);
// All should produce similar results
assert!(vectors_approx_equal(&output_small, &output_medium, 1e-3));
assert!(vectors_approx_equal(&output_medium, &output_large, 1e-3));
}
#[test]
fn test_flash_attention_auto_block_selection() {
let head_dim = 128;
let scale = 1.0 / (head_dim as f32).sqrt();
// Short sequence should use small blocks
let (q1, k1, v1) = generate_test_data(head_dim, 32, 333);
let _output1 = flash_attention_auto(&q1, &k1, &v1, scale, false);
// Long sequence should use larger blocks
let (q2, k2, v2) = generate_test_data(head_dim, 1024, 444);
let _output2 = flash_attention_auto(&q2, &k2, &v2, scale, false);
// Just verify they complete without error
}
// ============================================================================
// Block Size Selection Tests
// ============================================================================
#[test]
fn test_select_block_size_short_sequence() {
let head_dim = 128;
// Very short sequences should use small blocks
assert_eq!(select_block_size(32, head_dim), BLOCK_SIZE_SMALL);
assert_eq!(select_block_size(64, head_dim), BLOCK_SIZE_SMALL);
}
#[test]
fn test_select_block_size_medium_sequence() {
let head_dim = 128;
// Medium sequences should use medium blocks
assert_eq!(select_block_size(128, head_dim), BLOCK_SIZE_MEDIUM);
assert_eq!(select_block_size(256, head_dim), BLOCK_SIZE_MEDIUM);
assert_eq!(select_block_size(512, head_dim), BLOCK_SIZE_MEDIUM);
}
#[test]
fn test_select_block_size_long_sequence() {
let head_dim = 64; // Smaller head_dim allows larger blocks
// Long sequences with small head_dim can use large blocks
let block = select_block_size(2048, head_dim);
assert!(
block >= BLOCK_SIZE_MEDIUM,
"Long sequences should use at least medium blocks"
);
}
#[test]
fn test_select_block_size_large_head_dim() {
let head_dim = 256; // Large head_dim limits block size
// Large head_dim should constrain block size to fit in L1
let block = select_block_size(2048, head_dim);
assert!(block <= BLOCK_SIZE_LARGE);
}
// ============================================================================
// Paged KV Cache Tests
// ============================================================================
#[test]
fn test_paged_kv_cache_creation() {
let cache = PagedKvCache::new(16, 4, 64);
assert_eq!(cache.block_size, 16);
assert_eq!(cache.num_kv_heads, 4);
assert_eq!(cache.head_dim, 64);
assert_eq!(cache.num_tokens, 0);
assert!(cache.key_blocks.is_empty());
assert!(cache.value_blocks.is_empty());
}
#[test]
fn test_paged_kv_cache_append() {
let mut cache = PagedKvCache::new(4, 2, 8);
// Append one token (2 kv_heads * 8 head_dim = 16 elements)
let keys = vec![1.0; 16];
let values = vec![2.0; 16];
cache.append(&keys, &values);
assert_eq!(cache.num_tokens, 1);
assert_eq!(cache.key_blocks.len(), 1);
assert_eq!(cache.value_blocks.len(), 1);
}
#[test]
fn test_paged_kv_cache_append_multiple() {
let mut cache = PagedKvCache::new(4, 2, 8);
let stride = 2 * 8; // 16 elements per token
// Append 5 tokens (more than one block)
for i in 0..5 {
let keys = vec![(i + 1) as f32; stride];
let values = vec![(i + 1) as f32 * 2.0; stride];
cache.append(&keys, &values);
}
assert_eq!(cache.num_tokens, 5);
assert_eq!(cache.key_blocks.len(), 2); // 5 tokens, 4 per block = 2 blocks
}
#[test]
fn test_paged_kv_cache_get_keys() {
let mut cache = PagedKvCache::new(4, 1, 8);
// Append 2 tokens
let keys1 = vec![1.0; 8];
let values1 = vec![10.0; 8];
cache.append(&keys1, &values1);
let keys2 = vec![2.0; 8];
let values2 = vec![20.0; 8];
cache.append(&keys2, &values2);
let retrieved_keys = cache.get_keys();
assert_eq!(retrieved_keys.len(), 16); // 2 tokens * 1 head * 8 dim
assert!(retrieved_keys[..8].iter().all(|&x| x == 1.0));
assert!(retrieved_keys[8..].iter().all(|&x| x == 2.0));
}
#[test]
fn test_paged_kv_cache_get_values() {
let mut cache = PagedKvCache::new(4, 1, 8);
let keys = vec![1.0; 8];
let values = vec![5.0; 8];
cache.append(&keys, &values);
let retrieved_values = cache.get_values();
assert_eq!(retrieved_values.len(), 8);
assert!(retrieved_values.iter().all(|&x| x == 5.0));
}
// ============================================================================
// Paged Attention Tests
// ============================================================================
#[test]
fn test_paged_attention_empty_cache() {
let cache = PagedKvCache::new(16, 1, 16);
let query = vec![0.5; 16];
let scale = 0.25;
let output = paged_attention_neon(&query, &cache, &[], scale);
assert_eq!(output.len(), 16);
// Empty cache should return zeros
assert!(output.iter().all(|&x| x == 0.0));
}
#[test]
fn test_paged_attention_with_cache() {
let mut cache = PagedKvCache::new(16, 1, 16);
// Add some tokens
for _ in 0..8 {
let keys: Vec<f32> = (0..16).map(|i| (i as f32) * 0.1).collect();
let values: Vec<f32> = (0..16).map(|i| (i as f32) * 0.2).collect();
cache.append(&keys, &values);
}
let query: Vec<f32> = (0..16).map(|i| (i as f32) * 0.05).collect();
let scale = 1.0 / 4.0;
let output = paged_attention_neon(&query, &cache, &[], scale);
assert_eq!(output.len(), 16);
assert!(output.iter().all(|&x| x.is_finite()));
}
// ============================================================================
// Multi-Query Attention (MQA) Tests
// ============================================================================
#[test]
fn test_mqa_basic() {
let config = AttentionConfig {
num_heads: 8,
num_kv_heads: 1, // MQA: single KV head
head_dim: 16,
causal: false,
..Default::default()
};
let queries: Vec<f32> = (0..config.num_heads * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let kv_len = 4;
let keys: Vec<f32> = (0..kv_len * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let values: Vec<f32> = (0..kv_len * config.head_dim)
.map(|i| (i as f32) * 0.02)
.collect();
let output = multi_query_attention_neon(&queries, &keys, &values, &config);
assert_eq!(output.len(), config.num_heads * config.head_dim);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_mqa_shared_kv() {
// Verify that all query heads see the same K/V
let config = AttentionConfig {
num_heads: 4,
num_kv_heads: 1,
head_dim: 8,
causal: false,
..Default::default()
};
// All queries identical
let query_head: Vec<f32> = vec![1.0; config.head_dim];
let queries: Vec<f32> = query_head
.iter()
.cloned()
.cycle()
.take(config.num_heads * config.head_dim)
.collect();
let kv_len = 2;
let keys: Vec<f32> = (0..kv_len * config.head_dim)
.map(|i| (i as f32) * 0.1)
.collect();
let values: Vec<f32> = (0..kv_len * config.head_dim).map(|_| 1.0).collect();
let output = multi_query_attention_neon(&queries, &keys, &values, &config);
// All output heads should be identical since all queries are identical
let head_outputs: Vec<&[f32]> = output.chunks(config.head_dim).collect();
for i in 1..head_outputs.len() {
assert!(
vectors_approx_equal(head_outputs[0], head_outputs[i], 1e-5),
"All heads should produce same output with identical queries"
);
}
}
// ============================================================================
// Grouped-Query Attention (GQA) Tests
// ============================================================================
#[test]
fn test_gqa_basic() {
let config = AttentionConfig {
num_heads: 8,
num_kv_heads: 2, // GQA: 4:1 ratio
head_dim: 16,
causal: false,
..Default::default()
};
assert_eq!(config.gqa_ratio(), 4);
let queries: Vec<f32> = (0..config.num_heads * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let kv_len = 4;
let keys: Vec<f32> = (0..kv_len * config.num_kv_heads * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let values: Vec<f32> = (0..kv_len * config.num_kv_heads * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let output = grouped_query_attention_neon(&queries, &keys, &values, &config);
assert_eq!(output.len(), config.num_heads * config.head_dim);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_gqa_head_grouping() {
let config = AttentionConfig {
num_heads: 4,
num_kv_heads: 2, // 2:1 ratio
head_dim: 8,
causal: false,
..Default::default()
};
assert_eq!(config.gqa_ratio(), 2);
// Query heads 0,1 share KV head 0
// Query heads 2,3 share KV head 1
// Create distinct KV for each KV head
let kv_len = 2;
let mut keys = vec![0.0; kv_len * config.num_kv_heads * config.head_dim];
let mut values = vec![0.0; kv_len * config.num_kv_heads * config.head_dim];
// KV head 0: all 1.0
for t in 0..kv_len {
let offset = t * config.num_kv_heads * config.head_dim;
for i in 0..config.head_dim {
keys[offset + i] = 1.0;
values[offset + i] = 1.0;
}
}
// KV head 1: all 2.0
for t in 0..kv_len {
let offset = t * config.num_kv_heads * config.head_dim + config.head_dim;
for i in 0..config.head_dim {
keys[offset + i] = 2.0;
values[offset + i] = 2.0;
}
}
// Uniform queries
let queries: Vec<f32> = vec![0.5; config.num_heads * config.head_dim];
let output = grouped_query_attention_neon(&queries, &keys, &values, &config);
// Heads 0,1 should have values around 1.0, heads 2,3 around 2.0
let head_outputs: Vec<f32> = output
.chunks(config.head_dim)
.map(|h| h.iter().sum::<f32>() / config.head_dim as f32)
.collect();
assert!(
(head_outputs[0] - 1.0).abs() < 0.1,
"Head 0 should use KV head 0"
);
assert!(
(head_outputs[1] - 1.0).abs() < 0.1,
"Head 1 should use KV head 0"
);
assert!(
(head_outputs[2] - 2.0).abs() < 0.1,
"Head 2 should use KV head 1"
);
assert!(
(head_outputs[3] - 2.0).abs() < 0.1,
"Head 3 should use KV head 1"
);
}
// ============================================================================
// AttentionConfig Tests
// ============================================================================
#[test]
fn test_attention_config_default() {
let config = AttentionConfig::default();
assert_eq!(config.num_heads, 32);
assert_eq!(config.num_kv_heads, 8);
assert_eq!(config.head_dim, 128);
assert!(config.causal);
assert_eq!(config.gqa_ratio(), 4);
}
#[test]
fn test_attention_config_effective_scale() {
let config = AttentionConfig {
head_dim: 64,
scale: 0.0, // Auto-compute
..Default::default()
};
let expected = 1.0 / (64.0f32).sqrt();
assert!((config.effective_scale() - expected).abs() < 1e-6);
// Explicit scale
let config2 = AttentionConfig {
head_dim: 64,
scale: 0.2,
..Default::default()
};
assert!((config2.effective_scale() - 0.2).abs() < 1e-6);
}
#[test]
fn test_attention_config_gqa_ratios() {
// Standard MHA (1:1)
let mha = AttentionConfig {
num_heads: 32,
num_kv_heads: 32,
..Default::default()
};
assert_eq!(mha.gqa_ratio(), 1);
// GQA 4:1
let gqa_4 = AttentionConfig {
num_heads: 32,
num_kv_heads: 8,
..Default::default()
};
assert_eq!(gqa_4.gqa_ratio(), 4);
// GQA 8:1
let gqa_8 = AttentionConfig {
num_heads: 32,
num_kv_heads: 4,
..Default::default()
};
assert_eq!(gqa_8.gqa_ratio(), 8);
// MQA (all heads share 1 KV)
let mqa = AttentionConfig {
num_heads: 32,
num_kv_heads: 1,
..Default::default()
};
assert_eq!(mqa.gqa_ratio(), 32);
}
// ============================================================================
// Memory Allocation Tests
// ============================================================================
#[test]
fn test_attention_no_extra_allocation() {
let head_dim = 128;
let kv_len = 256;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 555);
// Run attention multiple times
let output1 = flash_attention_neon(&query, &key, &value, scale, false);
let output2 = flash_attention_neon(&query, &key, &value, scale, false);
let output3 = flash_attention_neon(&query, &key, &value, scale, false);
// Results should be identical (deterministic)
assert!(vectors_approx_equal(&output1, &output2, 1e-6));
assert!(vectors_approx_equal(&output2, &output3, 1e-6));
}
#[test]
fn test_attention_output_size_correct() {
let head_dim = 64;
let kv_len = 100;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 666);
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert_eq!(
output.len(),
head_dim,
"Output should exactly match head_dim"
);
}
// ============================================================================
// Performance Benchmark Tests
// ============================================================================
#[test]
fn test_attention_benchmark_short_sequence() {
let head_dim = 128;
let kv_len = 64;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 777);
// Warm up
for _ in 0..10 {
let _ = flash_attention_neon(&query, &key, &value, scale, false);
}
// Benchmark
let iterations = 1000;
let start = Instant::now();
for _ in 0..iterations {
let _ = flash_attention_neon(&query, &key, &value, scale, false);
}
let duration = start.elapsed();
let avg_us = duration.as_micros() as f64 / iterations as f64;
assert!(
avg_us < 1000.0,
"Short sequence attention should be fast: {}us",
avg_us
);
}
#[test]
fn test_attention_benchmark_long_sequence() {
let head_dim = 128;
let kv_len = 2048;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 888);
// Warm up
for _ in 0..5 {
let _ = flash_attention_neon(&query, &key, &value, scale, false);
}
// Benchmark
let iterations = 100;
let start = Instant::now();
for _ in 0..iterations {
let _ = flash_attention_neon(&query, &key, &value, scale, false);
}
let duration = start.elapsed();
let avg_ms = duration.as_millis() as f64 / iterations as f64;
assert!(
avg_ms < 50.0,
"Long sequence attention should complete in <50ms: {}ms",
avg_ms
);
}
#[test]
fn test_attention_benchmark_block_sizes() {
let head_dim = 128;
let kv_len = 512;
let scale = 1.0 / (head_dim as f32).sqrt();
let iterations = 100;
let (query, key, value) = generate_test_data(head_dim, kv_len, 999);
// Benchmark small blocks
let start = Instant::now();
for _ in 0..iterations {
let _ = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_SMALL);
}
let small_time = start.elapsed();
// Benchmark medium blocks
let start = Instant::now();
for _ in 0..iterations {
let _ = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_MEDIUM);
}
let medium_time = start.elapsed();
// Benchmark large blocks
let start = Instant::now();
for _ in 0..iterations {
let _ = flash_attention_v2(&query, &key, &value, scale, false, BLOCK_SIZE_LARGE);
}
let large_time = start.elapsed();
// All should complete in reasonable time
assert!(small_time.as_millis() < 5000);
assert!(medium_time.as_millis() < 5000);
assert!(large_time.as_millis() < 5000);
}
// ============================================================================
// Numerical Stability Tests
// ============================================================================
#[test]
fn test_attention_large_logits() {
let head_dim = 32;
let kv_len = 8;
// Create query and key that will produce large dot products
let query = vec![10.0; head_dim];
let key = vec![10.0; kv_len * head_dim];
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| i as f32).collect();
let scale = 1.0 / (head_dim as f32).sqrt();
let output = flash_attention_neon(&query, &key, &value, scale, false);
// Output should be finite
assert!(
output.iter().all(|&x| x.is_finite()),
"Should handle large dot products"
);
}
#[test]
fn test_attention_small_values() {
let head_dim = 32;
let kv_len = 8;
// Very small values
let query = vec![1e-6; head_dim];
let key = vec![1e-6; kv_len * head_dim];
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| i as f32).collect();
let scale = 1.0 / (head_dim as f32).sqrt();
let output = flash_attention_neon(&query, &key, &value, scale, false);
// Output should be finite
assert!(
output.iter().all(|&x| x.is_finite()),
"Should handle small values"
);
}
#[test]
fn test_attention_mixed_signs() {
let head_dim = 32;
let kv_len = 8;
// Mix of positive and negative values
let query: Vec<f32> = (0..head_dim)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let key: Vec<f32> = (0..kv_len * head_dim)
.map(|i| if i % 3 == 0 { -0.5 } else { 0.5 })
.collect();
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.01).collect();
let scale = 1.0 / (head_dim as f32).sqrt();
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert!(output.iter().all(|&x| x.is_finite()));
}
// ============================================================================
// Edge Cases
// ============================================================================
#[test]
fn test_attention_single_head_dim() {
let head_dim = 1;
let kv_len = 4;
let query = vec![1.0];
let key = vec![1.0, 2.0, 3.0, 4.0];
let value = vec![10.0, 20.0, 30.0, 40.0];
let scale = 1.0;
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert_eq!(output.len(), 1);
assert!(output[0].is_finite());
}
#[test]
fn test_attention_large_head_dim() {
let head_dim = 512;
let kv_len = 16;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, 1111);
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert_eq!(output.len(), head_dim);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_attention_power_of_two_dims() {
// Test common power-of-2 dimensions
for head_dim in [32, 64, 128, 256] {
let kv_len = 64;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, head_dim as u64);
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert_eq!(output.len(), head_dim);
assert!(
output.iter().all(|&x| x.is_finite()),
"Failed for head_dim={}",
head_dim
);
}
}
#[test]
fn test_attention_non_power_of_two_dims() {
// Test non-power-of-2 dimensions
for head_dim in [17, 33, 65, 100, 127] {
let kv_len = 32;
let scale = 1.0 / (head_dim as f32).sqrt();
let (query, key, value) = generate_test_data(head_dim, kv_len, head_dim as u64);
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert_eq!(output.len(), head_dim);
assert!(
output.iter().all(|&x| x.is_finite()),
"Failed for head_dim={}",
head_dim
);
}
}

View File

@@ -0,0 +1,785 @@
//! Token Generation Tests
//!
//! Tests for autoregressive token generation, sampling strategies,
//! streaming callbacks, KV cache integration, and speculative decoding.
use crate::speculative::{
log_softmax, sample_from_probs, softmax, top_k_filter, top_p_filter, AtomicSpeculativeStats,
SpeculationTree, SpeculativeConfig, SpeculativeStats, TreeNode, VerificationResult,
};
use rand::rngs::StdRng;
use rand::SeedableRng;
use std::time::Duration;
// ============================================================================
// Softmax and Sampling Tests
// ============================================================================
#[test]
fn test_softmax_produces_valid_distribution() {
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let probs = softmax(&logits);
// Sum should be 1.0
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"Softmax sum should be 1.0, got {}",
sum
);
// All probabilities should be positive
assert!(
probs.iter().all(|&p| p > 0.0),
"All probabilities should be positive"
);
// Ordering should be preserved
for i in 0..probs.len() - 1 {
assert!(
probs[i] < probs[i + 1],
"Higher logits should have higher probs"
);
}
}
#[test]
fn test_softmax_handles_large_logits() {
// Test numerical stability with large logits
let logits = vec![1000.0, 1001.0, 1002.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-4,
"Should handle large logits: sum = {}",
sum
);
assert!(
probs.iter().all(|p| p.is_finite()),
"All probs should be finite"
);
}
#[test]
fn test_softmax_handles_negative_logits() {
let logits = vec![-5.0, -3.0, -1.0, 0.0, 1.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "Should handle negative logits");
assert!(probs[4] > probs[0], "Larger logit should have higher prob");
}
#[test]
fn test_softmax_empty_input() {
let logits: Vec<f32> = vec![];
let probs = softmax(&logits);
assert!(probs.is_empty(), "Empty input should return empty output");
}
#[test]
fn test_softmax_single_element() {
let logits = vec![5.0];
let probs = softmax(&logits);
assert_eq!(probs.len(), 1);
assert!(
(probs[0] - 1.0).abs() < 1e-5,
"Single element should have prob 1.0"
);
}
#[test]
fn test_log_softmax_relationship() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let probs = softmax(&logits);
let log_probs = log_softmax(&logits);
// log_softmax should equal log(softmax)
for (lp, p) in log_probs.iter().zip(probs.iter()) {
let expected = p.ln();
assert!(
(lp - expected).abs() < 1e-4,
"log_softmax should match log(softmax)"
);
}
}
#[test]
fn test_log_softmax_numerical_stability() {
// log_softmax should be stable even when softmax would underflow
let logits = vec![-1000.0, -999.0, -998.0];
let log_probs = log_softmax(&logits);
assert!(
log_probs.iter().all(|p| p.is_finite()),
"log_softmax should handle extreme values"
);
// Check that relative ordering is preserved
assert!(log_probs[0] < log_probs[1] && log_probs[1] < log_probs[2]);
}
// ============================================================================
// Top-K Filtering Tests
// ============================================================================
#[test]
fn test_top_k_filter_basic() {
let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
top_k_filter(&mut logits, 2);
// Only top 2 (indices 1 and 3 with values 5.0 and 4.0) should remain finite
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert_eq!(finite_count, 2, "Only top-k elements should remain");
// Check that correct elements are kept
assert!(logits[1].is_finite(), "5.0 should remain");
assert!(logits[3].is_finite(), "4.0 should remain");
}
#[test]
fn test_top_k_filter_k_greater_than_length() {
let mut logits = vec![1.0, 2.0, 3.0];
top_k_filter(&mut logits, 10);
// All should remain unchanged
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert_eq!(finite_count, 3, "All should remain when k > length");
}
#[test]
fn test_top_k_filter_k_zero() {
let mut logits = vec![1.0, 2.0, 3.0];
top_k_filter(&mut logits, 0);
// All should remain unchanged
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert_eq!(finite_count, 3, "All should remain when k = 0");
}
#[test]
fn test_top_k_filter_k_one() {
let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
top_k_filter(&mut logits, 1);
// Only the maximum should remain
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert_eq!(finite_count, 1, "Only one element should remain");
assert!(logits[1].is_finite(), "Maximum (5.0) should remain");
}
// ============================================================================
// Top-P (Nucleus) Filtering Tests
// ============================================================================
#[test]
fn test_top_p_filter_basic() {
// Create logits where first element dominates
let mut logits = vec![10.0, 1.0, 0.0, -1.0, -2.0];
top_p_filter(&mut logits, 0.9);
// At least the highest probability token should remain
assert!(logits[0].is_finite(), "Highest prob token should remain");
}
#[test]
fn test_top_p_filter_p_one() {
let mut logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let original = logits.clone();
top_p_filter(&mut logits, 1.0);
// All should remain unchanged when p >= 1.0
assert_eq!(logits, original, "All should remain when p = 1.0");
}
#[test]
fn test_top_p_filter_p_zero() {
let mut logits = vec![1.0, 2.0, 3.0];
top_p_filter(&mut logits, 0.0);
// Only top token should remain
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert!(finite_count >= 1, "At least one token should remain");
}
// ============================================================================
// Sampling Tests
// ============================================================================
#[test]
fn test_sample_from_probs_deterministic() {
let probs = vec![0.0, 0.0, 1.0, 0.0]; // Deterministic: only index 2
let mut rng = StdRng::seed_from_u64(12345);
for _ in 0..10 {
let idx = sample_from_probs(&probs, &mut rng);
assert_eq!(idx, 2, "Should always sample index 2");
}
}
#[test]
fn test_sample_from_probs_uniform() {
let probs = vec![0.25, 0.25, 0.25, 0.25];
let mut rng = StdRng::seed_from_u64(42);
let mut counts = vec![0usize; 4];
// Sample many times
for _ in 0..10000 {
let idx = sample_from_probs(&probs, &mut rng);
counts[idx] += 1;
}
// Each should be sampled approximately 2500 times
for (i, &count) in counts.iter().enumerate() {
let expected = 2500.0;
let actual = count as f64;
let ratio = actual / expected;
assert!(
(0.8..=1.2).contains(&ratio),
"Index {} should be sampled uniformly, got {} (expected ~{})",
i,
count,
expected
);
}
}
#[test]
fn test_sample_from_probs_skewed() {
let probs = vec![0.9, 0.05, 0.03, 0.02]; // Heavily skewed
let mut rng = StdRng::seed_from_u64(42);
let mut counts = vec![0usize; 4];
for _ in 0..1000 {
let idx = sample_from_probs(&probs, &mut rng);
counts[idx] += 1;
}
// Index 0 should dominate
assert!(counts[0] > 800, "Index 0 should be sampled most often");
}
// ============================================================================
// Temperature Scaling Tests
// ============================================================================
#[test]
fn test_temperature_scaling_sharpens() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let temperature = 0.1; // Low temperature -> sharper distribution
let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
let probs = softmax(&scaled);
// Highest logit should have much higher probability
assert!(
probs[3] > 0.99,
"Low temperature should concentrate probability on max"
);
}
#[test]
fn test_temperature_scaling_flattens() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let temperature = 10.0; // High temperature -> flatter distribution
let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
let probs = softmax(&scaled);
// Distribution should be more uniform
let min_prob = probs.iter().cloned().fold(f32::INFINITY, f32::min);
let max_prob = probs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
assert!(
max_prob - min_prob < 0.2,
"High temperature should flatten distribution"
);
}
#[test]
fn test_temperature_one_unchanged() {
let logits = vec![1.0, 2.0, 3.0, 4.0];
let temperature = 1.0;
let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
let probs1 = softmax(&logits);
let probs2 = softmax(&scaled);
for (p1, p2) in probs1.iter().zip(probs2.iter()) {
assert!(
(p1 - p2).abs() < 1e-6,
"Temperature 1.0 should not change distribution"
);
}
}
// ============================================================================
// Speculative Decoding Config Tests
// ============================================================================
#[test]
fn test_speculative_config_default() {
let config = SpeculativeConfig::default();
assert_eq!(config.lookahead, 4);
assert!((config.acceptance_threshold - 0.5).abs() < 0.01);
assert_eq!(config.draft_temperature, 0.0);
assert!(!config.tree_speculation);
assert!(config.adaptive_lookahead);
assert_eq!(config.min_lookahead, 2);
assert_eq!(config.max_lookahead, 8);
}
#[test]
fn test_speculative_config_custom() {
let config = SpeculativeConfig {
lookahead: 8,
acceptance_threshold: 0.7,
draft_temperature: 0.3,
tree_speculation: true,
max_tree_depth: 4,
tree_branching_factor: 3,
..Default::default()
};
assert_eq!(config.lookahead, 8);
assert!((config.acceptance_threshold - 0.7).abs() < 0.01);
assert!(config.tree_speculation);
assert_eq!(config.max_tree_depth, 4);
assert_eq!(config.tree_branching_factor, 3);
}
// ============================================================================
// Speculative Stats Tests
// ============================================================================
#[test]
fn test_speculative_stats_new() {
let stats = SpeculativeStats::new();
assert_eq!(stats.draft_tokens, 0);
assert_eq!(stats.accepted_tokens, 0);
assert_eq!(stats.acceptance_rate, 0.0);
assert_eq!(stats.speedup, 0.0);
assert_eq!(stats.main_forward_passes, 0);
}
#[test]
fn test_speculative_stats_record_round() {
let mut stats = SpeculativeStats::new();
// Record a round with 4 drafts, 3 accepted
stats.record_round(4, 3, 10.0);
assert_eq!(stats.draft_tokens, 4);
assert_eq!(stats.accepted_tokens, 3);
assert!((stats.acceptance_rate - 0.75).abs() < 0.01);
assert_eq!(stats.main_forward_passes, 1);
assert_eq!(stats.total_tokens_generated, 4); // 3 accepted + 1 correction
assert!((stats.total_speculation_time_ms - 10.0).abs() < 0.01);
}
#[test]
fn test_speculative_stats_multiple_rounds() {
let mut stats = SpeculativeStats::new();
// Round 1: 4 drafts, 4 accepted (100% acceptance)
stats.record_round(4, 4, 10.0);
// Round 2: 4 drafts, 2 accepted (50% acceptance)
stats.record_round(4, 2, 15.0);
assert_eq!(stats.draft_tokens, 8);
assert_eq!(stats.accepted_tokens, 6);
assert!((stats.acceptance_rate - 0.75).abs() < 0.01); // 6/8 = 0.75
assert_eq!(stats.main_forward_passes, 2);
// Total tokens depends on implementation - just check it's reasonable
assert!(
stats.total_tokens_generated >= 6,
"Should generate at least accepted tokens"
);
}
#[test]
fn test_speculative_stats_reset() {
let mut stats = SpeculativeStats::new();
stats.record_round(4, 3, 10.0);
stats.reset();
assert_eq!(stats.draft_tokens, 0);
assert_eq!(stats.accepted_tokens, 0);
assert_eq!(stats.acceptance_rate, 0.0);
}
#[test]
fn test_speculative_stats_speedup_calculation() {
let mut stats = SpeculativeStats::new();
// If we accept 4 tokens per main pass on average, speedup should be ~4x
stats.record_round(4, 4, 10.0);
stats.record_round(4, 4, 10.0);
// 10 total tokens, 2 main passes -> 5 tokens/pass
assert!(
stats.speedup > 4.0,
"Speedup should reflect tokens per main pass"
);
}
// ============================================================================
// Atomic Speculative Stats Tests
// ============================================================================
#[test]
fn test_atomic_stats_new() {
let stats = AtomicSpeculativeStats::new();
let snapshot = stats.snapshot();
assert_eq!(snapshot.draft_tokens, 0);
assert_eq!(snapshot.accepted_tokens, 0);
}
#[test]
fn test_atomic_stats_record_round() {
let stats = AtomicSpeculativeStats::new();
stats.record_round(4, 3, Duration::from_millis(10));
let snapshot = stats.snapshot();
assert_eq!(snapshot.draft_tokens, 4);
assert_eq!(snapshot.accepted_tokens, 3);
assert!((snapshot.acceptance_rate - 0.75).abs() < 0.01);
}
#[test]
fn test_atomic_stats_thread_safe() {
use std::sync::Arc;
use std::thread;
let stats = Arc::new(AtomicSpeculativeStats::new());
let mut handles = vec![];
// Spawn multiple threads recording rounds
for _ in 0..10 {
let stats_clone = Arc::clone(&stats);
handles.push(thread::spawn(move || {
for _ in 0..100 {
stats_clone.record_round(4, 3, Duration::from_millis(1));
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let snapshot = stats.snapshot();
assert_eq!(snapshot.draft_tokens, 4000); // 10 threads * 100 rounds * 4 drafts
assert_eq!(snapshot.accepted_tokens, 3000);
}
#[test]
fn test_atomic_stats_reset() {
let stats = AtomicSpeculativeStats::new();
stats.record_round(4, 3, Duration::from_millis(10));
stats.reset();
let snapshot = stats.snapshot();
assert_eq!(snapshot.draft_tokens, 0);
}
// ============================================================================
// Tree Node Tests
// ============================================================================
#[test]
fn test_tree_node_new() {
let node = TreeNode::new(42, 0.8, 0);
assert_eq!(node.token, 42);
assert!((node.prob - 0.8).abs() < 0.01);
assert!((node.logprob - 0.8f32.ln()).abs() < 0.01);
assert_eq!(node.depth, 0);
assert!(node.children.is_empty());
}
#[test]
fn test_tree_node_add_child() {
let mut root = TreeNode::new(0, 1.0, 0);
let child1 = root.add_child(1, 0.6);
assert_eq!(child1.token, 1);
assert_eq!(child1.depth, 1);
let child2 = root.add_child(2, 0.4);
assert_eq!(child2.token, 2);
assert_eq!(root.children.len(), 2);
}
#[test]
fn test_tree_node_get_paths() {
let mut root = TreeNode::new(0, 1.0, 0);
// Build a tree:
// 0
// / \
// 1 2
// /
// 3
{
let child1 = root.add_child(1, 0.6);
child1.add_child(3, 0.5);
}
root.add_child(2, 0.4);
let paths = root.get_paths();
assert_eq!(paths.len(), 2);
// Should have paths [0, 1, 3] and [0, 2]
assert!(paths.iter().any(|p| p == &vec![0, 1, 3]));
assert!(paths.iter().any(|p| p == &vec![0, 2]));
}
#[test]
fn test_tree_node_best_path() {
let mut root = TreeNode::new(0, 1.0, 0);
// Build tree with different probabilities
{
let child1 = root.add_child(1, 0.6);
child1.add_child(3, 0.5);
}
root.add_child(2, 0.4);
let best = root.best_path();
// Should follow highest probability children: 0 -> 1 -> 3
assert_eq!(best, vec![0, 1, 3]);
}
// ============================================================================
// Speculation Tree Tests
// ============================================================================
#[test]
fn test_speculation_tree_new() {
let tree = SpeculationTree::new(3, 2);
assert_eq!(tree.max_depth, 3);
assert_eq!(tree.branching_factor, 2);
assert_eq!(tree.node_count, 1);
}
#[test]
fn test_speculation_tree_clear() {
let mut tree = SpeculationTree::new(3, 2);
tree.root.add_child(1, 0.5);
tree.node_count += 1;
tree.clear();
assert_eq!(tree.node_count, 1);
assert!(tree.root.children.is_empty());
}
#[test]
fn test_speculation_tree_best_path_empty() {
let tree = SpeculationTree::new(3, 2);
let path = tree.best_path();
assert!(path.is_empty(), "Empty tree should have empty best path");
}
#[test]
fn test_speculation_tree_best_path_linear() {
let mut tree = SpeculationTree::new(4, 2);
// Build linear path: root -> 1 -> 2 -> 3
let node1 = tree.root.add_child(1, 0.8);
tree.node_count += 1;
let node2 = node1.add_child(2, 0.7);
tree.node_count += 1;
node2.add_child(3, 0.6);
tree.node_count += 1;
let path = tree.best_path();
assert_eq!(path, vec![1, 2, 3]);
}
// ============================================================================
// Verification Result Tests
// ============================================================================
#[test]
fn test_verification_result_all_accepted() {
let result = VerificationResult {
accepted_count: 4,
next_token: 100,
accepted_logprobs: vec![-0.1, -0.2, -0.1, -0.15],
next_logprob: -0.3,
all_accepted: true,
};
assert_eq!(result.accepted_count, 4);
assert_eq!(result.next_token, 100);
assert!(result.all_accepted);
}
#[test]
fn test_verification_result_partial_accept() {
let result = VerificationResult {
accepted_count: 2,
next_token: 50, // Correction token
accepted_logprobs: vec![-0.1, -0.2],
next_logprob: -0.5,
all_accepted: false,
};
assert_eq!(result.accepted_count, 2);
assert!(!result.all_accepted);
}
#[test]
fn test_verification_result_none_accepted() {
let result = VerificationResult {
accepted_count: 0,
next_token: 25, // Immediate correction
accepted_logprobs: vec![],
next_logprob: -0.4,
all_accepted: false,
};
assert_eq!(result.accepted_count, 0);
assert!(result.accepted_logprobs.is_empty());
assert!(!result.all_accepted);
}
// ============================================================================
// Integration Sampling Tests
// ============================================================================
#[test]
fn test_full_sampling_pipeline() {
// Test basic sampling pipeline functionality
let logits = vec![1.0, 2.0, 3.0, 4.0, 5.0];
// Convert to probabilities
let probs = softmax(&logits);
// Verify softmax produces valid distribution
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-4,
"Softmax should sum to 1.0, got {}",
sum
);
assert!(
probs.iter().all(|&p| p > 0.0),
"All probabilities should be positive"
);
// Sample with fixed RNG
let mut rng = StdRng::seed_from_u64(42);
let mut samples = vec![0usize; 5];
for _ in 0..1000 {
let idx = sample_from_probs(&probs, &mut rng);
if idx < samples.len() {
samples[idx] += 1;
}
}
// Higher logits should be sampled more frequently on average
let total_samples: usize = samples.iter().sum();
assert_eq!(total_samples, 1000, "Should have 1000 total samples");
// Higher indices (higher logits) should be more frequent
// This is a statistical test - with 1000 samples, index 4 (highest logit)
// should be sampled more often than index 0 (lowest logit)
assert!(
samples[4] > samples[0],
"Higher logit should be sampled more: idx4={}, idx0={}",
samples[4],
samples[0]
);
}
#[test]
fn test_greedy_decoding_simulation() {
// Simulate greedy decoding (temperature = 0 equivalent)
let logits = vec![1.0, 3.0, 2.0, 5.0, 4.0];
// Greedy: pick argmax
let argmax = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.unwrap();
assert_eq!(argmax, 3, "Greedy should select index 3 (value 5.0)");
}
#[test]
fn test_beam_search_simulation() {
// Simulate a simple beam search step
let beam_width = 3;
let logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
// Get top-k indices
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let top_indices: Vec<usize> = indexed.iter().take(beam_width).map(|(i, _)| *i).collect();
assert_eq!(
top_indices,
vec![1, 3, 2],
"Top-3 should be indices 1, 3, 2"
);
}
// ============================================================================
// Edge Cases and Error Handling
// ============================================================================
#[test]
fn test_softmax_with_inf() {
let logits = vec![f32::NEG_INFINITY, 1.0, 2.0];
let probs = softmax(&logits);
// First element should have probability ~0
assert!(
probs[0] < 1e-10 || probs[0].abs() < 1e-10,
"NEG_INFINITY should give ~0 probability"
);
// Sum should still be ~1
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-4, "Sum should be 1.0");
}
#[test]
fn test_sample_numerical_edge_case() {
// Probabilities that might cause issues
let probs = vec![0.9999999, 0.0000001];
let mut rng = StdRng::seed_from_u64(42);
// Should not panic
for _ in 0..100 {
let idx = sample_from_probs(&probs, &mut rng);
assert!(idx < 2, "Index should be valid");
}
}
#[test]
fn test_top_k_with_ties() {
let mut logits = vec![5.0, 5.0, 5.0, 1.0, 2.0];
top_k_filter(&mut logits, 3);
// All three 5.0s should remain
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert!(
finite_count >= 3,
"Should keep at least k elements when ties exist"
);
}

View File

@@ -0,0 +1,736 @@
//! GGUF Loading Tests
//!
//! Tests for GGUF header/metadata parsing, tensor loading, quantization
//! format handling, architecture detection, memory mapping, and error handling.
use crate::gguf::parser::GgufValueType;
use crate::gguf::{
parse_header, parse_metadata, GgufHeader, GgufQuantType, GgufValue, GGUF_MAGIC, GGUF_VERSION,
};
use std::io::Cursor;
// ============================================================================
// Header Parsing Tests
// ============================================================================
#[test]
fn test_parse_valid_header() {
let mut data = vec![];
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes()); // magic
data.extend_from_slice(&GGUF_VERSION.to_le_bytes()); // version
data.extend_from_slice(&10u64.to_le_bytes()); // tensor_count
data.extend_from_slice(&5u64.to_le_bytes()); // metadata_kv_count
let mut cursor = Cursor::new(data);
let header = parse_header(&mut cursor).unwrap();
assert_eq!(header.magic, GGUF_MAGIC);
assert_eq!(header.version, GGUF_VERSION);
assert_eq!(header.tensor_count, 10);
assert_eq!(header.metadata_kv_count, 5);
}
#[test]
fn test_gguf_magic_is_correct() {
// "GGUF" in little-endian bytes
let expected = 0x46554747u32;
assert_eq!(GGUF_MAGIC, expected);
// Verify it spells "GGUF"
let bytes = GGUF_MAGIC.to_le_bytes();
assert_eq!(&bytes, b"GGUF");
}
#[test]
fn test_parse_header_truncated() {
// Only provide partial header
let data = vec![0x47, 0x47, 0x55, 0x46]; // Just magic
let mut cursor = Cursor::new(data);
let result = parse_header(&mut cursor);
assert!(result.is_err(), "Truncated header should fail");
}
#[test]
fn test_parse_header_empty() {
let data: Vec<u8> = vec![];
let mut cursor = Cursor::new(data);
let result = parse_header(&mut cursor);
assert!(result.is_err(), "Empty input should fail");
}
// ============================================================================
// GgufValue Tests
// ============================================================================
#[test]
fn test_gguf_value_string() {
let val = GgufValue::String("test_value".to_string());
assert_eq!(val.as_str(), Some("test_value"));
assert_eq!(val.as_u64(), None);
assert_eq!(val.as_i64(), None);
assert_eq!(val.as_f32(), None);
assert_eq!(val.as_bool(), None);
assert!(val.as_array().is_none());
}
#[test]
fn test_gguf_value_integer_conversions() {
// Test U32
let val = GgufValue::U32(42);
assert_eq!(val.as_u64(), Some(42));
assert_eq!(val.as_i64(), Some(42));
assert_eq!(val.as_f32(), Some(42.0));
assert_eq!(val.as_str(), None);
// Test I32
let val = GgufValue::I32(-5);
assert_eq!(val.as_i64(), Some(-5));
assert_eq!(val.as_u64(), None); // Negative cannot be u64
// Test U64
let val = GgufValue::U64(u64::MAX);
assert_eq!(val.as_u64(), Some(u64::MAX));
assert_eq!(val.as_i64(), None); // Too large for i64
// Test I64
let val = GgufValue::I64(-100);
assert_eq!(val.as_i64(), Some(-100));
assert_eq!(val.as_u64(), None);
// Test I64 positive
let val = GgufValue::I64(100);
assert_eq!(val.as_i64(), Some(100));
assert_eq!(val.as_u64(), Some(100));
}
#[test]
fn test_gguf_value_float_conversions() {
// Test F32
let val = GgufValue::F32(3.14);
assert!((val.as_f32().unwrap() - 3.14).abs() < 0.001);
assert!((val.as_f64().unwrap() - 3.14).abs() < 0.001);
assert_eq!(val.as_str(), None);
// Test F64
let val = GgufValue::F64(2.71828);
assert!((val.as_f64().unwrap() - 2.71828).abs() < 0.00001);
assert!((val.as_f32().unwrap() - 2.71828).abs() < 0.001);
}
#[test]
fn test_gguf_value_bool() {
let val_true = GgufValue::Bool(true);
let val_false = GgufValue::Bool(false);
assert_eq!(val_true.as_bool(), Some(true));
assert_eq!(val_false.as_bool(), Some(false));
assert_eq!(val_true.as_str(), None);
// Test implicit bool from U8
let val = GgufValue::U8(1);
assert_eq!(val.as_bool(), Some(true));
let val = GgufValue::U8(0);
assert_eq!(val.as_bool(), Some(false));
}
#[test]
fn test_gguf_value_array() {
let arr = vec![GgufValue::U32(1), GgufValue::U32(2), GgufValue::U32(3)];
let val = GgufValue::Array(arr);
let array = val.as_array().unwrap();
assert_eq!(array.len(), 3);
assert_eq!(array[0].as_u64(), Some(1));
assert_eq!(array[1].as_u64(), Some(2));
assert_eq!(array[2].as_u64(), Some(3));
}
#[test]
fn test_gguf_value_small_integers() {
// Test U8
let val = GgufValue::U8(255);
assert_eq!(val.as_u64(), Some(255));
// Test I8
let val = GgufValue::I8(-128);
assert_eq!(val.as_i64(), Some(-128));
assert_eq!(val.as_u64(), None);
// Test U16
let val = GgufValue::U16(65535);
assert_eq!(val.as_u64(), Some(65535));
// Test I16
let val = GgufValue::I16(-32768);
assert_eq!(val.as_i64(), Some(-32768));
}
// ============================================================================
// GgufValueType Tests
// ============================================================================
#[test]
fn test_value_type_conversion() {
assert_eq!(GgufValueType::try_from(0).unwrap(), GgufValueType::U8);
assert_eq!(GgufValueType::try_from(1).unwrap(), GgufValueType::I8);
assert_eq!(GgufValueType::try_from(2).unwrap(), GgufValueType::U16);
assert_eq!(GgufValueType::try_from(3).unwrap(), GgufValueType::I16);
assert_eq!(GgufValueType::try_from(4).unwrap(), GgufValueType::U32);
assert_eq!(GgufValueType::try_from(5).unwrap(), GgufValueType::I32);
assert_eq!(GgufValueType::try_from(6).unwrap(), GgufValueType::F32);
assert_eq!(GgufValueType::try_from(7).unwrap(), GgufValueType::Bool);
assert_eq!(GgufValueType::try_from(8).unwrap(), GgufValueType::String);
assert_eq!(GgufValueType::try_from(9).unwrap(), GgufValueType::Array);
assert_eq!(GgufValueType::try_from(10).unwrap(), GgufValueType::U64);
assert_eq!(GgufValueType::try_from(11).unwrap(), GgufValueType::I64);
assert_eq!(GgufValueType::try_from(12).unwrap(), GgufValueType::F64);
}
#[test]
fn test_value_type_invalid() {
assert!(GgufValueType::try_from(13).is_err());
assert!(GgufValueType::try_from(100).is_err());
assert!(GgufValueType::try_from(255).is_err());
}
// ============================================================================
// Quantization Type Tests
// ============================================================================
#[test]
fn test_quant_type_from_u32() {
assert!(GgufQuantType::try_from(0u32).is_ok()); // F32
assert!(GgufQuantType::try_from(1u32).is_ok()); // F16
assert!(GgufQuantType::try_from(2u32).is_ok()); // Q4_0
assert!(GgufQuantType::try_from(3u32).is_ok()); // Q4_1
assert!(GgufQuantType::try_from(8u32).is_ok()); // Q8_0
}
#[test]
fn test_quant_type_block_size() {
assert_eq!(GgufQuantType::F32.block_size(), 1);
assert_eq!(GgufQuantType::F16.block_size(), 1);
assert_eq!(GgufQuantType::Q4_0.block_size(), 32);
assert_eq!(GgufQuantType::Q4_1.block_size(), 32);
assert_eq!(GgufQuantType::Q8_0.block_size(), 32);
assert_eq!(GgufQuantType::Q4_K.block_size(), 256);
assert_eq!(GgufQuantType::Q2_K.block_size(), 256);
assert_eq!(GgufQuantType::Q3_K.block_size(), 256);
assert_eq!(GgufQuantType::Q5_K.block_size(), 256);
assert_eq!(GgufQuantType::Q6_K.block_size(), 256);
}
#[test]
fn test_quant_type_type_size() {
// F32: 4 bytes per element, 1 element per block
assert_eq!(GgufQuantType::F32.type_size(), 4);
// F16: 2 bytes per element, 1 element per block
assert_eq!(GgufQuantType::F16.type_size(), 2);
// Q4_0: 2 bytes scale + 16 bytes data (32 elements * 4 bits / 8) = 18 bytes
assert_eq!(GgufQuantType::Q4_0.type_size(), 18);
// Q4_1: 2 bytes scale + 2 bytes min + 16 bytes data = 20 bytes
assert_eq!(GgufQuantType::Q4_1.type_size(), 20);
// Q8_0: 2 bytes scale + 32 bytes data = 34 bytes
assert_eq!(GgufQuantType::Q8_0.type_size(), 34);
}
#[test]
fn test_quant_type_is_quantized() {
assert!(!GgufQuantType::F32.is_quantized());
assert!(!GgufQuantType::F16.is_quantized());
assert!(GgufQuantType::Q4_0.is_quantized());
assert!(GgufQuantType::Q4_1.is_quantized());
assert!(GgufQuantType::Q8_0.is_quantized());
assert!(GgufQuantType::Q4_K.is_quantized());
assert!(GgufQuantType::Q2_K.is_quantized());
}
#[test]
fn test_quant_type_bits_per_weight() {
// bits_per_weight returns f32
assert!((GgufQuantType::F32.bits_per_weight() - 32.0).abs() < 0.1);
assert!((GgufQuantType::F16.bits_per_weight() - 16.0).abs() < 0.1);
// Q8_0: 34 bytes * 8 / 32 elements = 8.5 bits
assert!((GgufQuantType::Q8_0.bits_per_weight() - 8.5).abs() < 0.1);
// Q4_0: (18 bytes * 8 bits) / 32 elements = 4.5 bits
let q4_bits =
(GgufQuantType::Q4_0.type_size() * 8) as f32 / GgufQuantType::Q4_0.block_size() as f32;
assert!((q4_bits - 4.5).abs() < 0.1);
}
// ============================================================================
// Architecture Detection Tests
// ============================================================================
#[test]
fn test_architecture_metadata_key() {
// Verify common architecture metadata keys
let arch_keys = [
"general.architecture",
"llama.context_length",
"llama.embedding_length",
"llama.attention.head_count",
"llama.attention.head_count_kv",
"llama.block_count",
"llama.rope.freq_base",
"mistral.context_length",
"phi.context_length",
];
for key in &arch_keys {
// Just verify the key format is valid
assert!(!key.is_empty());
assert!(key.contains('.') || key.starts_with("general"));
}
}
#[test]
fn test_architecture_detection_patterns() {
// Test architecture pattern matching logic
let arch_patterns = [
("llama", "llama"),
("mistral", "mistral"),
("phi", "phi"),
("phi2", "phi"),
("phi3", "phi"),
("qwen", "qwen"),
("qwen2", "qwen"),
("gemma", "gemma"),
];
for (input, expected_prefix) in &arch_patterns {
let normalized = input.to_lowercase();
assert!(
normalized.starts_with(expected_prefix) || normalized.contains(expected_prefix),
"{} should match {} pattern",
input,
expected_prefix
);
}
}
// ============================================================================
// Metadata Parsing Tests
// ============================================================================
fn build_metadata_entry(key: &str, value_type: u32, value_bytes: &[u8]) -> Vec<u8> {
let mut data = vec![];
// Key: length (u64) + bytes
data.extend_from_slice(&(key.len() as u64).to_le_bytes());
data.extend_from_slice(key.as_bytes());
// Value type
data.extend_from_slice(&value_type.to_le_bytes());
// Value data
data.extend_from_slice(value_bytes);
data
}
#[test]
fn test_parse_metadata_u32() {
let key = "test.value";
let value = 12345u32;
let data = build_metadata_entry(key, 4, &value.to_le_bytes());
let mut cursor = Cursor::new(data);
let metadata = parse_metadata(&mut cursor, 1).unwrap();
assert!(metadata.contains_key(key));
assert_eq!(metadata.get(key).unwrap().as_u64(), Some(12345));
}
#[test]
fn test_parse_metadata_f32() {
let key = "test.float";
let value = 3.14159f32;
let data = build_metadata_entry(key, 6, &value.to_le_bytes());
let mut cursor = Cursor::new(data);
let metadata = parse_metadata(&mut cursor, 1).unwrap();
let parsed = metadata.get(key).unwrap().as_f32().unwrap();
assert!((parsed - 3.14159).abs() < 0.0001);
}
#[test]
fn test_parse_metadata_string() {
let key = "test.name";
let value = "hello_world";
let mut value_bytes = vec![];
value_bytes.extend_from_slice(&(value.len() as u64).to_le_bytes());
value_bytes.extend_from_slice(value.as_bytes());
let data = build_metadata_entry(key, 8, &value_bytes);
let mut cursor = Cursor::new(data);
let metadata = parse_metadata(&mut cursor, 1).unwrap();
assert_eq!(metadata.get(key).unwrap().as_str(), Some("hello_world"));
}
#[test]
fn test_parse_metadata_bool() {
let key = "test.enabled";
let value = 1u8;
let data = build_metadata_entry(key, 7, &[value]);
let mut cursor = Cursor::new(data);
let metadata = parse_metadata(&mut cursor, 1).unwrap();
assert_eq!(metadata.get(key).unwrap().as_bool(), Some(true));
}
#[test]
fn test_parse_metadata_multiple_entries() {
let mut data = vec![];
// Entry 1: U32
data.extend(build_metadata_entry("key1", 4, &42u32.to_le_bytes()));
// Entry 2: F32
data.extend(build_metadata_entry("key2", 6, &1.5f32.to_le_bytes()));
let mut cursor = Cursor::new(data);
let metadata = parse_metadata(&mut cursor, 2).unwrap();
assert_eq!(metadata.len(), 2);
assert_eq!(metadata.get("key1").unwrap().as_u64(), Some(42));
assert!((metadata.get("key2").unwrap().as_f32().unwrap() - 1.5).abs() < 0.001);
}
// ============================================================================
// Error Handling Tests
// ============================================================================
#[test]
fn test_parse_metadata_truncated_key() {
// Key length says 100 but only provide 5 bytes
let mut data = vec![];
data.extend_from_slice(&100u64.to_le_bytes()); // Key length
data.extend_from_slice(b"test"); // Only 4 bytes
let mut cursor = Cursor::new(data);
let result = parse_metadata(&mut cursor, 1);
assert!(result.is_err(), "Truncated key should fail");
}
#[test]
fn test_parse_metadata_invalid_value_type() {
let mut data = vec![];
data.extend_from_slice(&4u64.to_le_bytes()); // Key length
data.extend_from_slice(b"test");
data.extend_from_slice(&255u32.to_le_bytes()); // Invalid type
let mut cursor = Cursor::new(data);
let result = parse_metadata(&mut cursor, 1);
assert!(result.is_err(), "Invalid value type should fail");
}
#[test]
fn test_string_too_long_protection() {
// Attempt to create a string entry with unreasonable length
let key = "malicious.string";
let claimed_len = 10_000_000u64; // 10MB string
let mut data = vec![];
data.extend_from_slice(&(key.len() as u64).to_le_bytes());
data.extend_from_slice(key.as_bytes());
data.extend_from_slice(&8u32.to_le_bytes()); // String type
data.extend_from_slice(&claimed_len.to_le_bytes());
// Don't actually provide the data
let mut cursor = Cursor::new(data);
let result = parse_metadata(&mut cursor, 1);
assert!(result.is_err(), "Unreasonably long string should fail");
}
// ============================================================================
// TensorInfo Tests
// ============================================================================
#[test]
fn test_tensor_info_byte_size() {
use crate::gguf::tensors::TensorInfo;
// F32 tensor: 1024 elements * 4 bytes
let info = TensorInfo {
name: "test.weight".to_string(),
shape: vec![1024],
dtype: GgufQuantType::F32,
offset: 0,
};
assert_eq!(info.byte_size(), 1024 * 4);
// F16 tensor: 1024 elements * 2 bytes
let info = TensorInfo {
name: "test.weight".to_string(),
shape: vec![1024],
dtype: GgufQuantType::F16,
offset: 0,
};
assert_eq!(info.byte_size(), 1024 * 2);
// Q4_0 tensor: 1024 elements / 32 block_size * 18 bytes_per_block = 576 bytes
let info = TensorInfo {
name: "test.weight".to_string(),
shape: vec![1024],
dtype: GgufQuantType::Q4_0,
offset: 0,
};
assert_eq!(info.byte_size(), (1024 / 32) * 18);
}
#[test]
fn test_tensor_info_multidimensional() {
use crate::gguf::tensors::TensorInfo;
// 2D tensor: 512 x 256 = 131072 elements
let info = TensorInfo {
name: "model.layers.0.attention.wq.weight".to_string(),
shape: vec![512, 256],
dtype: GgufQuantType::F32,
offset: 4096,
};
let num_elements: usize = info.shape.iter().product();
assert_eq!(num_elements, 131072);
assert_eq!(info.byte_size(), 131072 * 4);
}
// ============================================================================
// Memory Mapping Tests
// ============================================================================
#[test]
fn test_alignment_calculation() {
// Test alignment helper logic
fn align_offset(offset: u64, alignment: u64) -> u64 {
(offset + alignment - 1) / alignment * alignment
}
assert_eq!(align_offset(0, 32), 0);
assert_eq!(align_offset(1, 32), 32);
assert_eq!(align_offset(31, 32), 32);
assert_eq!(align_offset(32, 32), 32);
assert_eq!(align_offset(33, 32), 64);
assert_eq!(align_offset(100, 64), 128);
}
#[test]
fn test_default_alignment_constant() {
use crate::gguf::DEFAULT_ALIGNMENT;
assert_eq!(DEFAULT_ALIGNMENT, 32);
}
// ============================================================================
// Quantization Format Tests
// ============================================================================
#[test]
fn test_all_quantization_types_defined() {
// Ensure all expected quantization types exist
let types = [
GgufQuantType::F32,
GgufQuantType::F16,
GgufQuantType::Q4_0,
GgufQuantType::Q4_1,
GgufQuantType::Q5_0,
GgufQuantType::Q5_1,
GgufQuantType::Q8_0,
GgufQuantType::Q8_1,
GgufQuantType::Q2_K,
GgufQuantType::Q3_K,
GgufQuantType::Q4_K,
GgufQuantType::Q5_K,
GgufQuantType::Q6_K,
];
for qt in &types {
assert!(
qt.block_size() > 0,
"{:?} should have positive block size",
qt
);
assert!(
qt.type_size() > 0,
"{:?} should have positive type size",
qt
);
}
}
#[test]
fn test_quantization_type_display() {
// Verify quantization types can be formatted
let qt = GgufQuantType::Q4_K;
let formatted = format!("{:?}", qt);
assert!(formatted.contains("Q4_K") || formatted.contains("4"));
}
#[test]
fn test_k_quant_larger_block_size() {
// K-quantization uses larger blocks (256) vs legacy (32)
assert_eq!(GgufQuantType::Q4_0.block_size(), 32);
assert_eq!(GgufQuantType::Q4_K.block_size(), 256);
// K-quant should have more data per block due to super-blocks
assert!(GgufQuantType::Q4_K.type_size() > GgufQuantType::Q4_0.type_size());
}
// ============================================================================
// Model Config Tests
// ============================================================================
#[test]
fn test_model_config_default() {
use crate::gguf::ModelConfig;
let config = ModelConfig::default();
assert!(config.architecture.is_none());
assert!(config.context_length.is_none());
assert!(config.embedding_length.is_none());
assert!(config.head_count.is_none());
assert!(config.head_count_kv.is_none());
assert!(config.layer_count.is_none());
assert!(config.vocab_size.is_none());
assert!(config.rope_freq_base.is_none());
assert!(config.feed_forward_length.is_none());
}
#[test]
fn test_model_config_populated() {
use crate::gguf::ModelConfig;
let config = ModelConfig {
architecture: Some("llama".to_string()),
context_length: Some(4096),
embedding_length: Some(4096),
head_count: Some(32),
head_count_kv: Some(8),
layer_count: Some(32),
vocab_size: Some(32000),
rope_freq_base: Some(10000.0),
feed_forward_length: Some(11008),
};
assert_eq!(config.architecture.as_deref(), Some("llama"));
assert_eq!(config.context_length, Some(4096));
assert_eq!(config.head_count, Some(32));
assert_eq!(config.head_count_kv, Some(8));
// GQA ratio
let gqa_ratio = config.head_count.unwrap() / config.head_count_kv.unwrap();
assert_eq!(gqa_ratio, 4);
}
// ============================================================================
// Integration Tests (Without Real Files)
// ============================================================================
#[test]
fn test_complete_header_metadata_flow() {
// Build a minimal but complete GGUF-like data structure
let mut data = vec![];
// Header
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
data.extend_from_slice(&0u64.to_le_bytes()); // No tensors
data.extend_from_slice(&1u64.to_le_bytes()); // 1 metadata entry
// Metadata entry: architecture = "llama"
let key = "general.architecture";
let value = "llama";
data.extend_from_slice(&(key.len() as u64).to_le_bytes());
data.extend_from_slice(key.as_bytes());
data.extend_from_slice(&8u32.to_le_bytes()); // String type
data.extend_from_slice(&(value.len() as u64).to_le_bytes());
data.extend_from_slice(value.as_bytes());
let mut cursor = Cursor::new(data);
// Parse header
let header = parse_header(&mut cursor).unwrap();
assert_eq!(header.magic, GGUF_MAGIC);
assert_eq!(header.metadata_kv_count, 1);
// Parse metadata
let metadata = parse_metadata(&mut cursor, header.metadata_kv_count).unwrap();
assert_eq!(
metadata.get("general.architecture").unwrap().as_str(),
Some("llama")
);
}
// ============================================================================
// Edge Cases
// ============================================================================
#[test]
fn test_empty_string_value() {
let key = "test.empty";
let value = "";
let mut value_bytes = vec![];
value_bytes.extend_from_slice(&0u64.to_le_bytes()); // length = 0
let data = build_metadata_entry(key, 8, &value_bytes);
let mut cursor = Cursor::new(data);
let metadata = parse_metadata(&mut cursor, 1).unwrap();
assert_eq!(metadata.get(key).unwrap().as_str(), Some(""));
}
#[test]
fn test_zero_tensor_count() {
let mut data = vec![];
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
data.extend_from_slice(&0u64.to_le_bytes()); // Zero tensors
data.extend_from_slice(&0u64.to_le_bytes()); // Zero metadata
let mut cursor = Cursor::new(data);
let header = parse_header(&mut cursor).unwrap();
assert_eq!(header.tensor_count, 0);
assert_eq!(header.metadata_kv_count, 0);
}
#[test]
fn test_large_tensor_count() {
// Should parse headers with large counts (though reading would require actual data)
let mut data = vec![];
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
data.extend_from_slice(&1000u64.to_le_bytes()); // 1000 tensors
data.extend_from_slice(&500u64.to_le_bytes()); // 500 metadata entries
let mut cursor = Cursor::new(data);
let header = parse_header(&mut cursor).unwrap();
assert_eq!(header.tensor_count, 1000);
assert_eq!(header.metadata_kv_count, 500);
}

View File

@@ -0,0 +1,19 @@
//! Comprehensive test suite for RuvLLM
//!
//! This module organizes all unit tests for the RuvLLM crate.
mod activation_tests;
mod attention_tests;
mod generation_tests;
mod gguf_tests;
mod witness_log_tests;
// Basic lib configuration tests (moved from lib.rs)
use crate::RuvLLMConfig;
#[test]
fn test_config_default() {
let config = RuvLLMConfig::default();
assert_eq!(config.max_sessions, 1000);
assert_eq!(config.embedding_dim, 768);
}

View File

@@ -0,0 +1,713 @@
//! Witness Log Tests
//!
//! Tests for async write batching, flush on shutdown, backpressure handling,
//! and the overall witness logging system.
use crate::types::ModelSize;
use crate::witness_log::{
AsyncWriteConfig, LatencyBreakdown, RoutingDecision, WitnessEntry, WitnessLog,
};
use std::time::Instant;
// ============================================================================
// LatencyBreakdown Tests
// ============================================================================
#[test]
fn test_latency_breakdown_default() {
let latency = LatencyBreakdown::default();
assert_eq!(latency.embedding_ms, 0.0);
assert_eq!(latency.retrieval_ms, 0.0);
assert_eq!(latency.routing_ms, 0.0);
assert_eq!(latency.attention_ms, 0.0);
assert_eq!(latency.generation_ms, 0.0);
assert_eq!(latency.total_ms, 0.0);
}
#[test]
fn test_latency_breakdown_compute_total() {
let mut latency = LatencyBreakdown {
embedding_ms: 10.0,
retrieval_ms: 5.0,
routing_ms: 2.0,
attention_ms: 50.0,
generation_ms: 100.0,
total_ms: 0.0,
};
latency.compute_total();
assert_eq!(latency.total_ms, 167.0);
}
#[test]
fn test_latency_breakdown_exceeds_threshold() {
let latency = LatencyBreakdown {
embedding_ms: 10.0,
retrieval_ms: 5.0,
routing_ms: 2.0,
attention_ms: 50.0,
generation_ms: 100.0,
total_ms: 167.0,
};
assert!(latency.exceeds_threshold(100.0));
assert!(!latency.exceeds_threshold(200.0));
}
#[test]
fn test_latency_breakdown_slowest_component() {
let latency = LatencyBreakdown {
embedding_ms: 10.0,
retrieval_ms: 5.0,
routing_ms: 2.0,
attention_ms: 50.0,
generation_ms: 100.0,
total_ms: 167.0,
};
let (name, value) = latency.slowest_component();
assert_eq!(name, "generation");
assert_eq!(value, 100.0);
}
#[test]
fn test_latency_breakdown_slowest_component_attention() {
let latency = LatencyBreakdown {
embedding_ms: 10.0,
retrieval_ms: 5.0,
routing_ms: 2.0,
attention_ms: 200.0,
generation_ms: 100.0,
total_ms: 317.0,
};
let (name, _) = latency.slowest_component();
assert_eq!(name, "attention");
}
#[test]
fn test_latency_breakdown_all_zeros() {
let latency = LatencyBreakdown::default();
let (_, value) = latency.slowest_component();
assert_eq!(value, 0.0);
}
// ============================================================================
// RoutingDecision Tests
// ============================================================================
#[test]
fn test_routing_decision_default() {
let decision = RoutingDecision::default();
assert_eq!(decision.model, ModelSize::Small);
assert_eq!(decision.context_size, 0);
assert!((decision.temperature - 0.7).abs() < 0.01);
assert!((decision.top_p - 0.9).abs() < 0.01);
assert!((decision.confidence - 0.5).abs() < 0.01);
assert_eq!(decision.model_probs, [0.25, 0.25, 0.25, 0.25]);
}
#[test]
fn test_routing_decision_custom() {
let decision = RoutingDecision {
model: ModelSize::Large,
context_size: 4096,
temperature: 0.3,
top_p: 0.95,
confidence: 0.85,
model_probs: [0.1, 0.1, 0.2, 0.6],
};
assert_eq!(decision.model, ModelSize::Large);
assert_eq!(decision.context_size, 4096);
assert!((decision.confidence - 0.85).abs() < 0.01);
// Probabilities should sum to 1.0
let sum: f32 = decision.model_probs.iter().sum();
assert!((sum - 1.0).abs() < 0.01);
}
#[test]
fn test_routing_decision_serialization() {
let decision = RoutingDecision {
model: ModelSize::Medium,
context_size: 2048,
temperature: 0.5,
top_p: 0.85,
confidence: 0.7,
model_probs: [0.2, 0.3, 0.3, 0.2],
};
// Test that serialization works
let json = serde_json::to_string(&decision).unwrap();
assert!(json.contains("context_size"));
// Test roundtrip
let deserialized: RoutingDecision = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.context_size, 2048);
assert!((deserialized.temperature - 0.5).abs() < 0.01);
}
// ============================================================================
// WitnessEntry Tests
// ============================================================================
#[test]
fn test_witness_entry_new() {
let entry = WitnessEntry::new(
"session-123".to_string(),
vec![0.1; 768],
RoutingDecision::default(),
);
assert!(!entry.request_id.is_nil());
assert_eq!(entry.session_id, "session-123");
assert_eq!(entry.query_embedding.len(), 768);
assert_eq!(entry.model_used, ModelSize::Small);
assert_eq!(entry.quality_score, 0.0);
assert!(entry.is_success());
assert!(entry.error.is_none());
}
#[test]
fn test_witness_entry_with_quality() {
let entry = WitnessEntry::new(
"session-456".to_string(),
vec![0.5; 768],
RoutingDecision::default(),
)
.with_quality(0.85);
assert!((entry.quality_score - 0.85).abs() < 0.01);
assert!(entry.meets_quality_threshold(0.8));
assert!(!entry.meets_quality_threshold(0.9));
}
#[test]
fn test_witness_entry_with_latency() {
let latency = LatencyBreakdown {
embedding_ms: 5.0,
retrieval_ms: 10.0,
routing_ms: 1.0,
attention_ms: 30.0,
generation_ms: 50.0,
total_ms: 96.0,
};
let entry = WitnessEntry::new(
"session-789".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
)
.with_latency(latency);
assert_eq!(entry.latency.total_ms, 96.0);
assert_eq!(entry.latency.generation_ms, 50.0);
}
#[test]
fn test_witness_entry_with_error() {
use crate::types::ErrorInfo;
let error = ErrorInfo {
code: "TIMEOUT".to_string(),
message: "Request timed out".to_string(),
stack_trace: None,
recovery_attempted: false,
};
let entry = WitnessEntry::new(
"session-error".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
)
.with_error(error);
assert!(!entry.is_success());
assert!(entry.error.is_some());
assert_eq!(entry.error.as_ref().unwrap().code, "TIMEOUT");
}
#[test]
fn test_witness_entry_quality_threshold_edge_cases() {
let entry_zero = WitnessEntry::new(
"session".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
)
.with_quality(0.0);
assert!(entry_zero.meets_quality_threshold(0.0));
assert!(!entry_zero.meets_quality_threshold(0.1));
let entry_one = WitnessEntry::new(
"session".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
)
.with_quality(1.0);
assert!(entry_one.meets_quality_threshold(1.0));
assert!(entry_one.meets_quality_threshold(0.99));
}
#[test]
fn test_witness_entry_timestamp() {
let before = chrono::Utc::now();
let entry = WitnessEntry::new(
"session".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
);
let after = chrono::Utc::now();
assert!(entry.timestamp >= before);
assert!(entry.timestamp <= after);
}
#[test]
fn test_witness_entry_unique_ids() {
let entry1 = WitnessEntry::new("s1".to_string(), vec![0.0; 768], RoutingDecision::default());
let entry2 = WitnessEntry::new("s1".to_string(), vec![0.0; 768], RoutingDecision::default());
// Each entry should have unique request_id
assert_ne!(entry1.request_id, entry2.request_id);
}
// ============================================================================
// AsyncWriteConfig Tests
// ============================================================================
#[test]
fn test_async_write_config_default() {
let config = AsyncWriteConfig::default();
assert_eq!(config.max_batch_size, 100);
assert_eq!(config.max_wait_ms, 1000);
assert_eq!(config.max_queue_depth, 10000);
assert!(!config.fsync_critical);
assert_eq!(config.flush_interval_ms, 1000);
}
#[test]
fn test_async_write_config_custom() {
let config = AsyncWriteConfig {
max_batch_size: 50,
max_wait_ms: 500,
max_queue_depth: 5000,
fsync_critical: true,
flush_interval_ms: 250,
};
assert_eq!(config.max_batch_size, 50);
assert!(config.fsync_critical);
}
// ============================================================================
// WritebackQueue Behavior Tests (Indirect via WitnessLog)
// ============================================================================
#[test]
fn test_writeback_batching_behavior() {
// Simulate the batching behavior
let max_batch_size = 10;
let mut batch: Vec<WitnessEntry> = Vec::new();
// Add entries
for i in 0..15 {
let entry = WitnessEntry::new(
format!("session-{}", i),
vec![i as f32 / 100.0; 768],
RoutingDecision::default(),
);
batch.push(entry);
// Check if batch should be flushed
if batch.len() >= max_batch_size {
assert_eq!(batch.len(), 10);
batch.clear();
}
}
// Remaining entries
assert_eq!(batch.len(), 5);
}
#[test]
fn test_backpressure_behavior() {
// Simulate backpressure when queue is full
let max_queue_depth = 100;
let mut queue_len = 0;
let mut dropped = 0;
for _ in 0..150 {
if queue_len < max_queue_depth {
queue_len += 1;
} else {
dropped += 1;
}
}
assert_eq!(queue_len, 100);
assert_eq!(dropped, 50);
}
#[test]
fn test_time_based_flush_simulation() {
use std::thread::sleep;
use std::time::Duration;
let max_wait = Duration::from_millis(100);
let start = Instant::now();
// Simulate waiting for time-based flush
sleep(Duration::from_millis(50));
assert!(start.elapsed() < max_wait, "Not yet time to flush");
sleep(Duration::from_millis(60));
assert!(start.elapsed() >= max_wait, "Should flush by now");
}
// ============================================================================
// WitnessLog Stats Tests
// ============================================================================
#[test]
fn test_witness_log_stats_structure() {
use crate::witness_log::WitnessLogStats;
let stats = WitnessLogStats {
total_entries: 1000,
success_count: 950,
error_count: 50,
success_rate: 0.95,
pending_writes: 25,
dropped_entries: 0,
background_running: false,
};
assert_eq!(stats.total_entries, 1000);
assert_eq!(stats.success_count + stats.error_count, 1000);
assert!((stats.success_rate - 0.95).abs() < 0.01);
}
#[test]
fn test_witness_log_stats_default() {
use crate::witness_log::WitnessLogStats;
let stats = WitnessLogStats::default();
assert_eq!(stats.total_entries, 0);
assert_eq!(stats.success_count, 0);
assert_eq!(stats.error_count, 0);
assert_eq!(stats.success_rate, 0.0);
assert_eq!(stats.pending_writes, 0);
assert_eq!(stats.dropped_entries, 0);
assert!(!stats.background_running);
}
#[test]
fn test_witness_log_stats_serialization() {
use crate::witness_log::WitnessLogStats;
let stats = WitnessLogStats {
total_entries: 100,
success_count: 95,
error_count: 5,
success_rate: 0.95,
pending_writes: 10,
dropped_entries: 0,
background_running: false,
};
let json = serde_json::to_string(&stats).unwrap();
assert!(json.contains("total_entries"));
assert!(json.contains("success_rate"));
let deserialized: WitnessLogStats = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.total_entries, 100);
}
// ============================================================================
// Concurrent Access Simulation Tests
// ============================================================================
#[test]
fn test_concurrent_entry_creation() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
// Spawn multiple threads creating entries
for _ in 0..10 {
let counter_clone = Arc::clone(&counter);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let _ = WitnessEntry::new(
"session".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
);
counter_clone.fetch_add(1, Ordering::Relaxed);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(counter.load(Ordering::Relaxed), 1000);
}
#[test]
fn test_unique_ids_concurrent() {
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use std::thread;
let ids = Arc::new(Mutex::new(HashSet::new()));
let mut handles = vec![];
for _ in 0..10 {
let ids_clone = Arc::clone(&ids);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let entry = WitnessEntry::new(
"session".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
);
ids_clone.lock().unwrap().insert(entry.request_id);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let unique_count = ids.lock().unwrap().len();
assert_eq!(unique_count, 1000, "All IDs should be unique");
}
// ============================================================================
// Error Handling Tests
// ============================================================================
#[test]
fn test_witness_entry_error_chain() {
use crate::types::ErrorInfo;
let entry = WitnessEntry::new(
"session".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
)
.with_quality(0.5)
.with_latency(LatencyBreakdown {
embedding_ms: 10.0,
retrieval_ms: 5.0,
routing_ms: 2.0,
attention_ms: 30.0,
generation_ms: 50.0,
total_ms: 97.0,
})
.with_error(ErrorInfo {
code: "GEN_FAILED".to_string(),
message: "Generation failed".to_string(),
stack_trace: None,
recovery_attempted: false,
});
// All builder methods should work together
assert!((entry.quality_score - 0.5).abs() < 0.01);
assert_eq!(entry.latency.total_ms, 97.0);
assert!(!entry.is_success());
assert_eq!(entry.error.as_ref().unwrap().code, "GEN_FAILED");
}
// ============================================================================
// Tag Filtering Tests
// ============================================================================
#[test]
fn test_witness_entry_tags() {
let mut entry = WitnessEntry::new(
"session".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
);
entry.tags.push("production".to_string());
entry.tags.push("high-priority".to_string());
entry.tags.push("api-v2".to_string());
assert_eq!(entry.tags.len(), 3);
assert!(entry.tags.contains(&"production".to_string()));
}
#[test]
fn test_witness_entry_filter_by_tag() {
let entries: Vec<WitnessEntry> = (0..10)
.map(|i| {
let mut entry = WitnessEntry::new(
format!("session-{}", i),
vec![0.0; 768],
RoutingDecision::default(),
);
if i % 2 == 0 {
entry.tags.push("even".to_string());
} else {
entry.tags.push("odd".to_string());
}
entry
})
.collect();
let even_entries: Vec<_> = entries
.iter()
.filter(|e| e.tags.contains(&"even".to_string()))
.collect();
assert_eq!(even_entries.len(), 5);
}
// ============================================================================
// Performance Measurement Tests
// ============================================================================
#[test]
fn test_entry_creation_performance() {
let iterations = 10000;
let start = Instant::now();
for _ in 0..iterations {
let _ = WitnessEntry::new(
"session".to_string(),
vec![0.0; 768],
RoutingDecision::default(),
);
}
let duration = start.elapsed();
let avg_us = duration.as_micros() as f64 / iterations as f64;
assert!(
avg_us < 100.0,
"Entry creation should be fast: {}us",
avg_us
);
}
#[test]
fn test_latency_breakdown_performance() {
let iterations = 100000;
let start = Instant::now();
for _ in 0..iterations {
let mut latency = LatencyBreakdown {
embedding_ms: 10.0,
retrieval_ms: 5.0,
routing_ms: 2.0,
attention_ms: 50.0,
generation_ms: 100.0,
total_ms: 0.0,
};
latency.compute_total();
let _ = latency.slowest_component();
}
let duration = start.elapsed();
let avg_ns = duration.as_nanos() as f64 / iterations as f64;
assert!(
avg_ns < 1000.0,
"Latency operations should be fast: {}ns",
avg_ns
);
}
// ============================================================================
// Edge Cases
// ============================================================================
#[test]
fn test_empty_embedding() {
let entry = WitnessEntry::new(
"session".to_string(),
vec![], // Empty embedding
RoutingDecision::default(),
);
assert!(entry.query_embedding.is_empty());
}
#[test]
fn test_large_embedding() {
let large_embedding = vec![0.1; 4096]; // 4K dimension embedding
let entry = WitnessEntry::new(
"session".to_string(),
large_embedding.clone(),
RoutingDecision::default(),
);
assert_eq!(entry.query_embedding.len(), 4096);
}
#[test]
fn test_empty_session_id() {
let entry = WitnessEntry::new("".to_string(), vec![0.0; 768], RoutingDecision::default());
assert!(entry.session_id.is_empty());
}
#[test]
fn test_long_session_id() {
let long_id = "x".repeat(1000);
let entry = WitnessEntry::new(long_id.clone(), vec![0.0; 768], RoutingDecision::default());
assert_eq!(entry.session_id.len(), 1000);
}
#[test]
fn test_extreme_latency_values() {
let latency = LatencyBreakdown {
embedding_ms: f32::MAX / 10.0,
retrieval_ms: 0.0,
routing_ms: 0.0,
attention_ms: 0.0,
generation_ms: 0.0,
total_ms: 0.0,
};
assert!(latency.embedding_ms.is_finite());
}
#[test]
fn test_zero_confidence_routing() {
let decision = RoutingDecision {
model: ModelSize::Tiny,
confidence: 0.0,
..Default::default()
};
assert_eq!(decision.confidence, 0.0);
}
#[test]
fn test_max_confidence_routing() {
let decision = RoutingDecision {
model: ModelSize::Large,
confidence: 1.0,
..Default::default()
};
assert_eq!(decision.confidence, 1.0);
}