Files
wifi-densepose/crates/ruvllm/tests/e2e_integration_test.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

1536 lines
47 KiB
Rust

#![allow(
clippy::all,
unused_imports,
unused_variables,
dead_code,
unused_mut,
unused_assignments,
non_camel_case_types,
clippy::approx_constant,
unexpected_cfgs,
unused_must_use,
unused_parens
)]
//! End-to-end Integration Tests for RuvLLM
//!
//! Tests the complete inference pipeline including:
//! - GGUF file parsing and loading
//! - Token generation with various configurations
//! - Streaming generation with callbacks
//! - Speculative decoding pipeline
//! - KV cache persistence and continuation
//! - Batch generation processing
//! - Stop sequence handling
//! - Temperature sampling verification
//!
//! ## Running Tests
//!
//! ### Without a real model (uses NoopBackend simulation):
//! ```bash
//! cargo test -p ruvllm --test e2e_integration_test
//! ```
//!
//! ### With a real model file:
//! ```bash
//! TEST_MODEL_PATH=/path/to/model.gguf cargo test -p ruvllm --test e2e_integration_test -- --ignored
//! ```
//!
//! ### Run specific test with model:
//! ```bash
//! TEST_MODEL_PATH=/path/to/model.gguf cargo test -p ruvllm --test e2e_integration_test test_real_model_generation -- --ignored
//! ```
use ruvllm::{
// Backends
backends::{
GenerateParams, GeneratedToken, LlmBackend, ModelArchitecture, ModelConfig, Quantization,
SpecialTokens, StreamEvent, TokenStream, Tokenizer,
},
// Error handling
error::{Result, RuvLLMError},
// KV Cache
kv_cache::{KvCacheConfig, TwoTierKvCache},
// Serving
serving::{
InferenceRequest, KvCachePoolConfig, Priority, ServingEngine, ServingEngineConfig,
TokenOutput,
},
// Speculative decoding
speculative::{
log_softmax, sample_from_probs, softmax, top_k_filter, top_p_filter,
AtomicSpeculativeStats, SpeculationTree, SpeculativeConfig, SpeculativeDecoder,
SpeculativeStats, TreeNode,
},
};
use std::collections::HashMap;
use std::env;
use std::path::Path;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
// ============================================================================
// Test Fixtures and Helpers
// ============================================================================
/// GGUF magic number "GGUF" in little-endian
const GGUF_MAGIC: u32 = 0x46554747;
/// Supported GGUF version
const GGUF_VERSION: u32 = 3;
/// GGUF metadata value types
#[repr(u32)]
enum GgufMetadataType {
Uint32 = 4,
String = 8,
}
/// Create a minimal valid GGUF file for testing (header only, no tensors)
fn create_minimal_test_gguf() -> Vec<u8> {
let mut data = Vec::new();
// Magic number
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
// Version
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
// Tensor count: 0
data.extend_from_slice(&0u64.to_le_bytes());
// Metadata KV count: 0
data.extend_from_slice(&0u64.to_le_bytes());
data
}
/// Create a GGUF file with metadata (architecture, context length, etc.)
fn create_test_gguf_with_metadata() -> Vec<u8> {
let mut data = Vec::new();
// Header
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
// Tensor count: 1 (we'll add a small embedding)
data.extend_from_slice(&1u64.to_le_bytes());
// Metadata count: 3
data.extend_from_slice(&3u64.to_le_bytes());
// Metadata 1: general.architecture = "llama" (string)
let key1 = "general.architecture";
data.extend_from_slice(&(key1.len() as u64).to_le_bytes());
data.extend_from_slice(key1.as_bytes());
data.extend_from_slice(&(GgufMetadataType::String as u32).to_le_bytes());
let value1 = "llama";
data.extend_from_slice(&(value1.len() as u64).to_le_bytes());
data.extend_from_slice(value1.as_bytes());
// Metadata 2: llama.context_length = 4096 (u32)
let key2 = "llama.context_length";
data.extend_from_slice(&(key2.len() as u64).to_le_bytes());
data.extend_from_slice(key2.as_bytes());
data.extend_from_slice(&(GgufMetadataType::Uint32 as u32).to_le_bytes());
data.extend_from_slice(&4096u32.to_le_bytes());
// Metadata 3: llama.embedding_length = 4096 (u32)
let key3 = "llama.embedding_length";
data.extend_from_slice(&(key3.len() as u64).to_le_bytes());
data.extend_from_slice(key3.as_bytes());
data.extend_from_slice(&(GgufMetadataType::Uint32 as u32).to_le_bytes());
data.extend_from_slice(&4096u32.to_le_bytes());
// Tensor info for a small embedding tensor
let tensor_name = "model.embed_tokens.weight";
data.extend_from_slice(&(tensor_name.len() as u64).to_le_bytes());
data.extend_from_slice(tensor_name.as_bytes());
data.extend_from_slice(&2u32.to_le_bytes()); // n_dims
data.extend_from_slice(&32u64.to_le_bytes()); // vocab_size (small for test)
data.extend_from_slice(&16u64.to_le_bytes()); // hidden_size (small for test)
data.extend_from_slice(&0u32.to_le_bytes()); // F32 type
data.extend_from_slice(&0u64.to_le_bytes()); // offset
data
}
/// Create a GGUF file with Q4_0 quantized tensor
fn create_test_gguf_q4_quantized() -> Vec<u8> {
let mut data = Vec::new();
// Header
data.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
data.extend_from_slice(&GGUF_VERSION.to_le_bytes());
data.extend_from_slice(&1u64.to_le_bytes()); // 1 tensor
data.extend_from_slice(&1u64.to_le_bytes()); // 1 metadata
// Metadata: architecture
let key = "general.architecture";
data.extend_from_slice(&(key.len() as u64).to_le_bytes());
data.extend_from_slice(key.as_bytes());
data.extend_from_slice(&(GgufMetadataType::String as u32).to_le_bytes());
let value = "llama";
data.extend_from_slice(&(value.len() as u64).to_le_bytes());
data.extend_from_slice(value.as_bytes());
// Tensor info (Q4_0 quantized)
let tensor_name = "model.layers.0.self_attn.q_proj.weight";
data.extend_from_slice(&(tensor_name.len() as u64).to_le_bytes());
data.extend_from_slice(tensor_name.as_bytes());
data.extend_from_slice(&2u32.to_le_bytes()); // n_dims
data.extend_from_slice(&64u64.to_le_bytes()); // dim0
data.extend_from_slice(&64u64.to_le_bytes()); // dim1
data.extend_from_slice(&2u32.to_le_bytes()); // Q4_0 type
data.extend_from_slice(&0u64.to_le_bytes()); // offset
data
}
/// Mock tokenizer for testing
struct MockTokenizer {
vocab: HashMap<String, u32>,
reverse_vocab: HashMap<u32, String>,
}
impl MockTokenizer {
fn new() -> Self {
let mut vocab = HashMap::new();
let mut reverse_vocab = HashMap::new();
// Add common tokens
let tokens = [
("<s>", 1),
("</s>", 2),
("<pad>", 0),
("Hello", 100),
(",", 101),
(" ", 102),
("world", 103),
("!", 104),
("The", 105),
("quick", 106),
("brown", 107),
("fox", 108),
("jumps", 109),
("over", 110),
("lazy", 111),
("dog", 112),
(".", 113),
("test", 114),
("model", 115),
("output", 116),
];
for (text, id) in tokens {
vocab.insert(text.to_string(), id);
reverse_vocab.insert(id, text.to_string());
}
Self {
vocab,
reverse_vocab,
}
}
}
impl Tokenizer for MockTokenizer {
fn encode(&self, text: &str) -> Result<Vec<u32>> {
// Simple word-based tokenization for testing
let mut tokens = Vec::new();
for word in text.split_whitespace() {
if let Some(&id) = self.vocab.get(word) {
tokens.push(id);
} else {
// Unknown word - hash it to a pseudo-ID
let hash = word
.bytes()
.fold(200u32, |acc, b| acc.wrapping_add(b as u32));
tokens.push(hash % 1000 + 200);
}
}
Ok(tokens)
}
fn decode(&self, tokens: &[u32]) -> Result<String> {
let words: Vec<String> = tokens
.iter()
.filter_map(|&id| {
self.reverse_vocab
.get(&id)
.cloned()
.or_else(|| Some(format!("[{}]", id)))
})
.collect();
Ok(words.join(" "))
}
fn vocab_size(&self) -> usize {
32000 // Standard vocab size
}
fn special_tokens(&self) -> SpecialTokens {
SpecialTokens {
bos_token_id: Some(1),
eos_token_id: Some(2),
pad_token_id: Some(0),
unk_token_id: Some(3),
}
}
}
/// Mock LLM backend that generates deterministic tokens based on context
struct MockLlmBackend {
tokenizer: MockTokenizer,
model_loaded: AtomicBool,
generation_count: AtomicUsize,
}
impl MockLlmBackend {
fn new() -> Self {
Self {
tokenizer: MockTokenizer::new(),
model_loaded: AtomicBool::new(false),
generation_count: AtomicUsize::new(0),
}
}
fn deterministic_token(&self, context: &[u32], seed_offset: usize) -> u32 {
let hash = context.iter().fold(seed_offset as u32, |acc, &t| {
acc.wrapping_add(t).wrapping_mul(31)
});
// Generate tokens in reasonable vocabulary range
(hash % 30000) + 100
}
}
impl LlmBackend for MockLlmBackend {
fn load_model(&mut self, _model_id: &str, _config: ModelConfig) -> Result<()> {
self.model_loaded.store(true, Ordering::SeqCst);
Ok(())
}
fn generate(&self, prompt: &str, params: GenerateParams) -> Result<String> {
if !self.model_loaded.load(Ordering::SeqCst) {
return Err(RuvLLMError::Config("Model not loaded".to_string()));
}
let count = self.generation_count.fetch_add(1, Ordering::SeqCst);
let prompt_tokens = self.tokenizer.encode(prompt)?;
// Generate deterministic tokens
let mut output_tokens = Vec::new();
let mut context = prompt_tokens.clone();
for i in 0..params.max_tokens {
let token = self.deterministic_token(&context, count + i);
// Check for stop
if token == 2 {
// EOS
break;
}
output_tokens.push(token);
context.push(token);
}
// Decode output
self.tokenizer.decode(&output_tokens)
}
fn generate_stream(
&self,
prompt: &str,
params: GenerateParams,
) -> Result<Box<dyn Iterator<Item = Result<GeneratedToken>> + Send + '_>> {
let count = self.generation_count.fetch_add(1, Ordering::SeqCst);
let prompt_tokens = self.tokenizer.encode(prompt)?;
Ok(Box::new(MockStreamIterator {
backend: self,
context: prompt_tokens,
remaining: params.max_tokens,
seed_offset: count,
finished: false,
}))
}
fn generate_stream_v2(&self, prompt: &str, params: GenerateParams) -> Result<TokenStream> {
let (tx, stream) = TokenStream::channel();
let count = self.generation_count.fetch_add(1, Ordering::SeqCst);
let prompt_tokens = self.tokenizer.encode(prompt)?;
let max_tokens = params.max_tokens;
// Pre-generate all tokens (deterministic, so we can compute them ahead of time)
let mut context = prompt_tokens;
let mut tokens_to_send = Vec::new();
let start = Instant::now();
for i in 0..max_tokens {
let token = self.deterministic_token(&context, count + i);
let text = self.tokenizer.decode(&[token]).unwrap_or_default();
let is_eos = token == 2;
tokens_to_send.push((token, text, is_eos));
if is_eos {
break;
}
context.push(token);
}
let token_count = tokens_to_send.len();
let duration = start.elapsed();
// Spawn thread to send tokens (only uses owned data now)
std::thread::spawn(move || {
for (token, text, is_eos) in tokens_to_send {
let event = StreamEvent::Token(GeneratedToken {
id: token,
text,
logprob: Some(-0.5), // Dummy logprob
is_special: is_eos,
});
if tx.send(event).is_err() {
break;
}
}
let _ = tx.send(StreamEvent::Done {
total_tokens: token_count,
duration_ms: duration.as_millis() as u64,
tokens_per_second: token_count as f64 / duration.as_secs_f64().max(0.001),
});
});
Ok(stream)
}
fn get_embeddings(&self, text: &str) -> Result<Vec<f32>> {
// Generate deterministic embeddings
let tokens = self.tokenizer.encode(text)?;
let dim = 768; // Standard embedding dim
let mut embeddings = vec![0.0f32; dim];
for (i, &t) in tokens.iter().enumerate() {
for j in 0..dim {
let idx = (i * 100 + j) % dim;
embeddings[idx] += (t as f32 * 0.001) * ((j as f32 + 1.0).sin());
}
}
// Normalize
let norm: f32 = embeddings.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for e in &mut embeddings {
*e /= norm;
}
}
Ok(embeddings)
}
fn tokenizer(&self) -> Option<&dyn Tokenizer> {
Some(&self.tokenizer)
}
fn is_model_loaded(&self) -> bool {
self.model_loaded.load(Ordering::SeqCst)
}
fn model_info(&self) -> Option<ruvllm::backends::ModelInfo> {
if self.is_model_loaded() {
Some(ruvllm::backends::ModelInfo {
name: "MockModel-7B".to_string(),
architecture: ModelArchitecture::Llama,
num_parameters: 7_000_000_000,
vocab_size: 32000,
hidden_size: 4096,
num_layers: 32,
max_context_length: 8192,
quantization: Some(Quantization::Q4K),
memory_usage: 4_000_000_000,
})
} else {
None
}
}
fn unload_model(&mut self) {
self.model_loaded.store(false, Ordering::SeqCst);
}
}
struct MockStreamIterator<'a> {
backend: &'a MockLlmBackend,
context: Vec<u32>,
remaining: usize,
seed_offset: usize,
finished: bool,
}
impl<'a> Iterator for MockStreamIterator<'a> {
type Item = Result<GeneratedToken>;
fn next(&mut self) -> Option<Self::Item> {
if self.finished || self.remaining == 0 {
return None;
}
let token = self
.backend
.deterministic_token(&self.context, self.seed_offset);
self.seed_offset += 1;
self.remaining -= 1;
let text = self.backend.tokenizer.decode(&[token]).unwrap_or_default();
let is_eos = token == 2;
if is_eos {
self.finished = true;
}
self.context.push(token);
Some(Ok(GeneratedToken {
id: token,
text,
logprob: Some(-0.5),
is_special: is_eos,
}))
}
}
/// Create a test serving engine with mock backend
fn create_mock_serving_engine() -> (ServingEngine, Arc<MockLlmBackend>) {
let backend = Arc::new(MockLlmBackend::new());
let config = ServingEngineConfig {
kv_cache: KvCachePoolConfig {
num_slots: 8,
max_seq_len: 512,
block_size: 16,
total_blocks: 128,
num_kv_heads: 4,
head_dim: 64,
num_layers: 8,
},
max_concurrent_requests: 16,
enable_speculative: false, // Disable for basic tests
..Default::default()
};
let engine = ServingEngine::new(backend.clone() as Arc<dyn LlmBackend>, config);
(engine, backend)
}
// ============================================================================
// GGUF Loading Tests
// ============================================================================
#[test]
fn test_gguf_load_and_generate_basic() {
// Test: Load a minimal GGUF, verify parsing works, then generate tokens
let gguf_data = create_minimal_test_gguf();
// Parse GGUF header
assert!(gguf_data.len() >= 24); // Minimum header size
let magic = u32::from_le_bytes([gguf_data[0], gguf_data[1], gguf_data[2], gguf_data[3]]);
assert_eq!(magic, GGUF_MAGIC, "Magic number should match");
let version = u32::from_le_bytes([gguf_data[4], gguf_data[5], gguf_data[6], gguf_data[7]]);
assert_eq!(version, GGUF_VERSION, "Version should be 3");
// Create mock backend and generate
let mut backend = MockLlmBackend::new();
backend
.load_model("test-model", ModelConfig::default())
.unwrap();
let params = GenerateParams::default().with_max_tokens(10);
let output = backend.generate("Hello world", params).unwrap();
assert!(!output.is_empty(), "Should generate some output");
}
#[test]
fn test_gguf_load_with_metadata() {
// Test: Load GGUF with metadata, verify extraction
let gguf_data = create_test_gguf_with_metadata();
// The data should be large enough to contain metadata
assert!(gguf_data.len() > 100, "Should have metadata");
// Verify magic
let magic = u32::from_le_bytes([gguf_data[0], gguf_data[1], gguf_data[2], gguf_data[3]]);
assert_eq!(magic, GGUF_MAGIC);
// Count metadata (at offset 16)
let metadata_count = u64::from_le_bytes(gguf_data[16..24].try_into().unwrap());
assert_eq!(metadata_count, 3, "Should have 3 metadata entries");
}
#[test]
fn test_gguf_load_with_quantization() {
// Test: Verify Q4_0, Q4_K, Q8_0 quantized model metadata parsing
let gguf_data = create_test_gguf_q4_quantized();
// Parse and verify header
let magic = u32::from_le_bytes([gguf_data[0], gguf_data[1], gguf_data[2], gguf_data[3]]);
assert_eq!(magic, GGUF_MAGIC);
let tensor_count = u64::from_le_bytes(gguf_data[8..16].try_into().unwrap());
assert_eq!(tensor_count, 1, "Should have 1 quantized tensor");
// Test quantization type bytes_per_weight
assert_eq!(Quantization::Q4.bytes_per_weight(), 0.5);
assert_eq!(Quantization::Q4K.bytes_per_weight(), 0.5);
assert_eq!(Quantization::Q8.bytes_per_weight(), 1.0);
assert!(Quantization::Q4.is_gguf());
assert!(Quantization::Q4K.is_gguf());
assert!(Quantization::Q8.is_gguf());
assert!(!Quantization::F16.is_gguf());
}
// ============================================================================
// Streaming Generation Tests
// ============================================================================
#[test]
fn test_streaming_generation() {
// Test: Streaming callback generation works correctly
let mut backend = MockLlmBackend::new();
backend
.load_model("test-model", ModelConfig::default())
.unwrap();
let params = GenerateParams::default()
.with_max_tokens(20)
.with_temperature(0.7);
// Collect streaming output
let mut tokens_received = Vec::new();
let stream = backend.generate_stream("Hello world", params).unwrap();
for result in stream {
let token = result.expect("Stream should not error");
tokens_received.push(token);
}
assert!(!tokens_received.is_empty(), "Should receive tokens");
assert!(tokens_received.len() <= 20, "Should respect max_tokens");
// Verify each token has valid fields
for token in &tokens_received {
assert!(token.id > 0 || token.is_special, "Token ID should be valid");
}
}
#[test]
fn test_streaming_generation_v2() {
// Test: New TokenStream interface
let mut backend = MockLlmBackend::new();
backend
.load_model("test-model", ModelConfig::default())
.unwrap();
let params = GenerateParams::default()
.with_max_tokens(10)
.with_temperature(0.5);
let mut stream = backend.generate_stream_v2("Test prompt", params).unwrap();
let mut token_count = 0;
let mut received_done = false;
// Use try_next with timeout to avoid blocking forever
let deadline = Instant::now() + Duration::from_secs(5);
while Instant::now() < deadline && !stream.is_finished() {
if let Some(result) = stream.recv_timeout(Duration::from_millis(100)) {
match result {
Ok(StreamEvent::Token(token)) => {
token_count += 1;
assert!(!token.text.is_empty() || token.is_special);
}
Ok(StreamEvent::Done { total_tokens, .. }) => {
received_done = true;
assert_eq!(total_tokens, token_count);
break;
}
Ok(StreamEvent::Error(e)) => {
panic!("Stream error: {}", e);
}
Err(e) => {
panic!("Result error: {:?}", e);
}
}
}
}
assert!(received_done, "Should receive Done event");
assert!(token_count > 0, "Should receive at least one token");
}
#[test]
fn test_streaming_with_callback() {
// Test: Streaming with callback in serving engine
let (engine, backend) = create_mock_serving_engine();
// Load model through backend
backend.model_loaded.store(true, Ordering::SeqCst);
let tokens_received = Arc::new(AtomicUsize::new(0));
let tokens_clone = tokens_received.clone();
let params = GenerateParams::default().with_max_tokens(5);
let request = InferenceRequest::new(vec![100, 101, 102], params);
let callback: Box<dyn Fn(TokenOutput) + Send + Sync> = Box::new(move |_output| {
tokens_clone.fetch_add(1, Ordering::Relaxed);
});
let _ = engine.submit_with_callback(request, callback);
// Run several iterations
for _ in 0..30 {
let _ = engine.run_iteration();
}
// Should have received some callbacks
let _received = tokens_received.load(Ordering::Relaxed);
// May or may not have tokens depending on timing
}
// ============================================================================
// Speculative Decoding Tests
// ============================================================================
#[test]
fn test_speculative_decoding_config() {
// Test: Speculative decoding configuration
let config = SpeculativeConfig::default();
assert!(config.lookahead >= 2, "Lookahead should be at least 2");
assert!(config.lookahead <= 16, "Lookahead should be reasonable");
assert!(config.acceptance_threshold > 0.0 && config.acceptance_threshold <= 1.0);
assert!(
config.adaptive_lookahead,
"Adaptive lookahead should be on by default"
);
}
#[test]
fn test_speculative_stats() {
// Test: Statistics tracking for speculative decoding
let mut stats = SpeculativeStats::new();
assert_eq!(stats.draft_tokens, 0);
assert_eq!(stats.accepted_tokens, 0);
assert_eq!(stats.acceptance_rate, 0.0);
// Record some speculation rounds
stats.record_round(4, 3, 10.0);
assert_eq!(stats.draft_tokens, 4);
assert_eq!(stats.accepted_tokens, 3);
assert!((stats.acceptance_rate - 0.75).abs() < 0.01);
assert_eq!(stats.total_tokens_generated, 4); // 3 accepted + 1 correction
stats.record_round(4, 4, 8.0);
assert_eq!(stats.draft_tokens, 8);
assert_eq!(stats.accepted_tokens, 7);
// Reset
stats.reset();
assert_eq!(stats.draft_tokens, 0);
}
#[test]
fn test_atomic_speculative_stats() {
// Test: Thread-safe atomic statistics
let stats = AtomicSpeculativeStats::new();
// Record from multiple threads
let stats_arc = Arc::new(stats);
let mut handles = vec![];
for _ in 0..4 {
let stats_clone = stats_arc.clone();
let handle = std::thread::spawn(move || {
for _ in 0..10 {
stats_clone.record_round(4, 3, Duration::from_millis(10));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let snapshot = stats_arc.snapshot();
assert_eq!(snapshot.draft_tokens, 4 * 10 * 4);
assert_eq!(snapshot.accepted_tokens, 3 * 10 * 4);
assert_eq!(snapshot.main_forward_passes, 10 * 4);
}
#[test]
fn test_speculation_tree() {
// Test: Tree-based speculation structure
let mut tree = SpeculationTree::new(4, 2);
assert_eq!(tree.node_count, 1);
assert_eq!(tree.max_depth, 4);
assert_eq!(tree.branching_factor, 2);
// Add children to root
tree.root.add_child(100, 0.8);
tree.root.add_child(101, 0.6);
tree.node_count += 2;
assert_eq!(tree.root.children.len(), 2);
// Get paths
let paths = tree.get_candidate_paths();
assert_eq!(paths.len(), 2); // Two leaf paths
// Best path should be the one with higher probability
let best = tree.best_path();
assert!(
best.is_empty() || best[0] == 100,
"Best path should start with high-prob token"
);
}
#[test]
fn test_tree_node_operations() {
// Test: TreeNode building and traversal
let mut root = TreeNode::new(0, 1.0, 0);
assert_eq!(root.token, 0);
assert_eq!(root.depth, 0);
assert!(root.children.is_empty());
// Build a small tree
let child1 = root.add_child(10, 0.7);
child1.add_child(20, 0.8);
child1.add_child(21, 0.4);
let child2 = root.add_child(11, 0.5);
child2.add_child(22, 0.9);
// Get all paths
let paths = root.get_paths();
assert_eq!(paths.len(), 3); // 3 leaf nodes
// Best path should maximize probability
let best = root.best_path();
assert_eq!(best.len(), 3); // root -> child -> leaf
}
#[test]
fn test_speculative_decoding_e2e() {
// Test: Full speculative decoding pipeline (mock)
let main_model = Arc::new(MockLlmBackend::new());
let draft_model = Arc::new(MockLlmBackend::new());
// Load both models
unsafe {
(Arc::as_ptr(&main_model) as *mut MockLlmBackend)
.as_mut()
.unwrap()
.load_model("main", ModelConfig::default())
.unwrap();
(Arc::as_ptr(&draft_model) as *mut MockLlmBackend)
.as_mut()
.unwrap()
.load_model("draft", ModelConfig::default())
.unwrap();
}
let config = SpeculativeConfig {
lookahead: 4,
acceptance_threshold: 0.5,
draft_temperature: 0.0,
tree_speculation: false,
adaptive_lookahead: true,
min_lookahead: 2,
max_lookahead: 8,
..Default::default()
};
let decoder = SpeculativeDecoder::new(main_model, draft_model, config);
// Verify configuration
let cfg = decoder.config();
assert_eq!(cfg.lookahead, 4);
// Check tokenizer availability
assert!(decoder.tokenizer().is_some());
// Get initial stats
let stats = decoder.stats();
assert_eq!(stats.draft_tokens, 0);
}
// ============================================================================
// KV Cache Tests
// ============================================================================
#[test]
fn test_kv_cache_persistence() {
// Test: Generate, cache, continue generating
let config = KvCacheConfig {
tail_length: 16,
max_tokens: 64,
num_kv_heads: 2,
head_dim: 32,
migration_batch: 8,
..Default::default()
};
let cache = TwoTierKvCache::new(config);
// Add initial context
for i in 0..10 {
let keys = vec![i as f32 * 0.1; 2 * 32];
let values = vec![i as f32 * 0.2; 2 * 32];
cache.append(&keys, &values).unwrap();
}
let stats1 = cache.stats();
assert_eq!(stats1.total_tokens, 10);
// Query with current cache (simulating continuation)
// Query size should match num_kv_heads * head_dim = 2 * 32 = 64
let query = vec![0.5f32; 2 * 32];
let scale = 1.0 / 32.0f32.sqrt();
let output1 = cache.attend(&query, scale).unwrap();
assert_eq!(output1.len(), 2 * 32);
// Add more tokens (continuation)
for i in 10..20 {
let keys = vec![i as f32 * 0.1; 2 * 32];
let values = vec![i as f32 * 0.2; 2 * 32];
cache.append(&keys, &values).unwrap();
}
let stats2 = cache.stats();
assert_eq!(stats2.total_tokens, 20);
// Query again - should now attend over more tokens
let output2 = cache.attend(&query, scale).unwrap();
assert_eq!(output2.len(), 2 * 32);
// Outputs should be different due to more context
let diff: f32 = output1
.iter()
.zip(output2.iter())
.map(|(a, b)| (a - b).abs())
.sum();
// Could be same if attention weights distribute similarly, so just check finite
assert!(diff.is_finite());
}
#[test]
fn test_kv_cache_two_tier_migration() {
// Test: Verify tail -> store migration
let config = KvCacheConfig {
tail_length: 4,
max_tokens: 100,
num_kv_heads: 1,
head_dim: 8,
migration_batch: 2,
..Default::default()
};
let cache = TwoTierKvCache::new(config);
// Add enough tokens to trigger migration
for i in 0..10 {
let keys = vec![i as f32; 8];
let values = vec![i as f32 * 2.0; 8];
cache.append(&keys, &values).unwrap();
}
let stats = cache.stats();
// Tail should be limited, store should have overflow
assert!(stats.tail_tokens <= 4, "Tail should respect limit");
assert!(stats.store_tokens > 0, "Store should have migrated tokens");
assert_eq!(stats.total_tokens, 10);
}
#[test]
fn test_kv_cache_concurrent_access() {
// Test: Concurrent KV cache operations
let config = KvCacheConfig {
tail_length: 32,
max_tokens: 256,
num_kv_heads: 4,
head_dim: 64,
migration_batch: 16,
..Default::default()
};
let cache = Arc::new(TwoTierKvCache::new(config));
let mut handles = vec![];
// Spawn concurrent writers
for t in 0..4 {
let cache_clone = cache.clone();
let handle = std::thread::spawn(move || {
for i in 0..25 {
let keys = vec![(t * 100 + i) as f32; 4 * 64];
let values = vec![(t * 100 + i) as f32 * 2.0; 4 * 64];
cache_clone.append(&keys, &values).unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let stats = cache.stats();
assert_eq!(stats.total_tokens, 100); // 4 threads * 25 tokens
}
// ============================================================================
// Batch Generation Tests
// ============================================================================
#[test]
fn test_batch_generation() {
// Test: Multiple prompts processed in batch
let (engine, backend) = create_mock_serving_engine();
backend.model_loaded.store(true, Ordering::SeqCst);
// Submit multiple requests
let mut request_ids = Vec::new();
let prompts = vec![
vec![100, 101, 102], // "Hello , "
vec![105, 106, 107], // "The quick brown"
vec![114, 115, 116], // "test model output"
];
for prompt in prompts {
let params = GenerateParams::default().with_max_tokens(5);
let request = InferenceRequest::new(prompt, params);
let id = engine.submit(request).unwrap();
request_ids.push(id);
}
// Run iterations to process all
for _ in 0..50 {
let _ = engine.run_iteration();
}
// Check metrics
let stats = engine.stats();
// Should have processed requests
assert!(
stats.running_requests > 0 || stats.completed_requests > 0 || stats.pending_requests > 0,
"Should have processed some requests"
);
}
#[test]
fn test_batch_priority_ordering() {
// Test: Higher priority requests are processed first
let (engine, backend) = create_mock_serving_engine();
backend.model_loaded.store(true, Ordering::SeqCst);
// Submit low priority first
let params = GenerateParams::default().with_max_tokens(3);
let mut low_req = InferenceRequest::new(vec![100], params.clone());
low_req.priority = Priority::Low;
let _low_id = engine.submit(low_req).unwrap();
// Submit high priority second
let mut high_req = InferenceRequest::new(vec![101], params);
high_req.priority = Priority::High;
let _high_id = engine.submit(high_req).unwrap();
// Priority values
assert!(Priority::High.value() > Priority::Low.value());
assert!(Priority::Critical.value() > Priority::High.value());
}
// ============================================================================
// Stop Sequence Tests
// ============================================================================
#[test]
fn test_stop_sequences() {
// Test: Generation stops at stop sequences
let mut backend = MockLlmBackend::new();
backend.load_model("test", ModelConfig::default()).unwrap();
let params = GenerateParams::default()
.with_max_tokens(100)
.with_stop_sequence("\n\n")
.with_stop_sequence("END");
// Generate - the mock backend won't actually hit stop sequences
// but we verify the params are stored correctly
assert_eq!(params.stop_sequences.len(), 2);
assert!(params.stop_sequences.contains(&"\n\n".to_string()));
assert!(params.stop_sequences.contains(&"END".to_string()));
}
#[test]
fn test_multiple_stop_sequences() {
// Test: Multiple stop sequences configuration
let params = GenerateParams::default()
.with_stop_sequence("<|end|>")
.with_stop_sequence("</s>")
.with_stop_sequence("STOP")
.with_stop_sequence("\n---\n");
assert_eq!(params.stop_sequences.len(), 4);
// Verify each sequence is present
for seq in &["<|end|>", "</s>", "STOP", "\n---\n"] {
assert!(
params.stop_sequences.contains(&seq.to_string()),
"Should contain {}",
seq
);
}
}
// ============================================================================
// Temperature Sampling Tests
// ============================================================================
#[test]
fn test_temperature_sampling() {
// Test: Temperature affects output diversity
let mut backend = MockLlmBackend::new();
backend.load_model("test", ModelConfig::default()).unwrap();
// Low temperature (more deterministic)
let low_temp_params = GenerateParams::default()
.with_max_tokens(10)
.with_temperature(0.1);
// High temperature (more random)
let high_temp_params = GenerateParams::default()
.with_max_tokens(10)
.with_temperature(1.5);
// Our mock backend doesn't actually use temperature, but we verify params
assert!(low_temp_params.temperature < high_temp_params.temperature);
assert!(low_temp_params.temperature < 0.5);
assert!(high_temp_params.temperature > 1.0);
}
#[test]
fn test_softmax_temperature_effect() {
// Test: Verify softmax correctly concentrates/diffuses with temperature
let logits = vec![1.0f32, 2.0, 3.0, 4.0];
// Standard softmax
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!((sum - 1.0).abs() < 0.001, "Softmax should sum to 1");
// Verify ordering preserved
assert!(probs[3] > probs[2]);
assert!(probs[2] > probs[1]);
assert!(probs[1] > probs[0]);
// Test with scaled logits (simulating low temperature)
let scaled: Vec<f32> = logits.iter().map(|&x| x * 5.0).collect();
let probs_sharp = softmax(&scaled);
// Sharp distribution should have higher max probability
assert!(
probs_sharp[3] > probs[3],
"Lower temperature should concentrate probability"
);
}
#[test]
fn test_log_softmax() {
// Test: Log softmax for numerical stability
let logits = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let log_probs = log_softmax(&logits);
// All log probs should be <= 0
for &lp in &log_probs {
assert!(lp <= 0.0, "Log probability should be <= 0");
assert!(lp.is_finite(), "Log probability should be finite");
}
// exp(log_softmax) should equal softmax
let probs_from_log: Vec<f32> = log_probs.iter().map(|&lp| lp.exp()).collect();
let probs = softmax(&logits);
for (a, b) in probs_from_log.iter().zip(probs.iter()) {
assert!(
(a - b).abs() < 0.001,
"exp(log_softmax) should equal softmax"
);
}
}
#[test]
fn test_top_k_filtering() {
// Test: Top-k sampling correctly filters
let mut logits = vec![1.0f32, 5.0, 3.0, 4.0, 2.0];
top_k_filter(&mut logits, 2);
// Only top 2 should remain finite
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert_eq!(finite_count, 2, "Top-k should keep exactly k values");
// The top 2 values (5.0 and 4.0 at indices 1 and 3) should be finite
assert!(logits[1].is_finite()); // 5.0
assert!(logits[3].is_finite()); // 4.0
}
#[test]
fn test_top_p_filtering() {
// Test: Nucleus (top-p) sampling correctly filters
let mut logits = vec![10.0f32, 5.0, 3.0, 2.0, 1.0];
top_p_filter(&mut logits, 0.9);
// Most probability mass should be preserved
let finite_count = logits.iter().filter(|x| x.is_finite()).count();
assert!(finite_count >= 1, "Top-p should keep at least one value");
assert!(finite_count < 5, "Top-p with 0.9 should filter some values");
}
#[test]
fn test_sampling_from_probabilities() {
// Test: Sample from probability distribution
use rand::SeedableRng;
let probs = vec![0.1f32, 0.2, 0.3, 0.4];
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let mut counts = vec![0usize; 4];
// Sample many times
for _ in 0..1000 {
let idx = sample_from_probs(&probs, &mut rng);
counts[idx] += 1;
}
// Higher probability indices should be sampled more often
// With these probabilities: idx 3 (0.4) > idx 2 (0.3) > idx 1 (0.2) > idx 0 (0.1)
assert!(
counts[3] > counts[0],
"Higher prob should be sampled more: {} vs {}",
counts[3],
counts[0]
);
}
#[test]
fn test_deterministic_generation_with_seed() {
// Test: Same seed produces same output
let mut backend1 = MockLlmBackend::new();
let mut backend2 = MockLlmBackend::new();
backend1.load_model("test", ModelConfig::default()).unwrap();
backend2.load_model("test", ModelConfig::default()).unwrap();
let params = GenerateParams::default().with_max_tokens(10).with_seed(42);
let output1 = backend1.generate("Hello", params.clone()).unwrap();
let output2 = backend2.generate("Hello", params).unwrap();
// With mock backend using deterministic generation, outputs should match
assert_eq!(output1, output2, "Same seed should produce same output");
}
// ============================================================================
// Real Model Tests (Requires TEST_MODEL_PATH)
// ============================================================================
#[test]
#[ignore = "Requires GGUF model file at TEST_MODEL_PATH environment variable"]
fn test_real_model_generation() {
// Test: Load actual GGUF model and generate
let model_path =
env::var("TEST_MODEL_PATH").expect("TEST_MODEL_PATH environment variable must be set");
let path = Path::new(&model_path);
assert!(path.exists(), "Model file should exist: {}", model_path);
// For now, just verify the file exists and can be opened
let file = std::fs::File::open(path).expect("Should open model file");
let metadata = file.metadata().expect("Should read metadata");
assert!(
metadata.len() > 1024,
"Model file should be larger than 1KB"
);
// Read and verify GGUF magic
let mut buffer = [0u8; 4];
use std::io::Read;
let mut file = std::fs::File::open(path).unwrap();
file.read_exact(&mut buffer).expect("Should read magic");
let magic = u32::from_le_bytes(buffer);
assert_eq!(magic, GGUF_MAGIC, "Should have valid GGUF magic");
}
#[test]
#[ignore = "Requires GGUF model file at TEST_MODEL_PATH environment variable"]
fn test_real_model_streaming() {
// Test: Stream generation from real model
let model_path =
env::var("TEST_MODEL_PATH").expect("TEST_MODEL_PATH environment variable must be set");
// Would need real model loading here
// For now, verify environment is set correctly
assert!(
!model_path.is_empty(),
"TEST_MODEL_PATH should not be empty"
);
}
#[test]
#[ignore = "Requires GGUF model file at TEST_MODEL_PATH environment variable"]
fn test_real_model_quantization() {
// Test: Load quantized model and verify inference
let _model_path =
env::var("TEST_MODEL_PATH").expect("TEST_MODEL_PATH environment variable must be set");
// Verify quantization types
assert!(Quantization::Q4K.is_gguf());
assert!(Quantization::Q8.is_gguf());
// Memory estimation for different quantizations
let param_count: f64 = 7_000_000_000.0; // 7B params
let q4k_memory = param_count * Quantization::Q4K.bytes_per_weight() as f64;
let q8_memory = param_count * Quantization::Q8.bytes_per_weight() as f64;
let f16_memory = param_count * Quantization::F16.bytes_per_weight() as f64;
assert!(q4k_memory < q8_memory);
assert!(q8_memory < f16_memory);
// ~3.5GB for Q4K, ~7GB for Q8, ~14GB for F16
assert!(q4k_memory < 5_000_000_000.0);
assert!(q8_memory < 10_000_000_000.0);
assert!(f16_memory < 20_000_000_000.0);
}
// ============================================================================
// Integration Tests - Full Pipeline
// ============================================================================
#[test]
fn test_full_pipeline_mock() {
// Test: Complete pipeline from request to completion
let (engine, backend) = create_mock_serving_engine();
backend.model_loaded.store(true, Ordering::SeqCst);
// Create and submit request
let params = GenerateParams::default()
.with_max_tokens(10)
.with_temperature(0.7)
.with_top_p(0.9);
let request = InferenceRequest::new(vec![100, 101, 102, 103, 104], params);
let request_id = engine.submit(request).unwrap();
// Process until completion or timeout
let deadline = Instant::now() + Duration::from_secs(5);
while Instant::now() < deadline {
let _ = engine.run_iteration();
if engine.is_complete(request_id) {
break;
}
std::thread::sleep(Duration::from_millis(10));
}
// Should have made progress
let stats = engine.stats();
assert!(
stats.running_requests > 0 || stats.completed_requests > 0 || stats.pending_requests > 0
);
}
#[test]
fn test_engine_metrics() {
// Test: Serving engine metrics collection
let (engine, backend) = create_mock_serving_engine();
backend.model_loaded.store(true, Ordering::SeqCst);
// Initial metrics
let metrics = engine.metrics();
assert_eq!(metrics.pending_requests, 0);
assert_eq!(metrics.running_requests, 0);
assert!(metrics.uptime_seconds >= 0.0);
// Submit some requests
for _ in 0..3 {
let params = GenerateParams::default().with_max_tokens(5);
let request = InferenceRequest::new(vec![100, 101], params);
engine.submit(request).unwrap();
}
// Run a few iterations
for _ in 0..10 {
let _ = engine.run_iteration();
}
// Check updated metrics
let metrics = engine.metrics();
// Requests may have completed by now, so check all states
assert!(
metrics.pending_requests > 0
|| metrics.running_requests > 0
|| metrics.completed_requests > 0
|| metrics.total_requests_processed > 0,
"Should have requests processed, pending, running, or completed: {:?}",
(
metrics.pending_requests,
metrics.running_requests,
metrics.completed_requests,
metrics.total_requests_processed
)
);
}
#[test]
fn test_request_cancellation() {
// Test: Request can be cancelled mid-generation
let (engine, backend) = create_mock_serving_engine();
backend.model_loaded.store(true, Ordering::SeqCst);
let params = GenerateParams::default().with_max_tokens(100);
let request = InferenceRequest::new(vec![100, 101, 102], params);
let request_id = engine.submit(request).unwrap();
// Start processing
for _ in 0..5 {
let _ = engine.run_iteration();
}
// Cancel
let cancelled = engine.cancel(request_id);
assert!(cancelled, "Should successfully cancel request");
}
#[test]
fn test_concurrent_engine_operations() {
// Test: Engine handles concurrent submissions
let (engine, backend) = create_mock_serving_engine();
backend.model_loaded.store(true, Ordering::SeqCst);
let engine = Arc::new(engine);
let mut handles = vec![];
// Spawn concurrent submitters
for i in 0..4 {
let engine_clone = engine.clone();
let handle = std::thread::spawn(move || {
let params = GenerateParams::default().with_max_tokens(5);
let request = InferenceRequest::new(vec![100 + i as u32], params);
engine_clone.submit(request)
});
handles.push(handle);
}
// All submissions should succeed
for handle in handles {
let result = handle.join().unwrap();
assert!(result.is_ok(), "Concurrent submission should succeed");
}
// Process
for _ in 0..50 {
let _ = engine.run_iteration();
}
}
// ============================================================================
// Error Handling Tests
// ============================================================================
#[test]
fn test_error_handling_unloaded_model() {
// Test: Proper error when model not loaded
let backend = MockLlmBackend::new();
// Don't load model
let params = GenerateParams::default();
let result = backend.generate("Hello", params);
assert!(result.is_err());
match result {
Err(RuvLLMError::Config(msg)) => {
assert!(msg.contains("not loaded"));
}
_ => panic!("Expected Config error for unloaded model"),
}
}
#[test]
fn test_error_handling_invalid_params() {
// Test: Handle edge case parameters
let params = GenerateParams::default()
.with_max_tokens(0) // Edge case: 0 tokens
.with_temperature(0.0); // Edge case: zero temperature (greedy)
assert_eq!(params.max_tokens, 0);
assert_eq!(params.temperature, 0.0);
// These should be handled gracefully by the backend
let mut backend = MockLlmBackend::new();
backend.load_model("test", ModelConfig::default()).unwrap();
let result = backend.generate("Hello", params);
// With max_tokens=0, should return empty or minimal output
assert!(result.is_ok());
}
#[test]
fn test_embeddings_generation() {
// Test: Embedding extraction works correctly
let mut backend = MockLlmBackend::new();
backend.load_model("test", ModelConfig::default()).unwrap();
let embeddings = backend.get_embeddings("Hello world").unwrap();
assert_eq!(embeddings.len(), 768); // Standard embedding dim
// Embeddings should be normalized
let norm: f32 = embeddings.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 0.01,
"Embeddings should be normalized, got norm {}",
norm
);
// Different texts should produce different embeddings
let embeddings2 = backend.get_embeddings("Different text here").unwrap();
let diff: f32 = embeddings
.iter()
.zip(embeddings2.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 0.1,
"Different texts should have different embeddings"
);
}