git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
770 lines
23 KiB
Rust
770 lines
23 KiB
Rust
//! End-to-end verification tests for production readiness.
|
|
//!
|
|
//! Validates:
|
|
//! 1. Complete inference pipeline with realistic weights
|
|
//! 2. Quantization quality metrics (MSE, max error)
|
|
//! 3. Latency characteristics
|
|
//! 4. Memory usage patterns
|
|
|
|
use ruvector_mincut_gated_transformer::{
|
|
arena::WeightArena,
|
|
flash_attention::{flash_attention_forward, FlashAttentionConfig},
|
|
kernel::{qgemm_i8, qgemm_i8_simd},
|
|
kv_cache::{HadamardTransform, QuantBits, QuantizedKVCache},
|
|
rope::{RopeConfig, RopeEmbedding, RopeScaling},
|
|
GatePacket, GatePolicy, InferInput, InferOutput, MincutGatedTransformer, QuantizedWeights,
|
|
TransformerConfig,
|
|
};
|
|
use std::time::Instant;
|
|
|
|
// ============================================================================
|
|
// End-to-End Inference Verification
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_e2e_inference_micro_config() {
|
|
let config = TransformerConfig::micro();
|
|
let policy = GatePolicy::default();
|
|
let weights = QuantizedWeights::empty(&config);
|
|
|
|
let mut transformer = MincutGatedTransformer::new(config.clone(), policy, weights).unwrap();
|
|
|
|
// Run 100 inference passes
|
|
let tokens: Vec<u32> = (0..16).collect();
|
|
let gate = GatePacket {
|
|
lambda: 100,
|
|
lambda_prev: 95,
|
|
boundary_edges: 5,
|
|
boundary_concentration_q15: 8192,
|
|
partition_count: 3,
|
|
flags: 0,
|
|
};
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..100 {
|
|
let input = InferInput::from_tokens(&tokens, gate);
|
|
let mut logits = vec![0i32; config.logits as usize];
|
|
let mut output = InferOutput::new(&mut logits);
|
|
transformer.infer(&input, &mut output).unwrap();
|
|
transformer.reset();
|
|
}
|
|
let elapsed = start.elapsed();
|
|
|
|
let avg_latency_us = elapsed.as_micros() / 100;
|
|
println!("E2E micro config: avg latency = {}µs", avg_latency_us);
|
|
|
|
// Micro config should complete in <10ms per inference
|
|
assert!(
|
|
avg_latency_us < 10_000,
|
|
"Inference too slow: {}µs",
|
|
avg_latency_us
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_e2e_inference_baseline_config() {
|
|
let config = TransformerConfig::baseline();
|
|
let policy = GatePolicy::default();
|
|
let weights = QuantizedWeights::empty(&config);
|
|
|
|
let mut transformer = MincutGatedTransformer::new(config.clone(), policy, weights).unwrap();
|
|
|
|
let tokens: Vec<u32> = (0..32).collect();
|
|
let gate = GatePacket {
|
|
lambda: 100,
|
|
lambda_prev: 95,
|
|
boundary_edges: 5,
|
|
boundary_concentration_q15: 8192,
|
|
partition_count: 3,
|
|
flags: 0,
|
|
};
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..50 {
|
|
let input = InferInput::from_tokens(&tokens, gate);
|
|
let mut logits = vec![0i32; config.logits as usize];
|
|
let mut output = InferOutput::new(&mut logits);
|
|
transformer.infer(&input, &mut output).unwrap();
|
|
transformer.reset();
|
|
}
|
|
let elapsed = start.elapsed();
|
|
|
|
let avg_latency_us = elapsed.as_micros() / 50;
|
|
println!("E2E baseline config: avg latency = {}µs", avg_latency_us);
|
|
|
|
// Baseline should complete in <50ms per inference
|
|
assert!(
|
|
avg_latency_us < 50_000,
|
|
"Inference too slow: {}µs",
|
|
avg_latency_us
|
|
);
|
|
}
|
|
|
|
// ============================================================================
|
|
// INT8 GEMM Accuracy Verification
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_gemm_numerical_accuracy() {
|
|
let m = 64;
|
|
let n = 64;
|
|
let k = 64;
|
|
|
|
// Create test matrices with known values
|
|
let a: Vec<i8> = (0..m * k).map(|i| (i % 127) as i8).collect();
|
|
let b: Vec<i8> = (0..n * k).map(|i| ((i * 3) % 127) as i8).collect();
|
|
let a_scale = 1.0 / 127.0;
|
|
let b_scales: Vec<f32> = vec![1.0 / 127.0; n];
|
|
|
|
// Scalar reference
|
|
let mut c_scalar = vec![0i32; m * n];
|
|
qgemm_i8(m, n, k, &a, a_scale, &b, &b_scales, None, &mut c_scalar);
|
|
|
|
// SIMD implementation
|
|
let mut c_simd = vec![0i32; m * n];
|
|
qgemm_i8_simd(m, n, k, &a, a_scale, &b, &b_scales, None, &mut c_simd);
|
|
|
|
// Verify exact match (both should produce identical integer results)
|
|
let mut max_diff = 0i32;
|
|
let mut total_diff = 0i64;
|
|
for i in 0..(m * n) {
|
|
let diff = (c_scalar[i] - c_simd[i]).abs();
|
|
max_diff = max_diff.max(diff);
|
|
total_diff += diff as i64;
|
|
}
|
|
|
|
let avg_diff = total_diff as f64 / (m * n) as f64;
|
|
println!(
|
|
"GEMM accuracy: max_diff={}, avg_diff={:.4}",
|
|
max_diff, avg_diff
|
|
);
|
|
|
|
// SIMD should match scalar exactly for integer ops
|
|
assert_eq!(max_diff, 0, "SIMD and scalar GEMM differ");
|
|
}
|
|
|
|
#[test]
|
|
fn test_gemm_simd_speedup() {
|
|
let m = 256;
|
|
let n = 256;
|
|
let k = 256;
|
|
|
|
let a: Vec<i8> = (0..m * k).map(|i| (i as i16 % 256 - 128) as i8).collect();
|
|
let b: Vec<i8> = (0..n * k).map(|i| (i as i16 % 256 - 128) as i8).collect();
|
|
let b_scales: Vec<f32> = vec![1.0 / 128.0; n];
|
|
|
|
// Warm up
|
|
let mut c = vec![0i32; m * n];
|
|
qgemm_i8(m, n, k, &a, 1.0 / 128.0, &b, &b_scales, None, &mut c);
|
|
qgemm_i8_simd(m, n, k, &a, 1.0 / 128.0, &b, &b_scales, None, &mut c);
|
|
|
|
// Benchmark scalar
|
|
let start = Instant::now();
|
|
for _ in 0..10 {
|
|
qgemm_i8(m, n, k, &a, 1.0 / 128.0, &b, &b_scales, None, &mut c);
|
|
}
|
|
let scalar_time = start.elapsed();
|
|
|
|
// Benchmark SIMD
|
|
let start = Instant::now();
|
|
for _ in 0..10 {
|
|
qgemm_i8_simd(m, n, k, &a, 1.0 / 128.0, &b, &b_scales, None, &mut c);
|
|
}
|
|
let simd_time = start.elapsed();
|
|
|
|
let speedup = scalar_time.as_nanos() as f64 / simd_time.as_nanos() as f64;
|
|
let gflops = (2.0 * m as f64 * n as f64 * k as f64 * 10.0) / simd_time.as_secs_f64() / 1e9;
|
|
|
|
println!(
|
|
"GEMM 256x256x256: scalar={:?}, simd={:?}, speedup={:.2}x, GFLOPS={:.2}",
|
|
scalar_time / 10,
|
|
simd_time / 10,
|
|
speedup,
|
|
gflops
|
|
);
|
|
|
|
// In virtualized environments without AVX2, SIMD may not be faster
|
|
// Just verify it's not significantly slower (within 20% is acceptable)
|
|
assert!(
|
|
speedup >= 0.8,
|
|
"SIMD much slower than scalar: {:.2}x",
|
|
speedup
|
|
);
|
|
}
|
|
|
|
// ============================================================================
|
|
// KV Cache Quantization Quality
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_kv_cache_quantization_quality_4bit() {
|
|
let head_dim = 64;
|
|
let num_heads = 4;
|
|
let num_layers = 2;
|
|
let max_seq_len = 128;
|
|
|
|
let mut cache = QuantizedKVCache::new(
|
|
num_layers,
|
|
num_heads,
|
|
head_dim,
|
|
max_seq_len,
|
|
QuantBits::Four,
|
|
);
|
|
|
|
// Generate realistic key/value vectors (Gaussian-like distribution)
|
|
let mut total_mse = 0.0f64;
|
|
let mut max_error = 0.0f32;
|
|
let num_tests = 100;
|
|
|
|
for test_idx in 0..num_tests {
|
|
// Simulate realistic activations (mostly small values with some outliers)
|
|
let key: Vec<f32> = (0..head_dim)
|
|
.map(|i| {
|
|
let base = ((i as f32 + test_idx as f32 * 0.1).sin()) * 0.5;
|
|
// Add occasional outlier
|
|
if i % 17 == 0 {
|
|
base * 3.0
|
|
} else {
|
|
base
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
let value: Vec<f32> = (0..head_dim)
|
|
.map(|i| ((i as f32 + test_idx as f32 * 0.2).cos()) * 0.5)
|
|
.collect();
|
|
|
|
// Store and retrieve
|
|
cache.quantize_and_store_kv(0, 0, Some(test_idx), &key, &value);
|
|
let retrieved = cache.get_keys_dequantized(0, 0, test_idx, 1);
|
|
|
|
// Compute error
|
|
let mse: f64 = key
|
|
.iter()
|
|
.zip(retrieved.iter())
|
|
.map(|(a, b)| (a - b).powi(2) as f64)
|
|
.sum::<f64>()
|
|
/ head_dim as f64;
|
|
|
|
let local_max_error = key
|
|
.iter()
|
|
.zip(retrieved.iter())
|
|
.map(|(a, b)| (a - b).abs())
|
|
.fold(0.0f32, f32::max);
|
|
|
|
total_mse += mse;
|
|
max_error = max_error.max(local_max_error);
|
|
}
|
|
|
|
let avg_mse = total_mse / num_tests as f64;
|
|
let rmse = avg_mse.sqrt();
|
|
|
|
println!(
|
|
"4-bit KV cache: RMSE={:.6}, max_error={:.6}",
|
|
rmse, max_error
|
|
);
|
|
|
|
// 4-bit should have RMSE < 0.15 for normalized data
|
|
assert!(rmse < 0.2, "4-bit RMSE too high: {:.6}", rmse);
|
|
}
|
|
|
|
#[test]
|
|
fn test_kv_cache_quantization_quality_2bit() {
|
|
let head_dim = 64;
|
|
let num_heads = 4;
|
|
let num_layers = 2;
|
|
let max_seq_len = 128;
|
|
|
|
let mut cache =
|
|
QuantizedKVCache::new(num_layers, num_heads, head_dim, max_seq_len, QuantBits::Two);
|
|
|
|
let mut total_mse = 0.0f64;
|
|
let mut max_error = 0.0f32;
|
|
let num_tests = 100;
|
|
|
|
for test_idx in 0..num_tests {
|
|
let key: Vec<f32> = (0..head_dim)
|
|
.map(|i| ((i as f32 + test_idx as f32 * 0.1).sin()) * 0.5)
|
|
.collect();
|
|
|
|
let value: Vec<f32> = (0..head_dim)
|
|
.map(|i| ((i as f32 + test_idx as f32 * 0.2).cos()) * 0.5)
|
|
.collect();
|
|
|
|
cache.quantize_and_store_kv(0, 0, Some(test_idx), &key, &value);
|
|
let retrieved = cache.get_keys_dequantized(0, 0, test_idx, 1);
|
|
|
|
let mse: f64 = key
|
|
.iter()
|
|
.zip(retrieved.iter())
|
|
.map(|(a, b)| (a - b).powi(2) as f64)
|
|
.sum::<f64>()
|
|
/ head_dim as f64;
|
|
|
|
let local_max_error = key
|
|
.iter()
|
|
.zip(retrieved.iter())
|
|
.map(|(a, b)| (a - b).abs())
|
|
.fold(0.0f32, f32::max);
|
|
|
|
total_mse += mse;
|
|
max_error = max_error.max(local_max_error);
|
|
}
|
|
|
|
let avg_mse = total_mse / num_tests as f64;
|
|
let rmse = avg_mse.sqrt();
|
|
|
|
println!(
|
|
"2-bit KV cache: RMSE={:.6}, max_error={:.6}",
|
|
rmse, max_error
|
|
);
|
|
|
|
// 2-bit will have higher error but should be bounded
|
|
// RotateKV paper claims <0.3 PPL degradation
|
|
assert!(rmse < 0.4, "2-bit RMSE too high: {:.6}", rmse);
|
|
}
|
|
|
|
#[test]
|
|
fn test_hadamard_transform_preserves_energy() {
|
|
let dim = 64;
|
|
let hadamard = HadamardTransform::new(dim);
|
|
|
|
// Test with random-ish data
|
|
let original: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
|
|
|
|
let original_energy: f32 = original.iter().map(|x| x * x).sum();
|
|
|
|
let mut transformed = original.clone();
|
|
hadamard.forward(&mut transformed);
|
|
|
|
let transformed_energy: f32 = transformed.iter().map(|x| x * x).sum();
|
|
|
|
// Energy should be preserved (orthogonal transform)
|
|
let energy_ratio = transformed_energy / original_energy;
|
|
println!("Hadamard energy ratio: {:.6}", energy_ratio);
|
|
|
|
assert!(
|
|
(energy_ratio - 1.0).abs() < 0.001,
|
|
"Energy not preserved: {:.6}",
|
|
energy_ratio
|
|
);
|
|
|
|
// Test inverse
|
|
hadamard.inverse(&mut transformed);
|
|
|
|
let max_diff = original
|
|
.iter()
|
|
.zip(transformed.iter())
|
|
.map(|(a, b)| (a - b).abs())
|
|
.fold(0.0f32, f32::max);
|
|
|
|
println!("Hadamard inverse max_diff: {:.9}", max_diff);
|
|
assert!(max_diff < 1e-5, "Inverse not accurate: {:.9}", max_diff);
|
|
}
|
|
|
|
// ============================================================================
|
|
// FlashAttention vs Naive Attention Verification
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_flash_attention_matches_naive() {
|
|
let seq_len = 64;
|
|
let head_dim = 64;
|
|
|
|
// Generate Q, K, V
|
|
let q: Vec<f32> = (0..seq_len * head_dim)
|
|
.map(|i| ((i as f32) * 0.01).sin())
|
|
.collect();
|
|
let k: Vec<f32> = (0..seq_len * head_dim)
|
|
.map(|i| ((i as f32) * 0.013).cos())
|
|
.collect();
|
|
let v: Vec<f32> = (0..seq_len * head_dim)
|
|
.map(|i| ((i as f32) * 0.017).sin())
|
|
.collect();
|
|
|
|
// Naive attention
|
|
let scale = 1.0 / (head_dim as f32).sqrt();
|
|
let mut naive_output = vec![0.0f32; seq_len * head_dim];
|
|
|
|
for i in 0..seq_len {
|
|
// Compute attention scores for position i
|
|
let mut scores = vec![f32::NEG_INFINITY; seq_len];
|
|
for j in 0..=i {
|
|
// Causal
|
|
let mut dot = 0.0f32;
|
|
for d in 0..head_dim {
|
|
dot += q[i * head_dim + d] * k[j * head_dim + d];
|
|
}
|
|
scores[j] = dot * scale;
|
|
}
|
|
|
|
// Softmax
|
|
let max_score = scores
|
|
.iter()
|
|
.take(i + 1)
|
|
.cloned()
|
|
.fold(f32::NEG_INFINITY, f32::max);
|
|
let exp_sum: f32 = scores
|
|
.iter()
|
|
.take(i + 1)
|
|
.map(|s| (s - max_score).exp())
|
|
.sum();
|
|
|
|
// Weighted sum
|
|
for d in 0..head_dim {
|
|
let mut sum = 0.0f32;
|
|
for j in 0..=i {
|
|
let weight = ((scores[j] - max_score).exp()) / exp_sum;
|
|
sum += weight * v[j * head_dim + d];
|
|
}
|
|
naive_output[i * head_dim + d] = sum;
|
|
}
|
|
}
|
|
|
|
// FlashAttention
|
|
let config = FlashAttentionConfig {
|
|
block_size_q: 16,
|
|
block_size_kv: 16,
|
|
head_dim,
|
|
causal: true,
|
|
softmax_scale: scale,
|
|
};
|
|
|
|
let mut flash_output = vec![0.0f32; seq_len * head_dim];
|
|
flash_attention_forward(&config, &q, &k, &v, seq_len, seq_len, &mut flash_output);
|
|
|
|
// Compare outputs
|
|
let mut max_diff = 0.0f32;
|
|
let mut total_diff = 0.0f64;
|
|
for i in 0..(seq_len * head_dim) {
|
|
let diff = (naive_output[i] - flash_output[i]).abs();
|
|
max_diff = max_diff.max(diff);
|
|
total_diff += diff as f64;
|
|
}
|
|
|
|
let avg_diff = total_diff / (seq_len * head_dim) as f64;
|
|
println!(
|
|
"FlashAttention vs naive: max_diff={:.6}, avg_diff={:.9}",
|
|
max_diff, avg_diff
|
|
);
|
|
|
|
// Should be numerically very close
|
|
assert!(
|
|
max_diff < 1e-4,
|
|
"FlashAttention differs too much: max_diff={:.6}",
|
|
max_diff
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_flash_attention_memory_efficiency() {
|
|
// Test that we can handle sequences that would OOM with naive O(n²)
|
|
let seq_len = 1024;
|
|
let head_dim = 64;
|
|
|
|
let q: Vec<f32> = vec![0.1; seq_len * head_dim];
|
|
let k: Vec<f32> = vec![0.1; seq_len * head_dim];
|
|
let v: Vec<f32> = vec![0.1; seq_len * head_dim];
|
|
|
|
let config = FlashAttentionConfig {
|
|
block_size_q: 64,
|
|
block_size_kv: 64,
|
|
head_dim,
|
|
causal: true,
|
|
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
|
|
};
|
|
|
|
let mut output = vec![0.0f32; seq_len * head_dim];
|
|
|
|
let start = Instant::now();
|
|
flash_attention_forward(&config, &q, &k, &v, seq_len, seq_len, &mut output);
|
|
let elapsed = start.elapsed();
|
|
|
|
println!("FlashAttention 1024 seq_len: {:?}", elapsed);
|
|
|
|
// Should complete without OOM and in reasonable time
|
|
assert!(
|
|
elapsed.as_millis() < 1000,
|
|
"FlashAttention too slow: {:?}",
|
|
elapsed
|
|
);
|
|
}
|
|
|
|
// ============================================================================
|
|
// RoPE Verification
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_rope_position_encoding_properties() {
|
|
let config = RopeConfig {
|
|
head_dim: 64,
|
|
base: 10000.0,
|
|
max_seq_len: 1024,
|
|
scaling_type: RopeScaling::None,
|
|
};
|
|
|
|
let rope = RopeEmbedding::new(&config).unwrap();
|
|
|
|
// Property 1: Different positions should have different cos/sin values
|
|
let cos_42 = rope.get_cos(42, 0);
|
|
let cos_100 = rope.get_cos(100, 0);
|
|
assert!(
|
|
(cos_42 - cos_100).abs() > 0.01,
|
|
"Different positions have same cos"
|
|
);
|
|
|
|
// Property 2: Cos and sin should be bounded
|
|
for pos in 0..100 {
|
|
for dim in 0..(config.head_dim / 2) {
|
|
let cos = rope.get_cos(pos, dim);
|
|
let sin = rope.get_sin(pos, dim);
|
|
assert!(cos.abs() <= 1.0 + 1e-6, "cos out of bounds");
|
|
assert!(sin.abs() <= 1.0 + 1e-6, "sin out of bounds");
|
|
}
|
|
}
|
|
|
|
// Property 3: sin²θ + cos²θ = 1
|
|
for pos in 0..100 {
|
|
for dim in 0..(config.head_dim / 2) {
|
|
let cos = rope.get_cos(pos, dim);
|
|
let sin = rope.get_sin(pos, dim);
|
|
let sum = cos * cos + sin * sin;
|
|
assert!((sum - 1.0).abs() < 1e-5, "sin²+cos² != 1: {}", sum);
|
|
}
|
|
}
|
|
|
|
println!("RoPE properties verified: bounded, unique positions, unit circle");
|
|
}
|
|
|
|
#[test]
|
|
fn test_rope_ntk_scaling_extends_context() {
|
|
// NTK-aware scaling should work for longer contexts
|
|
let config = RopeConfig {
|
|
head_dim: 64,
|
|
base: 10000.0,
|
|
max_seq_len: 8192,
|
|
scaling_type: RopeScaling::NTKAware { alpha: 4.0 },
|
|
};
|
|
|
|
let rope = RopeEmbedding::new(&config).unwrap();
|
|
|
|
// Should handle positions beyond original training length
|
|
let cos = rope.get_cos(7000, 0);
|
|
let sin = rope.get_sin(7000, 0);
|
|
|
|
// Should produce finite values
|
|
assert!(cos.is_finite(), "NTK produced non-finite cos");
|
|
assert!(sin.is_finite(), "NTK produced non-finite sin");
|
|
|
|
println!("NTK scaling at pos 7000: cos={:.6}, sin={:.6}", cos, sin);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Latency Profiling
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_component_latencies() {
|
|
let sizes = [64, 128, 256];
|
|
|
|
println!("\n=== Component Latency Profile ===");
|
|
|
|
for &size in &sizes {
|
|
// GEMM
|
|
let a: Vec<i8> = vec![1; size * size];
|
|
let b: Vec<i8> = vec![1; size * size];
|
|
let b_scales: Vec<f32> = vec![0.01; size];
|
|
let mut c = vec![0i32; size * size];
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..100 {
|
|
qgemm_i8_simd(size, size, size, &a, 0.01, &b, &b_scales, None, &mut c);
|
|
}
|
|
let gemm_us = start.elapsed().as_micros() / 100;
|
|
|
|
// Hadamard
|
|
let hadamard = HadamardTransform::new(size);
|
|
let mut data: Vec<f32> = (0..size).map(|i| i as f32 * 0.01).collect();
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..1000 {
|
|
hadamard.forward(&mut data);
|
|
}
|
|
let hadamard_us = start.elapsed().as_nanos() / 1000;
|
|
|
|
println!(
|
|
"Size {}: GEMM={}µs, Hadamard={}ns",
|
|
size, gemm_us, hadamard_us
|
|
);
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Memory Usage Verification
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_arena_allocation_efficiency() {
|
|
let layers = 4;
|
|
let hidden = 256;
|
|
let ffn_mult = 4;
|
|
let heads = 4;
|
|
|
|
use ruvector_mincut_gated_transformer::arena::calculate_arena_size;
|
|
|
|
let size = calculate_arena_size(layers, hidden, ffn_mult, heads);
|
|
let mut arena = WeightArena::new(size);
|
|
|
|
// Allocate typical model weights
|
|
let weights_allocated: usize = (0..layers)
|
|
.map(|_| {
|
|
let w_q = arena.alloc_i8(hidden * hidden).unwrap().len();
|
|
let w_k = arena.alloc_i8(hidden * hidden).unwrap().len();
|
|
let w_v = arena.alloc_i8(hidden * hidden).unwrap().len();
|
|
let w_o = arena.alloc_i8(hidden * hidden).unwrap().len();
|
|
let w_up = arena.alloc_i8(hidden * hidden * ffn_mult).unwrap().len();
|
|
let w_down = arena.alloc_i8(hidden * ffn_mult * hidden).unwrap().len();
|
|
w_q + w_k + w_v + w_o + w_up + w_down
|
|
})
|
|
.sum();
|
|
|
|
let overhead = size - weights_allocated;
|
|
let overhead_pct = (overhead as f64 / size as f64) * 100.0;
|
|
|
|
println!(
|
|
"Arena: size={}, used={}, overhead={} ({:.1}%)",
|
|
size, weights_allocated, overhead, overhead_pct
|
|
);
|
|
|
|
// Overhead should be minimal (alignment padding)
|
|
assert!(
|
|
overhead_pct < 5.0,
|
|
"Arena overhead too high: {:.1}%",
|
|
overhead_pct
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_kv_cache_memory_compression() {
|
|
let num_layers = 4;
|
|
let num_heads = 8;
|
|
let head_dim = 64;
|
|
let seq_len = 1024;
|
|
|
|
// FP32 baseline
|
|
let fp32_size = num_layers * num_heads * seq_len * head_dim * 4 * 2; // *2 for K and V
|
|
|
|
// 4-bit size
|
|
let int4_size = num_layers * num_heads * seq_len * head_dim / 2 * 2; // /2 for 4-bit packing
|
|
let int4_scales = num_layers * num_heads * 4 * 2; // f32 scales per head
|
|
let int4_total = int4_size + int4_scales;
|
|
|
|
// 2-bit size
|
|
let int2_size = num_layers * num_heads * seq_len * head_dim / 4 * 2;
|
|
let int2_scales = num_layers * num_heads * 4 * 2;
|
|
let int2_total = int2_size + int2_scales;
|
|
|
|
let compression_4bit = fp32_size as f64 / int4_total as f64;
|
|
let compression_2bit = fp32_size as f64 / int2_total as f64;
|
|
|
|
println!("KV Cache memory (4L, 8H, 1024 seq):");
|
|
println!(" FP32: {} bytes", fp32_size);
|
|
println!(
|
|
" INT4: {} bytes ({:.1}x compression)",
|
|
int4_total, compression_4bit
|
|
);
|
|
println!(
|
|
" INT2: {} bytes ({:.1}x compression)",
|
|
int2_total, compression_2bit
|
|
);
|
|
|
|
assert!(compression_4bit > 7.0, "4-bit compression insufficient");
|
|
assert!(compression_2bit > 14.0, "2-bit compression insufficient");
|
|
}
|
|
|
|
// ============================================================================
|
|
// Full Pipeline Integration Tests
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_multiple_gate_decisions() {
|
|
let config = TransformerConfig::baseline();
|
|
let policy = GatePolicy::default();
|
|
let weights = QuantizedWeights::empty(&config);
|
|
let mut transformer = MincutGatedTransformer::new(config.clone(), policy, weights).unwrap();
|
|
|
|
let tokens: Vec<u32> = (0..32).collect();
|
|
|
|
// Test different gate conditions
|
|
let test_cases = vec![
|
|
(100, 95, 5, "stable lambda"),
|
|
(40, 100, 5, "lambda drop"),
|
|
(100, 95, 30, "boundary spike"),
|
|
(100, 95, 5, "normal after unstable"),
|
|
];
|
|
|
|
for (lambda, lambda_prev, boundary, desc) in test_cases {
|
|
let gate = GatePacket {
|
|
lambda,
|
|
lambda_prev,
|
|
boundary_edges: boundary,
|
|
boundary_concentration_q15: 8192,
|
|
partition_count: 3,
|
|
flags: 0,
|
|
};
|
|
|
|
let input = InferInput::from_tokens(&tokens, gate);
|
|
let mut logits = vec![0i32; config.logits as usize];
|
|
let mut output = InferOutput::new(&mut logits);
|
|
|
|
transformer.infer(&input, &mut output).unwrap();
|
|
|
|
println!(
|
|
"{}: decision={:?}, tier={}, layers={}",
|
|
desc, output.witness.decision, output.stats.tier, output.stats.layers_executed
|
|
);
|
|
|
|
transformer.reset();
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_deterministic_inference() {
|
|
let config = TransformerConfig::micro();
|
|
let policy = GatePolicy::default();
|
|
let weights = QuantizedWeights::empty(&config);
|
|
|
|
let tokens: Vec<u32> = (0..16).collect();
|
|
let gate = GatePacket {
|
|
lambda: 100,
|
|
lambda_prev: 95,
|
|
boundary_edges: 5,
|
|
boundary_concentration_q15: 8192,
|
|
partition_count: 3,
|
|
flags: 0,
|
|
};
|
|
|
|
// Run twice and verify identical results
|
|
let mut transformer1 =
|
|
MincutGatedTransformer::new(config.clone(), policy.clone(), weights.clone()).unwrap();
|
|
let mut transformer2 = MincutGatedTransformer::new(config.clone(), policy, weights).unwrap();
|
|
|
|
let input = InferInput::from_tokens(&tokens, gate);
|
|
|
|
let mut logits1 = vec![0i32; config.logits as usize];
|
|
let mut logits2 = vec![0i32; config.logits as usize];
|
|
|
|
{
|
|
let mut output1 = InferOutput::new(&mut logits1);
|
|
transformer1.infer(&input, &mut output1).unwrap();
|
|
}
|
|
{
|
|
let mut output2 = InferOutput::new(&mut logits2);
|
|
transformer2.infer(&input, &mut output2).unwrap();
|
|
}
|
|
|
|
// Verify identical outputs
|
|
assert_eq!(logits1, logits2, "Non-deterministic inference detected");
|
|
println!("Determinism verified: identical outputs across runs");
|
|
}
|