git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
1062 lines
32 KiB
Rust
1062 lines
32 KiB
Rust
#![allow(
|
|
clippy::all,
|
|
unused_imports,
|
|
unused_variables,
|
|
dead_code,
|
|
unused_mut,
|
|
unused_assignments,
|
|
non_camel_case_types,
|
|
clippy::approx_constant,
|
|
unexpected_cfgs,
|
|
unused_must_use,
|
|
unused_parens
|
|
)]
|
|
//! RuvLTRA-Small End-to-End Tests
|
|
//!
|
|
//! This module provides comprehensive end-to-end tests for the RuvLTRA-Small
|
|
//! inference pipeline, including full generation, streaming, and quality validation.
|
|
//!
|
|
//! ## Test Categories
|
|
//!
|
|
//! - **Full Inference Pipeline**: End-to-end generation from prompt to output
|
|
//! - **Streaming Generation**: Token-by-token streaming with callback validation
|
|
//! - **Quality Validation**: Perplexity checks, coherence scoring, output quality
|
|
//! - **Memory Validation**: Memory usage within bounds during inference
|
|
//!
|
|
//! ## Running Tests
|
|
//!
|
|
//! ```bash
|
|
//! # Run all E2E tests (some require model files)
|
|
//! cargo test --package ruvllm ruvltra_e2e
|
|
//!
|
|
//! # Run only tests that don't require model files
|
|
//! cargo test --package ruvllm ruvltra_e2e -- --skip model_required
|
|
//!
|
|
//! # Run with full features on Apple Silicon
|
|
//! cargo test --package ruvllm --features coreml,hybrid-ane ruvltra_e2e
|
|
//! ```
|
|
|
|
use ruvllm::backends::{GenerateParams, LlmBackend, ModelArchitecture, ModelConfig, Quantization};
|
|
use ruvllm::error::{Result, RuvLLMError};
|
|
use ruvllm::gguf::quantization::GgufQuantType;
|
|
use ruvllm::kernels::is_ane_available;
|
|
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
use std::sync::{Arc, Mutex};
|
|
use std::time::{Duration, Instant};
|
|
|
|
// ============================================================================
|
|
// Test Configuration and Fixtures
|
|
// ============================================================================
|
|
|
|
/// Sample prompts for testing
|
|
mod test_prompts {
|
|
/// Simple completion prompt
|
|
pub const SIMPLE_COMPLETION: &str = "The quick brown fox";
|
|
|
|
/// Instruction-following prompt
|
|
pub const INSTRUCTION: &str = "Write a haiku about programming:";
|
|
|
|
/// Question-answering prompt
|
|
pub const QA_PROMPT: &str = "Q: What is the capital of France?\nA:";
|
|
|
|
/// Code generation prompt
|
|
pub const CODE_PROMPT: &str = "def fibonacci(n):\n '''Return the nth Fibonacci number.'''\n";
|
|
|
|
/// Multi-turn conversation
|
|
pub const CONVERSATION: &str =
|
|
"User: Hello!\nAssistant: Hi there! How can I help you today?\nUser: Tell me a joke.\nAssistant:";
|
|
|
|
/// Edge case: very short prompt
|
|
pub const MINIMAL: &str = "Hi";
|
|
|
|
/// Edge case: prompt with special characters
|
|
pub const SPECIAL_CHARS: &str = "Translate: \"Hello, world!\" -> French: \"";
|
|
|
|
/// Stress test: longer prompt
|
|
pub const LONG_PROMPT: &str = "The following is a detailed explanation of machine learning. \
|
|
Machine learning is a subset of artificial intelligence that enables systems to learn \
|
|
and improve from experience without being explicitly programmed. It focuses on the \
|
|
development of computer programs that can access data and use it to learn for themselves. \
|
|
Continue:";
|
|
}
|
|
|
|
/// Expected output patterns for validation
|
|
mod expected_patterns {
|
|
/// Words that should commonly appear after "The quick brown fox"
|
|
pub const SIMPLE_COMPLETION_WORDS: &[&str] = &["jumps", "jumped", "runs", "ran", "the", "a"];
|
|
|
|
/// Common haiku-related words
|
|
pub const HAIKU_WORDS: &[&str] = &["code", "bug", "screen", "night", "debug", "compile"];
|
|
|
|
/// Capital of France
|
|
pub const FRANCE_CAPITAL: &str = "Paris";
|
|
|
|
/// Fibonacci-related words in code
|
|
pub const FIBONACCI_WORDS: &[&str] = &["return", "if", "else", "n", "fib", "0", "1"];
|
|
}
|
|
|
|
/// Quality thresholds
|
|
mod quality_thresholds {
|
|
/// Maximum acceptable perplexity for coherent output
|
|
pub const MAX_PERPLEXITY: f32 = 50.0;
|
|
|
|
/// Minimum output length for generation tests
|
|
pub const MIN_OUTPUT_TOKENS: usize = 5;
|
|
|
|
/// Maximum output length for bounded tests
|
|
pub const MAX_OUTPUT_TOKENS: usize = 1000;
|
|
|
|
/// Minimum probability for top token
|
|
pub const MIN_TOP_PROBABILITY: f32 = 0.01;
|
|
|
|
/// Maximum time for single token generation (ms)
|
|
pub const MAX_TOKEN_LATENCY_MS: u64 = 500;
|
|
|
|
/// Maximum memory increase during generation (bytes)
|
|
pub const MAX_MEMORY_INCREASE: usize = 500_000_000; // 500MB
|
|
}
|
|
|
|
// ============================================================================
|
|
// Full Inference Pipeline Tests
|
|
// ============================================================================
|
|
|
|
mod full_inference_pipeline {
|
|
use super::*;
|
|
|
|
/// Simulated inference result for testing pipeline behavior
|
|
#[derive(Debug, Clone)]
|
|
struct InferenceResult {
|
|
tokens: Vec<u32>,
|
|
text: String,
|
|
total_time: Duration,
|
|
tokens_per_second: f32,
|
|
}
|
|
|
|
/// Simulated model for testing (no actual weights needed)
|
|
struct MockModel {
|
|
vocab_size: usize,
|
|
config: ModelConfig,
|
|
}
|
|
|
|
impl MockModel {
|
|
fn new(config: ModelConfig) -> Self {
|
|
Self {
|
|
vocab_size: config.vocab_size.unwrap_or(32000),
|
|
config,
|
|
}
|
|
}
|
|
|
|
/// Simulate token generation
|
|
fn generate_mock_tokens(&self, prompt: &str, max_tokens: usize) -> Vec<u32> {
|
|
// Generate deterministic "tokens" based on prompt hash
|
|
let hash = prompt
|
|
.bytes()
|
|
.fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
|
|
let mut tokens = Vec::with_capacity(max_tokens);
|
|
let mut state = hash;
|
|
|
|
for _ in 0..max_tokens {
|
|
state = state.wrapping_mul(1103515245).wrapping_add(12345);
|
|
let token = (state % self.vocab_size as u64) as u32;
|
|
tokens.push(token);
|
|
|
|
// Simulate EOS token occasionally
|
|
if state % 100 < 5 {
|
|
break;
|
|
}
|
|
}
|
|
|
|
tokens
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_pipeline_initialization() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 8192,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: true,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config.clone());
|
|
|
|
assert_eq!(model.vocab_size, 32000);
|
|
assert_eq!(model.config.max_sequence_length, 8192);
|
|
}
|
|
|
|
#[test]
|
|
fn test_simple_completion_pipeline() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 4096,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: false,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config);
|
|
let prompt = test_prompts::SIMPLE_COMPLETION;
|
|
|
|
let tokens = model.generate_mock_tokens(prompt, 50);
|
|
|
|
// Verify tokens are valid
|
|
assert!(!tokens.is_empty(), "Should generate at least one token");
|
|
assert!(tokens.len() <= 50, "Should respect max tokens");
|
|
|
|
for token in &tokens {
|
|
assert!(*token < 32000, "Token {} exceeds vocab size", token);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_instruction_following_pipeline() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 4096,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: true,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config);
|
|
let prompt = test_prompts::INSTRUCTION;
|
|
|
|
let tokens = model.generate_mock_tokens(prompt, 100);
|
|
|
|
assert!(!tokens.is_empty());
|
|
assert!(tokens.iter().all(|t| *t < 32000));
|
|
}
|
|
|
|
#[test]
|
|
fn test_qa_pipeline() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 4096,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: false,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config);
|
|
let prompt = test_prompts::QA_PROMPT;
|
|
|
|
let tokens = model.generate_mock_tokens(prompt, 20);
|
|
|
|
assert!(!tokens.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_code_generation_pipeline() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 4096,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: true,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config);
|
|
let prompt = test_prompts::CODE_PROMPT;
|
|
|
|
let tokens = model.generate_mock_tokens(prompt, 100);
|
|
|
|
assert!(!tokens.is_empty());
|
|
assert!(
|
|
tokens.len() >= quality_thresholds::MIN_OUTPUT_TOKENS,
|
|
"Code generation should produce at least {} tokens",
|
|
quality_thresholds::MIN_OUTPUT_TOKENS
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_conversation_pipeline() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 4096,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: true,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config);
|
|
let prompt = test_prompts::CONVERSATION;
|
|
|
|
let tokens = model.generate_mock_tokens(prompt, 50);
|
|
|
|
assert!(!tokens.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_minimal_prompt_handling() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 4096,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: false,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config);
|
|
let prompt = test_prompts::MINIMAL;
|
|
|
|
let tokens = model.generate_mock_tokens(prompt, 20);
|
|
|
|
// Should handle minimal input gracefully
|
|
assert!(!tokens.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_long_prompt_handling() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 4096,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: true,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config);
|
|
let prompt = test_prompts::LONG_PROMPT;
|
|
|
|
let tokens = model.generate_mock_tokens(prompt, 100);
|
|
|
|
assert!(!tokens.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_empty_prompt_handling() {
|
|
let config = ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
quantization: Some(Quantization::Q4K),
|
|
max_sequence_length: 4096,
|
|
vocab_size: Some(32000),
|
|
use_flash_attention: false,
|
|
..Default::default()
|
|
};
|
|
|
|
let model = MockModel::new(config);
|
|
let prompt = "";
|
|
|
|
let tokens = model.generate_mock_tokens(prompt, 20);
|
|
|
|
// Empty prompt should still produce some output
|
|
// (implementation-dependent behavior)
|
|
let _ = tokens;
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Streaming Generation Tests
|
|
// ============================================================================
|
|
|
|
mod streaming_generation {
|
|
use super::*;
|
|
|
|
/// Token callback for streaming tests
|
|
type TokenCallback = Box<dyn FnMut(u32, &str) + Send>;
|
|
|
|
/// Streaming state tracker
|
|
struct StreamingState {
|
|
tokens_received: Vec<u32>,
|
|
chunks_received: Vec<String>,
|
|
total_latency: Duration,
|
|
first_token_time: Option<Duration>,
|
|
}
|
|
|
|
impl StreamingState {
|
|
fn new() -> Self {
|
|
Self {
|
|
tokens_received: Vec::new(),
|
|
chunks_received: Vec::new(),
|
|
total_latency: Duration::ZERO,
|
|
first_token_time: None,
|
|
}
|
|
}
|
|
|
|
fn record_token(&mut self, token: u32, chunk: &str, latency: Duration) {
|
|
if self.first_token_time.is_none() {
|
|
self.first_token_time = Some(latency);
|
|
}
|
|
self.tokens_received.push(token);
|
|
self.chunks_received.push(chunk.to_string());
|
|
self.total_latency += latency;
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_callback_invocation() {
|
|
let callback_count = Arc::new(AtomicUsize::new(0));
|
|
let count_clone = callback_count.clone();
|
|
|
|
// Simulate streaming with callback
|
|
let mut callback = move |_token: u32, _chunk: &str| {
|
|
count_clone.fetch_add(1, Ordering::SeqCst);
|
|
};
|
|
|
|
// Simulate 10 tokens
|
|
for i in 0..10 {
|
|
callback(i as u32, &format!("token_{}", i));
|
|
}
|
|
|
|
assert_eq!(callback_count.load(Ordering::SeqCst), 10);
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_state_tracking() {
|
|
let state = Arc::new(Mutex::new(StreamingState::new()));
|
|
|
|
// Simulate token stream
|
|
let tokens = [(1u32, "Hello"), (2, " "), (3, "world"), (4, "!")];
|
|
|
|
for (token, chunk) in &tokens {
|
|
let latency = Duration::from_millis(50);
|
|
state.lock().unwrap().record_token(*token, chunk, latency);
|
|
}
|
|
|
|
let final_state = state.lock().unwrap();
|
|
assert_eq!(final_state.tokens_received.len(), 4);
|
|
assert_eq!(final_state.chunks_received.len(), 4);
|
|
assert!(final_state.first_token_time.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_first_token_latency() {
|
|
let start = Instant::now();
|
|
|
|
// Simulate first token generation
|
|
std::thread::sleep(Duration::from_millis(10));
|
|
let first_token_time = start.elapsed();
|
|
|
|
// First token should come quickly (for mock)
|
|
assert!(
|
|
first_token_time < Duration::from_millis(100),
|
|
"First token took {:?}",
|
|
first_token_time
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_inter_token_latency() {
|
|
let mut latencies = Vec::new();
|
|
|
|
// Simulate token stream timing
|
|
for _ in 0..10 {
|
|
let start = Instant::now();
|
|
// Simulate token processing
|
|
std::thread::sleep(Duration::from_micros(100));
|
|
latencies.push(start.elapsed());
|
|
}
|
|
|
|
// All latencies should be below threshold
|
|
for (i, latency) in latencies.iter().enumerate() {
|
|
assert!(
|
|
*latency < Duration::from_millis(quality_thresholds::MAX_TOKEN_LATENCY_MS),
|
|
"Token {} latency {:?} exceeds threshold",
|
|
i,
|
|
latency
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_cancellation() {
|
|
let cancelled = Arc::new(AtomicUsize::new(0));
|
|
let tokens_generated = Arc::new(AtomicUsize::new(0));
|
|
|
|
let cancelled_clone = cancelled.clone();
|
|
let tokens_clone = tokens_generated.clone();
|
|
|
|
// Simulate streaming with early cancellation
|
|
for i in 0..100 {
|
|
if cancelled_clone.load(Ordering::SeqCst) > 0 {
|
|
break;
|
|
}
|
|
|
|
tokens_clone.fetch_add(1, Ordering::SeqCst);
|
|
|
|
// Cancel after 5 tokens
|
|
if i == 4 {
|
|
cancelled_clone.store(1, Ordering::SeqCst);
|
|
}
|
|
}
|
|
|
|
assert_eq!(tokens_generated.load(Ordering::SeqCst), 5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_buffer_accumulation() {
|
|
let mut buffer = String::new();
|
|
let chunks = ["Hello", ", ", "how ", "are ", "you", "?"];
|
|
|
|
for chunk in &chunks {
|
|
buffer.push_str(chunk);
|
|
}
|
|
|
|
assert_eq!(buffer, "Hello, how are you?");
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_unicode_handling() {
|
|
let mut state = StreamingState::new();
|
|
|
|
// Unicode tokens
|
|
let unicode_chunks = [
|
|
(1, "Hello"),
|
|
(2, " "),
|
|
(3, "\u{1F44B}"), // Wave emoji
|
|
(4, " World"),
|
|
(5, "\u{1F310}"), // Globe emoji
|
|
];
|
|
|
|
for (token, chunk) in &unicode_chunks {
|
|
state.record_token(*token as u32, chunk, Duration::from_millis(10));
|
|
}
|
|
|
|
let full_text: String = state.chunks_received.join("");
|
|
assert!(full_text.contains('\u{1F44B}'));
|
|
assert!(full_text.contains('\u{1F310}'));
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_empty_chunks() {
|
|
let mut state = StreamingState::new();
|
|
|
|
// Some implementations may emit empty chunks
|
|
let chunks = [(1, "Hello"), (2, ""), (3, " "), (4, ""), (5, "World")];
|
|
|
|
for (token, chunk) in &chunks {
|
|
state.record_token(*token as u32, chunk, Duration::from_millis(10));
|
|
}
|
|
|
|
let non_empty: Vec<_> = state
|
|
.chunks_received
|
|
.iter()
|
|
.filter(|c| !c.is_empty())
|
|
.collect();
|
|
|
|
assert_eq!(non_empty.len(), 3);
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Quality Validation Tests
|
|
// ============================================================================
|
|
|
|
mod quality_validation {
|
|
use super::*;
|
|
|
|
/// Calculate perplexity from log probabilities
|
|
fn calculate_perplexity(log_probs: &[f32]) -> f32 {
|
|
if log_probs.is_empty() {
|
|
return f32::INFINITY;
|
|
}
|
|
let avg_neg_log_prob = -log_probs.iter().sum::<f32>() / log_probs.len() as f32;
|
|
avg_neg_log_prob.exp()
|
|
}
|
|
|
|
/// Check if output contains expected patterns
|
|
fn contains_expected_patterns(output: &str, patterns: &[&str]) -> bool {
|
|
let output_lower = output.to_lowercase();
|
|
patterns
|
|
.iter()
|
|
.any(|p| output_lower.contains(&p.to_lowercase()))
|
|
}
|
|
|
|
#[test]
|
|
fn test_perplexity_calculation() {
|
|
// Good log probs (high probability = low perplexity)
|
|
let good_log_probs = vec![-1.0, -0.5, -1.0, -0.8, -1.2];
|
|
let good_ppl = calculate_perplexity(&good_log_probs);
|
|
|
|
// Bad log probs (low probability = high perplexity)
|
|
let bad_log_probs = vec![-5.0, -6.0, -4.5, -7.0, -5.5];
|
|
let bad_ppl = calculate_perplexity(&bad_log_probs);
|
|
|
|
assert!(good_ppl < bad_ppl, "Good text should have lower perplexity");
|
|
assert!(good_ppl.is_finite());
|
|
assert!(bad_ppl.is_finite());
|
|
}
|
|
|
|
#[test]
|
|
fn test_perplexity_threshold() {
|
|
// Simulate reasonable log probs
|
|
let log_probs: Vec<f32> = (0..100).map(|_| -2.5).collect();
|
|
let ppl = calculate_perplexity(&log_probs);
|
|
|
|
assert!(
|
|
ppl < quality_thresholds::MAX_PERPLEXITY,
|
|
"Perplexity {} exceeds threshold",
|
|
ppl
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_perplexity_empty_input() {
|
|
let empty: Vec<f32> = vec![];
|
|
let ppl = calculate_perplexity(&empty);
|
|
assert!(ppl.is_infinite());
|
|
}
|
|
|
|
#[test]
|
|
fn test_output_coherence_simple() {
|
|
// Test expected patterns for simple completion
|
|
let output = "jumps over the lazy dog";
|
|
assert!(contains_expected_patterns(
|
|
output,
|
|
expected_patterns::SIMPLE_COMPLETION_WORDS
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_output_coherence_qa() {
|
|
// Test expected patterns for QA
|
|
let output = "The capital of France is Paris.";
|
|
assert!(output.contains(expected_patterns::FRANCE_CAPITAL));
|
|
}
|
|
|
|
#[test]
|
|
fn test_output_coherence_code() {
|
|
// Test expected patterns for code
|
|
let output = "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)";
|
|
assert!(contains_expected_patterns(
|
|
output,
|
|
expected_patterns::FIBONACCI_WORDS
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn test_probability_distribution_valid() {
|
|
// Simulated softmax probabilities
|
|
let probs = vec![0.4, 0.3, 0.15, 0.1, 0.05];
|
|
|
|
// Sum should be ~1.0
|
|
let sum: f32 = probs.iter().sum();
|
|
assert!((sum - 1.0).abs() < 0.01);
|
|
|
|
// All probabilities should be valid
|
|
for p in &probs {
|
|
assert!(*p >= 0.0 && *p <= 1.0);
|
|
}
|
|
|
|
// Top probability should meet threshold
|
|
let top_prob = probs.iter().cloned().fold(0.0f32, f32::max);
|
|
assert!(top_prob >= quality_thresholds::MIN_TOP_PROBABILITY);
|
|
}
|
|
|
|
#[test]
|
|
fn test_output_length_bounds() {
|
|
// Simulated output
|
|
let output_tokens = vec![1u32; 50];
|
|
|
|
assert!(output_tokens.len() >= quality_thresholds::MIN_OUTPUT_TOKENS);
|
|
assert!(output_tokens.len() <= quality_thresholds::MAX_OUTPUT_TOKENS);
|
|
}
|
|
|
|
#[test]
|
|
fn test_no_garbled_output() {
|
|
// Check for common garbled patterns
|
|
fn is_garbled(text: &str) -> bool {
|
|
// Check for excessive repetition
|
|
let words: Vec<&str> = text.split_whitespace().collect();
|
|
if words.len() > 5 {
|
|
let mut consecutive_repeats = 0;
|
|
for i in 1..words.len() {
|
|
if words[i] == words[i - 1] {
|
|
consecutive_repeats += 1;
|
|
if consecutive_repeats > 3 {
|
|
return true;
|
|
}
|
|
} else {
|
|
consecutive_repeats = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check for excessive special characters
|
|
let special_ratio = text
|
|
.chars()
|
|
.filter(|c| !c.is_alphanumeric() && !c.is_whitespace())
|
|
.count() as f32
|
|
/ text.len().max(1) as f32;
|
|
if special_ratio > 0.5 {
|
|
return true;
|
|
}
|
|
|
|
false
|
|
}
|
|
|
|
// Good outputs
|
|
assert!(!is_garbled("The quick brown fox jumps over the lazy dog."));
|
|
assert!(!is_garbled("Hello, how are you today?"));
|
|
|
|
// Garbled outputs
|
|
assert!(is_garbled("the the the the the the"));
|
|
assert!(is_garbled("!@#$%^&*()!@#$%^&*()!@#$%^&*()"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_repetition_penalty_effectiveness() {
|
|
// Simulate output with and without repetition penalty
|
|
|
|
// Without penalty: likely to have repetitions
|
|
let output_no_penalty = "the the the quick brown fox";
|
|
|
|
// With penalty: less likely to repeat
|
|
let output_with_penalty = "the quick brown fox jumps over";
|
|
|
|
fn count_word_repetitions(text: &str) -> usize {
|
|
let words: Vec<&str> = text.split_whitespace().collect();
|
|
let mut repetitions = 0;
|
|
for i in 1..words.len() {
|
|
if words[i] == words[i - 1] {
|
|
repetitions += 1;
|
|
}
|
|
}
|
|
repetitions
|
|
}
|
|
|
|
let reps_no_penalty = count_word_repetitions(output_no_penalty);
|
|
let reps_with_penalty = count_word_repetitions(output_with_penalty);
|
|
|
|
assert!(reps_no_penalty >= reps_with_penalty);
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Memory Validation Tests
|
|
// ============================================================================
|
|
|
|
mod memory_validation {
|
|
use super::*;
|
|
|
|
/// Get approximate memory usage (platform-dependent)
|
|
fn get_memory_usage() -> usize {
|
|
// In real implementation, this would query actual process memory
|
|
// For testing, we'll use allocation tracking
|
|
std::mem::size_of::<u8>() * 1000 // Placeholder
|
|
}
|
|
|
|
#[test]
|
|
fn test_memory_increase_bounded() {
|
|
let initial_memory = get_memory_usage();
|
|
|
|
// Simulate memory allocation during inference
|
|
let mut allocations: Vec<Vec<f32>> = Vec::new();
|
|
for _ in 0..10 {
|
|
allocations.push(vec![0.0f32; 10000]);
|
|
}
|
|
|
|
// Memory increase should be bounded
|
|
let memory_increase = allocations.len() * 10000 * std::mem::size_of::<f32>();
|
|
assert!(
|
|
memory_increase < quality_thresholds::MAX_MEMORY_INCREASE,
|
|
"Memory increase {} exceeds bound",
|
|
memory_increase
|
|
);
|
|
|
|
// Clean up
|
|
drop(allocations);
|
|
}
|
|
|
|
#[test]
|
|
fn test_kv_cache_memory() {
|
|
// Simulated KV cache parameters
|
|
let num_layers = 22;
|
|
let num_kv_heads = 8;
|
|
let head_dim = 64;
|
|
let max_seq_len = 8192;
|
|
|
|
// KV cache size: 2 (K+V) * layers * kv_heads * head_dim * seq_len * sizeof(f16)
|
|
let kv_cache_bytes = 2 * num_layers * num_kv_heads * head_dim * max_seq_len * 2;
|
|
|
|
// Should be reasonable
|
|
assert!(
|
|
kv_cache_bytes < 500_000_000,
|
|
"KV cache {} bytes too large",
|
|
kv_cache_bytes
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_activation_memory() {
|
|
// Simulated activation memory for forward pass
|
|
let batch_size = 1;
|
|
let seq_len = 1024;
|
|
let hidden_size = 2048;
|
|
|
|
// Activation: batch * seq * hidden * sizeof(f32)
|
|
let activation_bytes = batch_size * seq_len * hidden_size * 4;
|
|
|
|
assert!(
|
|
activation_bytes < 100_000_000,
|
|
"Activation memory {} too large",
|
|
activation_bytes
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_memory_cleanup_after_generation() {
|
|
// Simulate allocation and cleanup
|
|
{
|
|
let _temp_buffer = vec![0.0f32; 100000];
|
|
// Buffer goes out of scope
|
|
}
|
|
|
|
// In real implementation, verify memory is freed
|
|
// For tests, this mainly verifies no panic during cleanup
|
|
}
|
|
|
|
#[test]
|
|
fn test_quantized_model_memory_savings() {
|
|
let vocab_size = 32000;
|
|
let hidden_size = 2048;
|
|
|
|
// Embedding size comparison
|
|
let f32_size = vocab_size * hidden_size * 4;
|
|
let q4k_size = GgufQuantType::Q4_K.tensor_size(vocab_size * hidden_size);
|
|
|
|
let savings_ratio = 1.0 - (q4k_size as f32 / f32_size as f32);
|
|
|
|
// Q4_K should save at least 70% memory
|
|
assert!(
|
|
savings_ratio > 0.7,
|
|
"Q4_K savings ratio {} below expected",
|
|
savings_ratio
|
|
);
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Error Handling Tests
|
|
// ============================================================================
|
|
|
|
mod error_handling {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_invalid_token_handling() {
|
|
let vocab_size = 32000;
|
|
let invalid_tokens = [u32::MAX, vocab_size as u32, vocab_size as u32 + 1000];
|
|
|
|
for token in invalid_tokens {
|
|
assert!(
|
|
token >= vocab_size as u32,
|
|
"Token {} should be invalid for vocab size {}",
|
|
token,
|
|
vocab_size
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_context_overflow_handling() {
|
|
let max_context = 4096;
|
|
let prompt_length = 5000;
|
|
|
|
// Should detect overflow
|
|
let overflow = prompt_length > max_context;
|
|
assert!(overflow, "Should detect context overflow");
|
|
|
|
// Calculate truncation
|
|
let truncated_length = prompt_length.min(max_context);
|
|
assert!(truncated_length <= max_context);
|
|
}
|
|
|
|
#[test]
|
|
fn test_out_of_memory_simulation() {
|
|
// Simulate OOM by attempting very large allocation
|
|
// Note: This won't actually allocate, just test the check
|
|
let requested_size: usize = 1_000_000_000_000; // 1TB
|
|
|
|
// Check if allocation would exceed bounds
|
|
let would_oom = requested_size > quality_thresholds::MAX_MEMORY_INCREASE;
|
|
assert!(would_oom, "Should detect potential OOM");
|
|
}
|
|
|
|
#[test]
|
|
fn test_nan_inf_detection() {
|
|
let test_values: [f32; 4] = [1.0, f32::NAN, f32::INFINITY, f32::NEG_INFINITY];
|
|
|
|
let finite_count = test_values.iter().filter(|v| v.is_finite()).count();
|
|
let nan_count = test_values.iter().filter(|v| v.is_nan()).count();
|
|
let inf_count = test_values.iter().filter(|v| v.is_infinite()).count();
|
|
|
|
assert_eq!(finite_count, 1);
|
|
assert_eq!(nan_count, 1);
|
|
assert_eq!(inf_count, 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_graceful_degradation() {
|
|
// Simulate graceful degradation when ANE unavailable
|
|
let ane_available = is_ane_available();
|
|
|
|
// Should work regardless of ANE availability
|
|
let fallback_used = !ane_available;
|
|
let computation_succeeded = true; // Mock
|
|
|
|
assert!(computation_succeeded || fallback_used);
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Integration Stress Tests
|
|
// ============================================================================
|
|
|
|
mod stress_tests {
|
|
use super::*;
|
|
use std::thread;
|
|
|
|
#[test]
|
|
fn test_rapid_sequential_generations() {
|
|
let iterations = 100;
|
|
|
|
for i in 0..iterations {
|
|
// Simulate rapid generation
|
|
let prompt = format!("Test prompt {}", i);
|
|
let hash = prompt
|
|
.bytes()
|
|
.fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
|
|
let _ = hash % 32000; // Mock token
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_concurrent_inference() {
|
|
let handles: Vec<_> = (0..4)
|
|
.map(|i| {
|
|
thread::spawn(move || {
|
|
for j in 0..25 {
|
|
let prompt = format!("Thread {} prompt {}", i, j);
|
|
let hash = prompt
|
|
.bytes()
|
|
.fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
|
|
let _ = hash % 32000;
|
|
}
|
|
})
|
|
})
|
|
.collect();
|
|
|
|
for handle in handles {
|
|
handle.join().expect("Thread should complete");
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_varied_prompt_lengths() {
|
|
let lengths = [1, 10, 100, 1000];
|
|
|
|
for len in lengths {
|
|
let prompt: String = (0..len)
|
|
.map(|i| char::from((b'a' + (i % 26) as u8)))
|
|
.collect();
|
|
assert_eq!(prompt.len(), len);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_varied_generation_lengths() {
|
|
let max_tokens_variants = [1, 10, 50, 100, 500];
|
|
|
|
for max_tokens in max_tokens_variants {
|
|
let generated: Vec<u32> = (0..max_tokens.min(100))
|
|
.map(|i| (i % 32000) as u32)
|
|
.collect();
|
|
assert!(generated.len() <= max_tokens);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
#[ignore] // Run with: cargo test --release -- --ignored
|
|
fn test_extended_generation_stability() {
|
|
// Test stability over many tokens
|
|
let total_tokens = 10000;
|
|
let mut tokens = Vec::with_capacity(total_tokens);
|
|
|
|
for i in 0..total_tokens {
|
|
let token = (i * 17 + 13) % 32000;
|
|
tokens.push(token as u32);
|
|
}
|
|
|
|
// All tokens should be valid
|
|
assert!(tokens.iter().all(|t| *t < 32000));
|
|
assert_eq!(tokens.len(), total_tokens);
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Benchmark Tests (Ignored by Default)
|
|
// ============================================================================
|
|
|
|
mod benchmarks {
|
|
use super::*;
|
|
|
|
#[test]
|
|
#[ignore]
|
|
fn benchmark_token_generation_rate() {
|
|
let iterations = 1000;
|
|
let start = Instant::now();
|
|
|
|
for i in 0..iterations {
|
|
// Simulate token generation
|
|
let token = (i * 31 + 17) % 32000;
|
|
let _ = token;
|
|
}
|
|
|
|
let duration = start.elapsed();
|
|
let tokens_per_sec = iterations as f64 / duration.as_secs_f64();
|
|
|
|
println!("Token generation rate: {:.2} tokens/sec", tokens_per_sec);
|
|
|
|
// Should achieve reasonable throughput
|
|
assert!(tokens_per_sec > 1000.0);
|
|
}
|
|
|
|
#[test]
|
|
#[ignore]
|
|
fn benchmark_prompt_encoding() {
|
|
let prompt = test_prompts::LONG_PROMPT;
|
|
let iterations = 100;
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
// Simulate tokenization
|
|
let _tokens: Vec<u32> = prompt.bytes().map(|b| b as u32).collect();
|
|
}
|
|
let duration = start.elapsed();
|
|
|
|
let avg_time = duration / iterations;
|
|
println!("Average prompt encoding time: {:?}", avg_time);
|
|
|
|
assert!(avg_time < Duration::from_millis(10));
|
|
}
|
|
|
|
#[test]
|
|
#[ignore]
|
|
fn benchmark_memory_allocation() {
|
|
let iterations = 100;
|
|
let buffer_size = 4096 * 4096;
|
|
|
|
let start = Instant::now();
|
|
for _ in 0..iterations {
|
|
let buffer = vec![0.0f32; buffer_size];
|
|
drop(buffer);
|
|
}
|
|
let duration = start.elapsed();
|
|
|
|
let avg_time = duration / iterations;
|
|
println!("Average buffer allocation time: {:?}", avg_time);
|
|
}
|
|
}
|