Files
wifi-densepose/crates/ruvllm/tests/mistral_backend_test.rs
ruv d803bfe2b1 Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
2026-02-28 14:39:40 -05:00

699 lines
20 KiB
Rust

#![allow(
clippy::all,
unused_imports,
unused_variables,
dead_code,
unused_mut,
unused_assignments,
non_camel_case_types,
clippy::approx_constant,
unexpected_cfgs,
unused_must_use,
unused_parens
)]
//! Integration tests for mistral-rs backend
//!
//! Tests the mistral-rs backend integration including:
//! - Backend creation and configuration
//! - PagedAttention integration
//! - X-LoRA adapter management
//! - ISQ (In-Situ Quantization) configuration
//! - Model loading and generation (requires model files)
//!
//! ## Running Tests
//!
//! ```bash
//! # Run basic tests (no model required)
//! cargo test --features mistral-rs mistral_backend
//!
//! # Run all tests including model-dependent ones
//! cargo test --features mistral-rs mistral_backend -- --include-ignored
//!
//! # Run with Metal acceleration
//! cargo test --features mistral-rs-metal mistral_backend
//! ```
#![cfg(feature = "mistral-rs")]
use ruvllm::backends::mistral_backend::{
IsqConfig, IsqMethod, MistralBackend, MistralBackendConfig, PagedAttentionConfigExt,
XLoraConfig, XLoraManager, XLoraMixingMode,
};
use ruvllm::backends::{
DType, DeviceType, GenerateParams, LlmBackend, ModelArchitecture, ModelConfig, Quantization,
};
use std::path::Path;
// ============================================================================
// Backend Creation Tests
// ============================================================================
#[test]
fn test_mistral_backend_creation() {
let backend = MistralBackend::new().unwrap();
assert!(!backend.is_model_loaded());
assert!(backend.model_info().is_none());
}
#[test]
fn test_mistral_backend_default() {
let backend = MistralBackend::default();
assert!(!backend.is_model_loaded());
}
#[test]
fn test_mistral_backend_for_metal() {
let result = MistralBackend::for_metal();
assert!(result.is_ok());
let backend = result.unwrap();
assert!(!backend.is_model_loaded());
}
#[test]
fn test_mistral_backend_for_cuda() {
let result = MistralBackend::for_cuda(0);
assert!(result.is_ok());
let backend = result.unwrap();
assert!(!backend.is_model_loaded());
}
#[test]
fn test_mistral_backend_with_custom_config() {
let config = MistralBackendConfig::default()
.with_max_seq_len(16384)
.with_max_batch_size(64);
let backend = MistralBackend::with_config(config).unwrap();
assert!(!backend.is_model_loaded());
}
// ============================================================================
// Configuration Builder Tests
// ============================================================================
#[test]
fn test_mistral_config_builder() {
let config = MistralBackendConfig::default()
.with_paged_attention(16, 4096)
.with_xlora_adapters(vec!["code", "chat"])
.with_isq(4);
assert!(config.paged_attention.is_some());
assert!(config.xlora.is_some());
assert!(config.isq.is_some());
}
#[test]
fn test_mistral_config_paged_attention() {
let config = MistralBackendConfig::default().with_paged_attention(32, 8192);
let pa = config.paged_attention.unwrap();
assert_eq!(pa.block_size, 32);
assert_eq!(pa.max_pages, 8192);
assert!(pa.enable_prefix_caching);
assert!((pa.gpu_memory_fraction - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_mistral_config_xlora() {
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "chat", "math"]);
let xlora = config.xlora.unwrap();
assert_eq!(xlora.adapter_names.len(), 3);
assert!(xlora.adapter_names.contains(&"code".to_string()));
assert!(xlora.adapter_names.contains(&"chat".to_string()));
assert!(xlora.adapter_names.contains(&"math".to_string()));
}
#[test]
fn test_mistral_config_isq() {
let config = MistralBackendConfig::default().with_isq(4);
let isq = config.isq.unwrap();
assert_eq!(isq.bits, 4);
assert!(matches!(isq.method, IsqMethod::AWQ));
assert!(!isq.symmetric);
assert!(isq.per_channel);
}
#[test]
fn test_mistral_config_chained() {
let config = MistralBackendConfig::default()
.with_paged_attention(16, 4096)
.with_xlora_adapters(vec!["adapter1", "adapter2"])
.with_isq(8)
.with_max_seq_len(32768)
.with_max_batch_size(128);
assert!(config.paged_attention.is_some());
assert!(config.xlora.is_some());
assert!(config.isq.is_some());
assert_eq!(config.max_seq_len, 32768);
assert_eq!(config.max_batch_size, 128);
}
#[test]
fn test_mistral_config_for_metal() {
let config = MistralBackendConfig::for_metal();
assert!(matches!(config.device, DeviceType::Metal));
assert!(matches!(config.dtype, DType::F16));
assert!(config.use_flash_attn);
}
#[test]
fn test_mistral_config_for_cuda() {
let config = MistralBackendConfig::for_cuda(1);
if let DeviceType::Cuda(id) = config.device {
assert_eq!(id, 1);
} else {
panic!("Expected CUDA device type");
}
assert!(matches!(config.dtype, DType::F16));
assert!(config.use_flash_attn);
}
// ============================================================================
// PagedAttention Configuration Tests
// ============================================================================
#[test]
fn test_paged_attention_config_default() {
let config = PagedAttentionConfigExt::default();
assert_eq!(config.block_size, 16);
assert_eq!(config.max_pages, 4096);
assert!((config.gpu_memory_fraction - 0.9).abs() < f32::EPSILON);
assert!(config.enable_prefix_caching);
assert!((config.recomputation_threshold - 0.1).abs() < f32::EPSILON);
}
#[test]
fn test_paged_attention_stats() {
let backend = MistralBackend::new().unwrap();
let stats = backend.paged_attention_stats();
// Default config enables PagedAttention
assert!(stats.is_some());
let stats = stats.unwrap();
assert!(stats.total_blocks > 0);
assert_eq!(stats.active_sequences, 0);
}
#[test]
fn test_paged_attention_disabled() {
let mut config = MistralBackendConfig::default();
config.paged_attention = None;
let backend = MistralBackend::with_config(config).unwrap();
let stats = backend.paged_attention_stats();
assert!(stats.is_none());
}
// ============================================================================
// X-LoRA Manager Tests
// ============================================================================
#[test]
fn test_xlora_manager_creation() {
let xlora_config = XLoraConfig {
adapter_names: vec!["test".to_string()],
top_k: 1,
..Default::default()
};
let manager = XLoraManager::new(xlora_config);
let stats = manager.stats();
assert_eq!(stats.loaded_adapters, 0);
assert_eq!(stats.forward_count, 0);
}
#[test]
fn test_xlora_manager_routing() {
let xlora_config = XLoraConfig {
adapter_names: vec!["code".to_string(), "chat".to_string()],
top_k: 2,
use_learned_routing: false,
..Default::default()
};
let manager = XLoraManager::new(xlora_config);
// Route without adapters - returns empty
let routing = manager.route(&[0.1, 0.2, 0.3]);
assert!(routing.is_empty()); // No adapters loaded
let stats = manager.stats();
assert_eq!(stats.forward_count, 1);
}
#[test]
fn test_xlora_config_defaults() {
let config = XLoraConfig::default();
assert!(config.adapter_names.is_empty());
assert!(config.base_adapter.is_none());
assert!(config.adapter_scales.is_none());
assert_eq!(config.router_hidden_dim, 64);
assert_eq!(config.router_layers, 2);
assert_eq!(config.top_k, 2);
assert!((config.temperature - 1.0).abs() < f32::EPSILON);
assert!(config.use_learned_routing);
assert!(matches!(config.mixing_mode, XLoraMixingMode::Additive));
}
#[test]
fn test_xlora_mixing_modes() {
let additive = XLoraMixingMode::Additive;
let concat = XLoraMixingMode::Concatenate;
let gated = XLoraMixingMode::Gated;
let attention = XLoraMixingMode::Attention;
assert!(matches!(additive, XLoraMixingMode::Additive));
assert!(matches!(concat, XLoraMixingMode::Concatenate));
assert!(matches!(gated, XLoraMixingMode::Gated));
assert!(matches!(attention, XLoraMixingMode::Attention));
}
#[test]
fn test_xlora_stats_from_backend() {
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "chat"]);
let backend = MistralBackend::with_config(config).unwrap();
let stats = backend.xlora_stats();
assert!(stats.is_some());
let stats = stats.unwrap();
assert_eq!(stats.loaded_adapters, 0); // No adapters actually loaded from disk
assert_eq!(stats.forward_count, 0);
}
#[test]
fn test_xlora_stats_none_when_not_configured() {
let mut config = MistralBackendConfig::default();
config.xlora = None;
let backend = MistralBackend::with_config(config).unwrap();
let stats = backend.xlora_stats();
assert!(stats.is_none());
}
// ============================================================================
// ISQ Configuration Tests
// ============================================================================
#[test]
fn test_isq_config_defaults() {
let config = IsqConfig::default();
assert_eq!(config.bits, 4);
assert!(matches!(config.method, IsqMethod::AWQ));
assert!(!config.symmetric);
assert!(config.per_channel);
assert_eq!(config.calibration_samples, 128);
}
#[test]
fn test_isq_methods() {
let awq = IsqMethod::AWQ;
let gptq = IsqMethod::GPTQ;
let rtn = IsqMethod::RTN;
let smooth = IsqMethod::SmoothQuant;
assert!(matches!(awq, IsqMethod::AWQ));
assert!(matches!(gptq, IsqMethod::GPTQ));
assert!(matches!(rtn, IsqMethod::RTN));
assert!(matches!(smooth, IsqMethod::SmoothQuant));
}
#[test]
fn test_isq_with_different_bits() {
for bits in [2, 4, 8] {
let config = MistralBackendConfig::default().with_isq(bits);
let isq = config.isq.unwrap();
assert_eq!(isq.bits, bits);
}
}
// ============================================================================
// Backend Operation Tests (Without Model)
// ============================================================================
#[test]
fn test_generate_requires_loaded_model() {
let backend = MistralBackend::new().unwrap();
let result = backend.generate("Hello", GenerateParams::default());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("No model loaded"));
}
#[test]
fn test_generate_stream_requires_loaded_model() {
let backend = MistralBackend::new().unwrap();
let result = backend.generate_stream("Hello", GenerateParams::default());
assert!(result.is_err());
}
#[test]
fn test_embeddings_require_loaded_model() {
let backend = MistralBackend::new().unwrap();
let result = backend.get_embeddings("Test text");
assert!(result.is_err());
}
#[test]
fn test_tokenizer_none_before_load() {
let backend = MistralBackend::new().unwrap();
assert!(backend.tokenizer().is_none());
}
#[test]
fn test_model_info_none_before_load() {
let backend = MistralBackend::new().unwrap();
assert!(backend.model_info().is_none());
}
#[test]
fn test_unload_model_when_not_loaded() {
let mut backend = MistralBackend::new().unwrap();
// Should not panic when called on unloaded backend
backend.unload_model();
assert!(!backend.is_model_loaded());
}
#[test]
fn test_xlora_adapter_operations_require_config() {
let mut config = MistralBackendConfig::default();
config.xlora = None;
let backend = MistralBackend::with_config(config).unwrap();
// Loading adapter should fail without X-LoRA configured
let result = backend.load_xlora_adapter("test", Path::new("/nonexistent"));
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("X-LoRA not configured"));
// Setting adapters should fail
let result = backend.set_xlora_adapters(vec![("test", 1.0)]);
assert!(result.is_err());
}
#[test]
fn test_isq_requires_loaded_model() {
let config = MistralBackendConfig::default().with_isq(4);
let mut backend = MistralBackend::with_config(config).unwrap();
let result = backend.apply_isq();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("No model loaded"));
}
// ============================================================================
// Model Loading Tests (Requires Model Files - Ignored by Default)
// ============================================================================
#[test]
#[ignore = "Requires model file - run with --include-ignored"]
fn test_model_loading() {
let mut backend = MistralBackend::for_metal().unwrap();
// Note: Replace with actual model path for testing
let result = backend.load_model(
"models/test-model.gguf",
ModelConfig {
architecture: ModelArchitecture::Mistral,
device: DeviceType::Metal,
..Default::default()
},
);
if result.is_ok() {
assert!(backend.is_model_loaded());
assert!(backend.model_info().is_some());
assert!(backend.tokenizer().is_some());
}
}
#[test]
#[ignore = "Requires model file - run with --include-ignored"]
fn test_generation() {
let mut backend = MistralBackend::new().unwrap();
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
if load_result.is_err() {
return; // Skip if model not available
}
let output = backend.generate(
"Hello",
GenerateParams {
max_tokens: 10,
temperature: 0.7,
..Default::default()
},
);
match output {
Ok(text) => {
assert!(!text.is_empty());
}
Err(e) => {
panic!("Generation failed: {}", e);
}
}
}
#[test]
#[ignore = "Requires model file - run with --include-ignored"]
fn test_streaming_generation() {
let mut backend = MistralBackend::new().unwrap();
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
if load_result.is_err() {
return;
}
let stream = backend.generate_stream("Hello", GenerateParams::default());
match stream {
Ok(stream) => {
let tokens: Vec<_> = stream.collect();
assert!(!tokens.is_empty());
}
Err(e) => {
panic!("Streaming generation failed: {}", e);
}
}
}
#[test]
#[ignore = "Requires model file - run with --include-ignored"]
fn test_embeddings_extraction() {
let mut backend = MistralBackend::new().unwrap();
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
if load_result.is_err() {
return;
}
let embeddings = backend.get_embeddings("Test text for embedding");
match embeddings {
Ok(emb) => {
assert!(!emb.is_empty());
assert!(emb.iter().all(|&v| v.is_finite()));
}
Err(e) => {
panic!("Embedding extraction failed: {}", e);
}
}
}
#[test]
#[ignore = "Requires model file - run with --include-ignored"]
fn test_model_unload_and_reload() {
let mut backend = MistralBackend::new().unwrap();
// Load model
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
if load_result.is_err() {
return;
}
assert!(backend.is_model_loaded());
// Unload
backend.unload_model();
assert!(!backend.is_model_loaded());
assert!(backend.model_info().is_none());
// Reload
let reload_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
if reload_result.is_ok() {
assert!(backend.is_model_loaded());
}
}
// ============================================================================
// Integration Tests with PagedAttention
// ============================================================================
#[test]
fn test_backend_paged_attention_integration() {
let config = MistralBackendConfig::default().with_paged_attention(16, 4096);
let backend = MistralBackend::with_config(config).unwrap();
// Verify PagedAttention is configured
let stats = backend.paged_attention_stats().unwrap();
assert!(stats.total_blocks > 0);
assert!(stats.free_blocks > 0);
assert_eq!(stats.active_sequences, 0);
}
#[test]
fn test_backend_xlora_integration() {
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "math", "chat"]);
let backend = MistralBackend::with_config(config).unwrap();
// Verify X-LoRA is configured
let stats = backend.xlora_stats().unwrap();
assert_eq!(stats.loaded_adapters, 0); // None loaded yet
assert!(stats.adapter_usage.is_empty());
}
// ============================================================================
// Serialization Tests
// ============================================================================
#[test]
fn test_config_serialization() {
let config = MistralBackendConfig::default()
.with_paged_attention(32, 8192)
.with_xlora_adapters(vec!["test"])
.with_isq(4);
// Test serialization roundtrip
let json = serde_json::to_string(&config).unwrap();
let deserialized: MistralBackendConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.max_seq_len, config.max_seq_len);
assert_eq!(deserialized.max_batch_size, config.max_batch_size);
assert!(deserialized.paged_attention.is_some());
assert!(deserialized.xlora.is_some());
assert!(deserialized.isq.is_some());
}
#[test]
fn test_paged_attention_config_serialization() {
let config = PagedAttentionConfigExt::default();
let json = serde_json::to_string(&config).unwrap();
let deserialized: PagedAttentionConfigExt = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.block_size, config.block_size);
assert_eq!(deserialized.max_pages, config.max_pages);
assert_eq!(
deserialized.enable_prefix_caching,
config.enable_prefix_caching
);
}
#[test]
fn test_xlora_config_serialization() {
let config = XLoraConfig {
adapter_names: vec!["a".to_string(), "b".to_string()],
top_k: 3,
temperature: 0.5,
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: XLoraConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.adapter_names.len(), 2);
assert_eq!(deserialized.top_k, 3);
assert!((deserialized.temperature - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_isq_config_serialization() {
let config = IsqConfig {
bits: 8,
method: IsqMethod::GPTQ,
symmetric: true,
per_channel: false,
calibration_samples: 256,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: IsqConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.bits, 8);
assert!(matches!(deserialized.method, IsqMethod::GPTQ));
assert!(deserialized.symmetric);
assert!(!deserialized.per_channel);
assert_eq!(deserialized.calibration_samples, 256);
}
// ============================================================================
// Edge Case Tests
// ============================================================================
#[test]
fn test_empty_xlora_adapters() {
let config = MistralBackendConfig::default().with_xlora_adapters(vec![]);
let xlora = config.xlora.unwrap();
assert!(xlora.adapter_names.is_empty());
}
#[test]
fn test_large_page_config() {
let config = MistralBackendConfig::default().with_paged_attention(256, 65536);
let pa = config.paged_attention.unwrap();
assert_eq!(pa.block_size, 256);
assert_eq!(pa.max_pages, 65536);
}
#[test]
fn test_multiple_backend_instances() {
let backend1 = MistralBackend::new().unwrap();
let backend2 = MistralBackend::for_metal().unwrap();
let backend3 = MistralBackend::for_cuda(0).unwrap();
assert!(!backend1.is_model_loaded());
assert!(!backend2.is_model_loaded());
assert!(!backend3.is_model_loaded());
}
#[test]
fn test_generate_params_integration() {
let params = GenerateParams {
max_tokens: 256,
temperature: 0.8,
top_p: 0.95,
top_k: 50,
repetition_penalty: 1.1,
frequency_penalty: 0.1,
presence_penalty: 0.1,
stop_sequences: vec!["STOP".to_string(), "\n\n".to_string()],
seed: Some(12345),
};
assert_eq!(params.max_tokens, 256);
assert!((params.temperature - 0.8).abs() < f32::EPSILON);
assert_eq!(params.stop_sequences.len(), 2);
assert_eq!(params.seed, Some(12345));
}