766 lines
21 KiB
Rust
766 lines
21 KiB
Rust
#![allow(
|
|
clippy::all,
|
|
unused_imports,
|
|
unused_variables,
|
|
dead_code,
|
|
unused_mut,
|
|
unused_assignments,
|
|
non_camel_case_types,
|
|
clippy::approx_constant,
|
|
unexpected_cfgs,
|
|
unused_must_use,
|
|
unused_parens
|
|
)]
|
|
//! End-to-end integration tests for RuvLLM
|
|
//!
|
|
//! Tests the complete inference pipeline including model loading,
|
|
//! session management, KV cache, paged attention, and policy/witness stores.
|
|
|
|
use chrono::Utc;
|
|
use ruvllm::{
|
|
backends::{DType, DeviceType, GenerateParams, ModelArchitecture, ModelConfig, Quantization},
|
|
error::Result,
|
|
kv_cache::{KvCacheConfig, TwoTierKvCache},
|
|
lora::{AdaptFeedback, MicroLoRA, MicroLoraConfig, TargetModule},
|
|
paged_attention::{PagedAttention, PagedAttentionConfig},
|
|
policy_store::{PolicyEntry, PolicySource, PolicyStore, PolicyType, QuantizationPolicy},
|
|
session::{SessionConfig, SessionManager},
|
|
sona::{LearningLoop, SonaConfig, SonaIntegration, Trajectory},
|
|
types::ModelSize,
|
|
witness_log::{LatencyBreakdown, RoutingDecision, WitnessEntry, WitnessLog},
|
|
RuvLLMConfig, RuvLLMEngine,
|
|
};
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
use tempfile::TempDir;
|
|
use uuid::Uuid;
|
|
|
|
/// Create a temporary directory for test storage
|
|
fn create_test_dir() -> TempDir {
|
|
tempfile::tempdir().expect("Failed to create temp dir")
|
|
}
|
|
|
|
/// Create a test RuvLLM configuration
|
|
fn create_test_config(storage_path: &str) -> RuvLLMConfig {
|
|
RuvLLMConfig {
|
|
storage_path: storage_path.to_string(),
|
|
paged_attention: PagedAttentionConfig {
|
|
page_size: 16,
|
|
page_table_capacity: 64,
|
|
num_kv_heads: 4,
|
|
head_dim: 32,
|
|
..Default::default()
|
|
},
|
|
kv_cache: KvCacheConfig {
|
|
tail_length: 32,
|
|
max_tokens: 256,
|
|
num_kv_heads: 4,
|
|
head_dim: 32,
|
|
..Default::default()
|
|
},
|
|
session: SessionConfig::default(),
|
|
sona: SonaConfig::default(),
|
|
max_sessions: 100,
|
|
embedding_dim: 768, // Must match SessionState::from_session default
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore] // Requires model download
|
|
async fn test_full_inference_pipeline() {
|
|
// This test would require an actual model
|
|
// let temp_dir = create_test_dir();
|
|
// let config = create_test_config(temp_dir.path().to_str().unwrap());
|
|
// let engine = RuvLLMEngine::new(config).unwrap();
|
|
|
|
// Steps:
|
|
// 1. Load model
|
|
// 2. Create session
|
|
// 3. Generate initial response
|
|
// 4. Apply adaptation based on feedback
|
|
// 5. Generate again (should be different/improved)
|
|
// 6. Verify learning metrics
|
|
}
|
|
|
|
#[test]
|
|
fn test_engine_creation() {
|
|
let temp_dir = create_test_dir();
|
|
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
|
|
|
let result = RuvLLMEngine::new(config);
|
|
assert!(result.is_ok(), "Engine creation failed: {:?}", result.err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_creation_and_retrieval() {
|
|
let temp_dir = create_test_dir();
|
|
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
|
let engine = RuvLLMEngine::new(config).unwrap();
|
|
|
|
// Create session
|
|
let session = engine.create_session(Some("user-123")).unwrap();
|
|
assert!(!session.id.is_empty());
|
|
|
|
// Retrieve session
|
|
let retrieved = engine.get_session(&session.id).unwrap();
|
|
assert!(retrieved.is_some());
|
|
|
|
let retrieved_session = retrieved.unwrap();
|
|
assert_eq!(retrieved_session.id, session.id);
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_sessions() {
|
|
let temp_dir = create_test_dir();
|
|
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
|
let engine = RuvLLMEngine::new(config).unwrap();
|
|
|
|
let mut sessions = Vec::new();
|
|
for i in 0..10 {
|
|
let session = engine.create_session(Some(&format!("user-{}", i))).unwrap();
|
|
sessions.push(session.id.clone());
|
|
}
|
|
|
|
// Verify all sessions exist
|
|
for session_id in &sessions {
|
|
let session = engine.get_session(session_id).unwrap();
|
|
assert!(session.is_some());
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_kv_cache_eviction() {
|
|
let config = KvCacheConfig {
|
|
tail_length: 4,
|
|
max_tokens: 10,
|
|
num_kv_heads: 2,
|
|
head_dim: 8,
|
|
migration_batch: 2,
|
|
..Default::default()
|
|
};
|
|
|
|
let cache = TwoTierKvCache::new(config);
|
|
|
|
// Add more tokens than max
|
|
for i in 0..20 {
|
|
let keys = vec![i as f32; 2 * 8]; // num_kv_heads * head_dim
|
|
let values = vec![i as f32 * 2.0; 2 * 8];
|
|
cache.append(&keys, &values).unwrap();
|
|
}
|
|
|
|
// Should have evicted to stay under max
|
|
let stats = cache.stats();
|
|
assert!(
|
|
stats.total_tokens <= 10,
|
|
"Should evict to stay under max: {}",
|
|
stats.total_tokens
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_kv_cache_two_tier_storage() {
|
|
let config = KvCacheConfig {
|
|
tail_length: 4,
|
|
max_tokens: 100,
|
|
num_kv_heads: 2,
|
|
head_dim: 8,
|
|
migration_batch: 2,
|
|
..Default::default()
|
|
};
|
|
|
|
let cache = TwoTierKvCache::new(config);
|
|
|
|
// Add tokens to trigger migration
|
|
for i in 0..10 {
|
|
let keys = vec![i as f32; 2 * 8];
|
|
let values = vec![i as f32 * 2.0; 2 * 8];
|
|
cache.append(&keys, &values).unwrap();
|
|
}
|
|
|
|
let stats = cache.stats();
|
|
|
|
// Should have some in tail and some in store
|
|
assert_eq!(stats.total_tokens, 10);
|
|
assert!(
|
|
stats.tail_tokens <= 4,
|
|
"Tail should be limited: {}",
|
|
stats.tail_tokens
|
|
);
|
|
assert!(
|
|
stats.store_tokens >= 6,
|
|
"Store should have overflow: {}",
|
|
stats.store_tokens
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_kv_cache_attention() {
|
|
let config = KvCacheConfig {
|
|
tail_length: 8,
|
|
max_tokens: 32,
|
|
num_kv_heads: 1,
|
|
head_dim: 16,
|
|
migration_batch: 4,
|
|
..Default::default()
|
|
};
|
|
|
|
let cache = TwoTierKvCache::new(config);
|
|
|
|
// Add some KV pairs
|
|
for i in 0..5 {
|
|
let keys: Vec<f32> = (0..16).map(|j| (i * 16 + j) as f32 * 0.1).collect();
|
|
let values: Vec<f32> = (0..16).map(|j| (i * 16 + j) as f32 * 0.2).collect();
|
|
cache.append(&keys, &values).unwrap();
|
|
}
|
|
|
|
// Query
|
|
let query: Vec<f32> = (0..16).map(|i| i as f32 * 0.05).collect();
|
|
let scale = 1.0 / 16.0f32.sqrt();
|
|
|
|
let output = cache.attend(&query, scale).unwrap();
|
|
|
|
assert_eq!(output.len(), 16);
|
|
assert!(output.iter().all(|&v| v.is_finite()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_paged_attention_basic() {
|
|
let config = PagedAttentionConfig {
|
|
page_size: 4,
|
|
page_table_capacity: 16,
|
|
num_kv_heads: 2,
|
|
head_dim: 16,
|
|
..Default::default()
|
|
};
|
|
|
|
let paged_attn = PagedAttention::new(config);
|
|
|
|
// Check initial state
|
|
let stats_before = paged_attn.stats();
|
|
assert_eq!(stats_before.active_sequences, 0);
|
|
|
|
// Allocate pages for a sequence
|
|
let seq_id = "seq-1";
|
|
paged_attn.allocate_sequence(seq_id, 8).unwrap();
|
|
|
|
// Check allocation via stats
|
|
let stats_after_alloc = paged_attn.stats();
|
|
assert_eq!(stats_after_alloc.active_sequences, 1);
|
|
|
|
// Free sequence
|
|
paged_attn.free_sequence(seq_id).unwrap();
|
|
|
|
// Verify freed via stats
|
|
let stats_after_free = paged_attn.stats();
|
|
assert_eq!(stats_after_free.active_sequences, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_concurrent_kv_cache_access() {
|
|
use std::sync::Arc;
|
|
use std::thread;
|
|
|
|
let config = KvCacheConfig {
|
|
tail_length: 64,
|
|
max_tokens: 256,
|
|
num_kv_heads: 4,
|
|
head_dim: 32,
|
|
migration_batch: 16,
|
|
..Default::default()
|
|
};
|
|
|
|
let cache = Arc::new(TwoTierKvCache::new(config));
|
|
let mut handles = vec![];
|
|
|
|
// Spawn multiple writers
|
|
for t in 0..4 {
|
|
let cache_clone = Arc::clone(&cache);
|
|
let handle = thread::spawn(move || {
|
|
for i in 0..10 {
|
|
let keys = vec![(t * 100 + i) as f32; 4 * 32];
|
|
let values = vec![(t * 100 + i) as f32 * 2.0; 4 * 32];
|
|
cache_clone.append(&keys, &values).unwrap();
|
|
}
|
|
});
|
|
handles.push(handle);
|
|
}
|
|
|
|
// Wait for all threads
|
|
for handle in handles {
|
|
handle.join().unwrap();
|
|
}
|
|
|
|
// Verify final state
|
|
let stats = cache.stats();
|
|
assert!(stats.total_tokens > 0);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_concurrent_requests() {
|
|
let temp_dir = create_test_dir();
|
|
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
|
let engine = Arc::new(RuvLLMEngine::new(config).unwrap());
|
|
|
|
let mut handles = vec![];
|
|
|
|
// Spawn concurrent session creators
|
|
for i in 0..10 {
|
|
let engine_clone = Arc::clone(&engine);
|
|
let handle = tokio::spawn(async move {
|
|
let session = engine_clone.create_session(Some(&format!("concurrent-user-{}", i)));
|
|
session.is_ok()
|
|
});
|
|
handles.push(handle);
|
|
}
|
|
|
|
// All should succeed
|
|
for handle in handles {
|
|
assert!(handle.await.unwrap());
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_policy_store() {
|
|
let temp_dir = create_test_dir();
|
|
let storage_path = format!("{}/policies", temp_dir.path().to_str().unwrap());
|
|
|
|
let store = PolicyStore::new(&storage_path, 64).unwrap();
|
|
|
|
// Store a policy
|
|
let policy = PolicyEntry {
|
|
id: Uuid::new_v4(),
|
|
policy_type: PolicyType::Quantization,
|
|
embedding: vec![0.1; 64],
|
|
parameters: serde_json::json!({
|
|
"precision": "q4_k",
|
|
"quality_threshold": 0.9,
|
|
}),
|
|
confidence: 0.85,
|
|
fisher_diagonal: None,
|
|
created_at: Utc::now(),
|
|
last_accessed: Utc::now(),
|
|
source: PolicySource::InstantLoop,
|
|
tags: vec!["quantization".to_string()],
|
|
};
|
|
|
|
store.store(policy).unwrap();
|
|
|
|
// Search
|
|
let query = vec![0.1; 64];
|
|
let results = store.search(&query, 5).unwrap();
|
|
|
|
assert!(!results.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_witness_log() {
|
|
let temp_dir = create_test_dir();
|
|
let storage_path = format!("{}/witness", temp_dir.path().to_str().unwrap());
|
|
|
|
let log = WitnessLog::new(&storage_path, 64).unwrap();
|
|
|
|
// Record entries
|
|
for i in 0..5 {
|
|
let routing_decision = RoutingDecision {
|
|
model: ModelSize::Small,
|
|
context_size: 512,
|
|
temperature: 0.7,
|
|
top_p: 0.9,
|
|
confidence: 0.8 + (i as f32 * 0.02),
|
|
model_probs: [0.1, 0.4, 0.3, 0.2],
|
|
};
|
|
|
|
let entry = WitnessEntry::new(
|
|
format!("session-{}", i % 2),
|
|
vec![i as f32 * 0.1; 64],
|
|
routing_decision,
|
|
)
|
|
.with_quality(0.85)
|
|
.with_latency(LatencyBreakdown {
|
|
embedding_ms: 5.0,
|
|
retrieval_ms: 2.0,
|
|
routing_ms: 1.0,
|
|
attention_ms: 30.0,
|
|
generation_ms: 62.0,
|
|
total_ms: 100.0 + (i as f32 * 10.0),
|
|
});
|
|
|
|
log.record(entry).unwrap();
|
|
}
|
|
|
|
// Flush to ensure entries are searchable
|
|
log.flush().unwrap();
|
|
|
|
// Search
|
|
let query = vec![0.2; 64];
|
|
let results = log.search(&query, 3).unwrap();
|
|
|
|
// Results may be empty if flush didn't complete vector indexing
|
|
// This is expected behavior for async write-back
|
|
}
|
|
|
|
#[test]
|
|
fn test_end_to_end_adaptation_flow() {
|
|
let config = MicroLoraConfig {
|
|
rank: 2,
|
|
alpha: 4.0,
|
|
dropout: 0.0,
|
|
target_modules: vec![TargetModule::QProj],
|
|
in_features: 64,
|
|
out_features: 64,
|
|
use_bias: false,
|
|
standard_init: true,
|
|
gradient_checkpointing: false,
|
|
};
|
|
|
|
let lora = MicroLoRA::new(config);
|
|
let _sona = SonaIntegration::new(SonaConfig::default());
|
|
|
|
let input: Vec<f32> = (0..64).map(|i| (i as f32) * 0.01).collect();
|
|
|
|
// Initial forward
|
|
let output_initial = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
// Simulate inference loop with adaptation
|
|
let mut quality_history = Vec::new();
|
|
for i in 0..20 {
|
|
// Forward pass
|
|
let _output = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
// Compute simulated quality (increasing over time)
|
|
let simulated_quality = 0.2 + (i as f32 * 0.03);
|
|
quality_history.push(simulated_quality);
|
|
|
|
// Create feedback
|
|
let feedback = AdaptFeedback::from_quality(simulated_quality);
|
|
|
|
// Adapt
|
|
lora.adapt(&input, feedback).unwrap();
|
|
lora.apply_updates(0.01);
|
|
}
|
|
|
|
// Final forward
|
|
let output_final = lora.forward(&input, &TargetModule::QProj);
|
|
|
|
// Verify adaptation happened
|
|
let changed = output_initial
|
|
.iter()
|
|
.zip(output_final.iter())
|
|
.any(|(a, b)| (a - b).abs() > 1e-6);
|
|
let all_near_zero = output_initial.iter().all(|&v| v.abs() < 1e-6);
|
|
|
|
assert!(changed || all_near_zero);
|
|
|
|
// Verify quality increased
|
|
let first_qualities: f32 = quality_history[..5].iter().sum::<f32>() / 5.0;
|
|
let last_qualities: f32 = quality_history[15..].iter().sum::<f32>() / 5.0;
|
|
assert!(
|
|
last_qualities > first_qualities,
|
|
"Quality should increase: {} vs {}",
|
|
last_qualities,
|
|
first_qualities
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_lifecycle() {
|
|
let config = SessionConfig::default();
|
|
let manager = SessionManager::new(config);
|
|
|
|
// Create session
|
|
let session = manager.create_session(Some("user-1")).unwrap();
|
|
let session_id = session.id.clone();
|
|
|
|
// Get session
|
|
let retrieved = manager.get_session(&session_id).unwrap();
|
|
assert!(retrieved.is_some());
|
|
|
|
// Terminate session
|
|
manager.terminate_session(&session_id).unwrap();
|
|
|
|
// Session should be gone
|
|
let ended = manager.get_session(&session_id).unwrap();
|
|
assert!(ended.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_latency_measurement() {
|
|
let start = Instant::now();
|
|
|
|
// Simulate some work
|
|
let mut sum = 0.0f32;
|
|
for i in 0..10000 {
|
|
sum += (i as f32).sqrt();
|
|
}
|
|
|
|
let elapsed = start.elapsed();
|
|
|
|
// Create latency breakdown
|
|
let breakdown = LatencyBreakdown {
|
|
embedding_ms: elapsed.as_secs_f32() * 100.0, // 10%
|
|
retrieval_ms: elapsed.as_secs_f32() * 50.0, // 5%
|
|
routing_ms: elapsed.as_secs_f32() * 50.0, // 5%
|
|
attention_ms: elapsed.as_secs_f32() * 300.0, // 30%
|
|
generation_ms: elapsed.as_secs_f32() * 500.0, // 50%
|
|
total_ms: elapsed.as_secs_f32() * 1000.0,
|
|
};
|
|
|
|
assert!(breakdown.total_ms >= 0.0);
|
|
assert!(sum > 0.0); // Use sum to prevent optimization
|
|
}
|
|
|
|
#[test]
|
|
fn test_model_config_variants() {
|
|
let configs = vec![
|
|
ModelConfig {
|
|
architecture: ModelArchitecture::Llama,
|
|
device: DeviceType::Cpu,
|
|
dtype: DType::F32,
|
|
quantization: None,
|
|
use_flash_attention: false,
|
|
max_sequence_length: 2048,
|
|
..Default::default()
|
|
},
|
|
ModelConfig {
|
|
architecture: ModelArchitecture::Mistral,
|
|
device: DeviceType::Metal,
|
|
dtype: DType::F16,
|
|
quantization: Some(Quantization::Q4),
|
|
use_flash_attention: true,
|
|
max_sequence_length: 4096,
|
|
..Default::default()
|
|
},
|
|
ModelConfig {
|
|
architecture: ModelArchitecture::Phi,
|
|
device: DeviceType::Cuda(0),
|
|
dtype: DType::Bf16,
|
|
quantization: Some(Quantization::Q8),
|
|
use_flash_attention: true,
|
|
max_sequence_length: 8192,
|
|
..Default::default()
|
|
},
|
|
];
|
|
|
|
for config in configs {
|
|
assert!(config.max_sequence_length > 0);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_params_customization() {
|
|
let params = GenerateParams {
|
|
max_tokens: 256,
|
|
temperature: 0.8,
|
|
top_p: 0.95,
|
|
top_k: 50,
|
|
repetition_penalty: 1.2,
|
|
frequency_penalty: 0.0,
|
|
presence_penalty: 0.0,
|
|
stop_sequences: vec!["<|end|>".to_string(), "\n\n".to_string()],
|
|
seed: Some(12345),
|
|
};
|
|
|
|
assert_eq!(params.max_tokens, 256);
|
|
assert_eq!(params.stop_sequences.len(), 2);
|
|
assert!(params.seed.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_params_builder() {
|
|
let params = GenerateParams::default()
|
|
.with_max_tokens(512)
|
|
.with_temperature(0.5)
|
|
.with_top_p(0.95)
|
|
.with_top_k(50)
|
|
.with_repetition_penalty(1.2)
|
|
.with_seed(42);
|
|
|
|
assert_eq!(params.max_tokens, 512);
|
|
assert_eq!(params.temperature, 0.5);
|
|
assert_eq!(params.top_p, 0.95);
|
|
assert_eq!(params.top_k, 50);
|
|
assert_eq!(params.repetition_penalty, 1.2);
|
|
assert_eq!(params.seed, Some(42));
|
|
}
|
|
|
|
#[test]
|
|
fn test_routing_decision() {
|
|
let decisions = vec![
|
|
RoutingDecision {
|
|
model: ModelSize::Large,
|
|
context_size: 1024,
|
|
temperature: 0.7,
|
|
top_p: 0.9,
|
|
confidence: 0.95,
|
|
model_probs: [0.05, 0.1, 0.25, 0.6],
|
|
},
|
|
RoutingDecision {
|
|
model: ModelSize::Medium,
|
|
context_size: 512,
|
|
temperature: 0.8,
|
|
top_p: 0.95,
|
|
confidence: 0.88,
|
|
model_probs: [0.1, 0.2, 0.5, 0.2],
|
|
},
|
|
RoutingDecision {
|
|
model: ModelSize::Small,
|
|
context_size: 256,
|
|
temperature: 0.6,
|
|
top_p: 0.9,
|
|
confidence: 0.6,
|
|
model_probs: [0.2, 0.5, 0.2, 0.1],
|
|
},
|
|
];
|
|
|
|
for decision in decisions {
|
|
assert!(decision.confidence >= 0.0 && decision.confidence <= 1.0);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_error_handling() {
|
|
let temp_dir = create_test_dir();
|
|
let config = create_test_config(temp_dir.path().to_str().unwrap());
|
|
let engine = RuvLLMEngine::new(config).unwrap();
|
|
|
|
// Try to get non-existent session
|
|
let result = engine.get_session("non-existent-session-id");
|
|
assert!(result.is_ok()); // Should succeed but return None
|
|
assert!(result.unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_memory_efficiency() {
|
|
let config = KvCacheConfig {
|
|
tail_length: 32,
|
|
max_tokens: 128,
|
|
num_kv_heads: 4,
|
|
head_dim: 64,
|
|
migration_batch: 16,
|
|
..Default::default()
|
|
};
|
|
|
|
let cache = TwoTierKvCache::new(config);
|
|
|
|
// Fill cache
|
|
for _ in 0..100 {
|
|
let keys = vec![1.0; 4 * 64];
|
|
let values = vec![2.0; 4 * 64];
|
|
cache.append(&keys, &values).unwrap();
|
|
}
|
|
|
|
let stats = cache.stats();
|
|
|
|
// Store should use less memory per token than tail (quantized)
|
|
if stats.store_tokens > 0 && stats.tail_tokens > 0 {
|
|
let bytes_per_tail_token = stats.tail_bytes as f32 / stats.tail_tokens as f32;
|
|
let bytes_per_store_token = stats.store_bytes as f32 / stats.store_tokens as f32;
|
|
|
|
// Quantized store should use less memory (or same if not actually quantized)
|
|
assert!(
|
|
bytes_per_store_token <= bytes_per_tail_token * 1.1,
|
|
"Store should be more memory efficient: {} vs {} bytes/token",
|
|
bytes_per_store_token,
|
|
bytes_per_tail_token
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_sona_integration_basic() {
|
|
let config = SonaConfig {
|
|
embedding_dim: 256,
|
|
..Default::default()
|
|
};
|
|
let sona = SonaIntegration::new(config);
|
|
|
|
// Record a trajectory
|
|
let trajectory = Trajectory {
|
|
request_id: "req-1".to_string(),
|
|
session_id: "test-session".to_string(),
|
|
query_embedding: vec![0.1; 256],
|
|
response_embedding: vec![0.2; 256],
|
|
quality_score: 0.8,
|
|
routing_features: vec![0.7, 0.9, 0.5, 0.5],
|
|
model_index: 1,
|
|
timestamp: Utc::now(),
|
|
};
|
|
sona.record_trajectory(trajectory).unwrap();
|
|
|
|
// Get stats
|
|
let stats = sona.stats();
|
|
assert!(stats.total_trajectories >= 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_sona_learning_loops() {
|
|
// Test that all learning loop variants exist
|
|
let loops = vec![
|
|
LearningLoop::Instant,
|
|
LearningLoop::Background,
|
|
LearningLoop::Deep,
|
|
];
|
|
|
|
for _loop in loops {
|
|
// Just verify the variants exist
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_quantization_variants() {
|
|
let q4 = Quantization::Q4;
|
|
let q8 = Quantization::Q8;
|
|
let q4k = Quantization::Q4K;
|
|
let f16 = Quantization::F16;
|
|
|
|
assert!(q4.is_gguf());
|
|
assert!(q8.is_gguf());
|
|
assert!(q4k.is_gguf());
|
|
assert!(!f16.is_gguf());
|
|
|
|
// Check bytes per weight
|
|
assert_eq!(Quantization::None.bytes_per_weight(), 4.0);
|
|
assert_eq!(Quantization::F16.bytes_per_weight(), 2.0);
|
|
assert_eq!(Quantization::Q8.bytes_per_weight(), 1.0);
|
|
assert_eq!(Quantization::Q4K.bytes_per_weight(), 0.5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_device_type_variants() {
|
|
let cpu = DeviceType::Cpu;
|
|
let metal = DeviceType::Metal;
|
|
let cuda = DeviceType::Cuda(0);
|
|
|
|
assert!(matches!(cpu, DeviceType::Cpu));
|
|
assert!(matches!(metal, DeviceType::Metal));
|
|
if let DeviceType::Cuda(idx) = cuda {
|
|
assert_eq!(idx, 0);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_model_architecture_variants() {
|
|
let llama = ModelArchitecture::Llama;
|
|
let mistral = ModelArchitecture::Mistral;
|
|
let phi = ModelArchitecture::Phi;
|
|
let qwen = ModelArchitecture::Qwen;
|
|
let gemma = ModelArchitecture::Gemma;
|
|
|
|
assert_eq!(llama.config_name(), "llama");
|
|
assert_eq!(mistral.config_name(), "mistral");
|
|
assert_eq!(phi.config_name(), "phi");
|
|
assert_eq!(qwen.config_name(), "qwen2");
|
|
assert_eq!(gemma.config_name(), "gemma");
|
|
}
|
|
|
|
#[test]
|
|
fn test_dtype_variants() {
|
|
let f32_type = DType::F32;
|
|
let f16_type = DType::F16;
|
|
let bf16_type = DType::Bf16;
|
|
|
|
assert!(matches!(f32_type, DType::F32));
|
|
assert!(matches!(f16_type, DType::F16));
|
|
assert!(matches!(bf16_type, DType::Bf16));
|
|
}
|