git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
699 lines
20 KiB
Rust
699 lines
20 KiB
Rust
#![allow(
|
|
clippy::all,
|
|
unused_imports,
|
|
unused_variables,
|
|
dead_code,
|
|
unused_mut,
|
|
unused_assignments,
|
|
non_camel_case_types,
|
|
clippy::approx_constant,
|
|
unexpected_cfgs,
|
|
unused_must_use,
|
|
unused_parens
|
|
)]
|
|
//! Integration tests for mistral-rs backend
|
|
//!
|
|
//! Tests the mistral-rs backend integration including:
|
|
//! - Backend creation and configuration
|
|
//! - PagedAttention integration
|
|
//! - X-LoRA adapter management
|
|
//! - ISQ (In-Situ Quantization) configuration
|
|
//! - Model loading and generation (requires model files)
|
|
//!
|
|
//! ## Running Tests
|
|
//!
|
|
//! ```bash
|
|
//! # Run basic tests (no model required)
|
|
//! cargo test --features mistral-rs mistral_backend
|
|
//!
|
|
//! # Run all tests including model-dependent ones
|
|
//! cargo test --features mistral-rs mistral_backend -- --include-ignored
|
|
//!
|
|
//! # Run with Metal acceleration
|
|
//! cargo test --features mistral-rs-metal mistral_backend
|
|
//! ```
|
|
|
|
#![cfg(feature = "mistral-rs")]
|
|
|
|
use ruvllm::backends::mistral_backend::{
|
|
IsqConfig, IsqMethod, MistralBackend, MistralBackendConfig, PagedAttentionConfigExt,
|
|
XLoraConfig, XLoraManager, XLoraMixingMode,
|
|
};
|
|
use ruvllm::backends::{
|
|
DType, DeviceType, GenerateParams, LlmBackend, ModelArchitecture, ModelConfig, Quantization,
|
|
};
|
|
use std::path::Path;
|
|
|
|
// ============================================================================
|
|
// Backend Creation Tests
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_mistral_backend_creation() {
|
|
let backend = MistralBackend::new().unwrap();
|
|
assert!(!backend.is_model_loaded());
|
|
assert!(backend.model_info().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_backend_default() {
|
|
let backend = MistralBackend::default();
|
|
assert!(!backend.is_model_loaded());
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_backend_for_metal() {
|
|
let result = MistralBackend::for_metal();
|
|
assert!(result.is_ok());
|
|
let backend = result.unwrap();
|
|
assert!(!backend.is_model_loaded());
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_backend_for_cuda() {
|
|
let result = MistralBackend::for_cuda(0);
|
|
assert!(result.is_ok());
|
|
let backend = result.unwrap();
|
|
assert!(!backend.is_model_loaded());
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_backend_with_custom_config() {
|
|
let config = MistralBackendConfig::default()
|
|
.with_max_seq_len(16384)
|
|
.with_max_batch_size(64);
|
|
|
|
let backend = MistralBackend::with_config(config).unwrap();
|
|
assert!(!backend.is_model_loaded());
|
|
}
|
|
|
|
// ============================================================================
|
|
// Configuration Builder Tests
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_mistral_config_builder() {
|
|
let config = MistralBackendConfig::default()
|
|
.with_paged_attention(16, 4096)
|
|
.with_xlora_adapters(vec!["code", "chat"])
|
|
.with_isq(4);
|
|
|
|
assert!(config.paged_attention.is_some());
|
|
assert!(config.xlora.is_some());
|
|
assert!(config.isq.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_config_paged_attention() {
|
|
let config = MistralBackendConfig::default().with_paged_attention(32, 8192);
|
|
|
|
let pa = config.paged_attention.unwrap();
|
|
assert_eq!(pa.block_size, 32);
|
|
assert_eq!(pa.max_pages, 8192);
|
|
assert!(pa.enable_prefix_caching);
|
|
assert!((pa.gpu_memory_fraction - 0.9).abs() < f32::EPSILON);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_config_xlora() {
|
|
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "chat", "math"]);
|
|
|
|
let xlora = config.xlora.unwrap();
|
|
assert_eq!(xlora.adapter_names.len(), 3);
|
|
assert!(xlora.adapter_names.contains(&"code".to_string()));
|
|
assert!(xlora.adapter_names.contains(&"chat".to_string()));
|
|
assert!(xlora.adapter_names.contains(&"math".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_config_isq() {
|
|
let config = MistralBackendConfig::default().with_isq(4);
|
|
|
|
let isq = config.isq.unwrap();
|
|
assert_eq!(isq.bits, 4);
|
|
assert!(matches!(isq.method, IsqMethod::AWQ));
|
|
assert!(!isq.symmetric);
|
|
assert!(isq.per_channel);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_config_chained() {
|
|
let config = MistralBackendConfig::default()
|
|
.with_paged_attention(16, 4096)
|
|
.with_xlora_adapters(vec!["adapter1", "adapter2"])
|
|
.with_isq(8)
|
|
.with_max_seq_len(32768)
|
|
.with_max_batch_size(128);
|
|
|
|
assert!(config.paged_attention.is_some());
|
|
assert!(config.xlora.is_some());
|
|
assert!(config.isq.is_some());
|
|
assert_eq!(config.max_seq_len, 32768);
|
|
assert_eq!(config.max_batch_size, 128);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_config_for_metal() {
|
|
let config = MistralBackendConfig::for_metal();
|
|
|
|
assert!(matches!(config.device, DeviceType::Metal));
|
|
assert!(matches!(config.dtype, DType::F16));
|
|
assert!(config.use_flash_attn);
|
|
}
|
|
|
|
#[test]
|
|
fn test_mistral_config_for_cuda() {
|
|
let config = MistralBackendConfig::for_cuda(1);
|
|
|
|
if let DeviceType::Cuda(id) = config.device {
|
|
assert_eq!(id, 1);
|
|
} else {
|
|
panic!("Expected CUDA device type");
|
|
}
|
|
assert!(matches!(config.dtype, DType::F16));
|
|
assert!(config.use_flash_attn);
|
|
}
|
|
|
|
// ============================================================================
|
|
// PagedAttention Configuration Tests
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_paged_attention_config_default() {
|
|
let config = PagedAttentionConfigExt::default();
|
|
|
|
assert_eq!(config.block_size, 16);
|
|
assert_eq!(config.max_pages, 4096);
|
|
assert!((config.gpu_memory_fraction - 0.9).abs() < f32::EPSILON);
|
|
assert!(config.enable_prefix_caching);
|
|
assert!((config.recomputation_threshold - 0.1).abs() < f32::EPSILON);
|
|
}
|
|
|
|
#[test]
|
|
fn test_paged_attention_stats() {
|
|
let backend = MistralBackend::new().unwrap();
|
|
let stats = backend.paged_attention_stats();
|
|
|
|
// Default config enables PagedAttention
|
|
assert!(stats.is_some());
|
|
let stats = stats.unwrap();
|
|
assert!(stats.total_blocks > 0);
|
|
assert_eq!(stats.active_sequences, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_paged_attention_disabled() {
|
|
let mut config = MistralBackendConfig::default();
|
|
config.paged_attention = None;
|
|
|
|
let backend = MistralBackend::with_config(config).unwrap();
|
|
let stats = backend.paged_attention_stats();
|
|
|
|
assert!(stats.is_none());
|
|
}
|
|
|
|
// ============================================================================
|
|
// X-LoRA Manager Tests
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_xlora_manager_creation() {
|
|
let xlora_config = XLoraConfig {
|
|
adapter_names: vec!["test".to_string()],
|
|
top_k: 1,
|
|
..Default::default()
|
|
};
|
|
|
|
let manager = XLoraManager::new(xlora_config);
|
|
let stats = manager.stats();
|
|
|
|
assert_eq!(stats.loaded_adapters, 0);
|
|
assert_eq!(stats.forward_count, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_xlora_manager_routing() {
|
|
let xlora_config = XLoraConfig {
|
|
adapter_names: vec!["code".to_string(), "chat".to_string()],
|
|
top_k: 2,
|
|
use_learned_routing: false,
|
|
..Default::default()
|
|
};
|
|
|
|
let manager = XLoraManager::new(xlora_config);
|
|
|
|
// Route without adapters - returns empty
|
|
let routing = manager.route(&[0.1, 0.2, 0.3]);
|
|
assert!(routing.is_empty()); // No adapters loaded
|
|
|
|
let stats = manager.stats();
|
|
assert_eq!(stats.forward_count, 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_xlora_config_defaults() {
|
|
let config = XLoraConfig::default();
|
|
|
|
assert!(config.adapter_names.is_empty());
|
|
assert!(config.base_adapter.is_none());
|
|
assert!(config.adapter_scales.is_none());
|
|
assert_eq!(config.router_hidden_dim, 64);
|
|
assert_eq!(config.router_layers, 2);
|
|
assert_eq!(config.top_k, 2);
|
|
assert!((config.temperature - 1.0).abs() < f32::EPSILON);
|
|
assert!(config.use_learned_routing);
|
|
assert!(matches!(config.mixing_mode, XLoraMixingMode::Additive));
|
|
}
|
|
|
|
#[test]
|
|
fn test_xlora_mixing_modes() {
|
|
let additive = XLoraMixingMode::Additive;
|
|
let concat = XLoraMixingMode::Concatenate;
|
|
let gated = XLoraMixingMode::Gated;
|
|
let attention = XLoraMixingMode::Attention;
|
|
|
|
assert!(matches!(additive, XLoraMixingMode::Additive));
|
|
assert!(matches!(concat, XLoraMixingMode::Concatenate));
|
|
assert!(matches!(gated, XLoraMixingMode::Gated));
|
|
assert!(matches!(attention, XLoraMixingMode::Attention));
|
|
}
|
|
|
|
#[test]
|
|
fn test_xlora_stats_from_backend() {
|
|
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "chat"]);
|
|
let backend = MistralBackend::with_config(config).unwrap();
|
|
|
|
let stats = backend.xlora_stats();
|
|
assert!(stats.is_some());
|
|
|
|
let stats = stats.unwrap();
|
|
assert_eq!(stats.loaded_adapters, 0); // No adapters actually loaded from disk
|
|
assert_eq!(stats.forward_count, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_xlora_stats_none_when_not_configured() {
|
|
let mut config = MistralBackendConfig::default();
|
|
config.xlora = None;
|
|
|
|
let backend = MistralBackend::with_config(config).unwrap();
|
|
let stats = backend.xlora_stats();
|
|
|
|
assert!(stats.is_none());
|
|
}
|
|
|
|
// ============================================================================
|
|
// ISQ Configuration Tests
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_isq_config_defaults() {
|
|
let config = IsqConfig::default();
|
|
|
|
assert_eq!(config.bits, 4);
|
|
assert!(matches!(config.method, IsqMethod::AWQ));
|
|
assert!(!config.symmetric);
|
|
assert!(config.per_channel);
|
|
assert_eq!(config.calibration_samples, 128);
|
|
}
|
|
|
|
#[test]
|
|
fn test_isq_methods() {
|
|
let awq = IsqMethod::AWQ;
|
|
let gptq = IsqMethod::GPTQ;
|
|
let rtn = IsqMethod::RTN;
|
|
let smooth = IsqMethod::SmoothQuant;
|
|
|
|
assert!(matches!(awq, IsqMethod::AWQ));
|
|
assert!(matches!(gptq, IsqMethod::GPTQ));
|
|
assert!(matches!(rtn, IsqMethod::RTN));
|
|
assert!(matches!(smooth, IsqMethod::SmoothQuant));
|
|
}
|
|
|
|
#[test]
|
|
fn test_isq_with_different_bits() {
|
|
for bits in [2, 4, 8] {
|
|
let config = MistralBackendConfig::default().with_isq(bits);
|
|
let isq = config.isq.unwrap();
|
|
assert_eq!(isq.bits, bits);
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Backend Operation Tests (Without Model)
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_generate_requires_loaded_model() {
|
|
let backend = MistralBackend::new().unwrap();
|
|
|
|
let result = backend.generate("Hello", GenerateParams::default());
|
|
assert!(result.is_err());
|
|
|
|
let err = result.unwrap_err();
|
|
assert!(err.to_string().contains("No model loaded"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_stream_requires_loaded_model() {
|
|
let backend = MistralBackend::new().unwrap();
|
|
|
|
let result = backend.generate_stream("Hello", GenerateParams::default());
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_embeddings_require_loaded_model() {
|
|
let backend = MistralBackend::new().unwrap();
|
|
|
|
let result = backend.get_embeddings("Test text");
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_tokenizer_none_before_load() {
|
|
let backend = MistralBackend::new().unwrap();
|
|
assert!(backend.tokenizer().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_model_info_none_before_load() {
|
|
let backend = MistralBackend::new().unwrap();
|
|
assert!(backend.model_info().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_unload_model_when_not_loaded() {
|
|
let mut backend = MistralBackend::new().unwrap();
|
|
|
|
// Should not panic when called on unloaded backend
|
|
backend.unload_model();
|
|
assert!(!backend.is_model_loaded());
|
|
}
|
|
|
|
#[test]
|
|
fn test_xlora_adapter_operations_require_config() {
|
|
let mut config = MistralBackendConfig::default();
|
|
config.xlora = None;
|
|
|
|
let backend = MistralBackend::with_config(config).unwrap();
|
|
|
|
// Loading adapter should fail without X-LoRA configured
|
|
let result = backend.load_xlora_adapter("test", Path::new("/nonexistent"));
|
|
assert!(result.is_err());
|
|
assert!(result
|
|
.unwrap_err()
|
|
.to_string()
|
|
.contains("X-LoRA not configured"));
|
|
|
|
// Setting adapters should fail
|
|
let result = backend.set_xlora_adapters(vec![("test", 1.0)]);
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_isq_requires_loaded_model() {
|
|
let config = MistralBackendConfig::default().with_isq(4);
|
|
let mut backend = MistralBackend::with_config(config).unwrap();
|
|
|
|
let result = backend.apply_isq();
|
|
assert!(result.is_err());
|
|
assert!(result.unwrap_err().to_string().contains("No model loaded"));
|
|
}
|
|
|
|
// ============================================================================
|
|
// Model Loading Tests (Requires Model Files - Ignored by Default)
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
#[ignore = "Requires model file - run with --include-ignored"]
|
|
fn test_model_loading() {
|
|
let mut backend = MistralBackend::for_metal().unwrap();
|
|
|
|
// Note: Replace with actual model path for testing
|
|
let result = backend.load_model(
|
|
"models/test-model.gguf",
|
|
ModelConfig {
|
|
architecture: ModelArchitecture::Mistral,
|
|
device: DeviceType::Metal,
|
|
..Default::default()
|
|
},
|
|
);
|
|
|
|
if result.is_ok() {
|
|
assert!(backend.is_model_loaded());
|
|
assert!(backend.model_info().is_some());
|
|
assert!(backend.tokenizer().is_some());
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
#[ignore = "Requires model file - run with --include-ignored"]
|
|
fn test_generation() {
|
|
let mut backend = MistralBackend::new().unwrap();
|
|
|
|
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
|
if load_result.is_err() {
|
|
return; // Skip if model not available
|
|
}
|
|
|
|
let output = backend.generate(
|
|
"Hello",
|
|
GenerateParams {
|
|
max_tokens: 10,
|
|
temperature: 0.7,
|
|
..Default::default()
|
|
},
|
|
);
|
|
|
|
match output {
|
|
Ok(text) => {
|
|
assert!(!text.is_empty());
|
|
}
|
|
Err(e) => {
|
|
panic!("Generation failed: {}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
#[ignore = "Requires model file - run with --include-ignored"]
|
|
fn test_streaming_generation() {
|
|
let mut backend = MistralBackend::new().unwrap();
|
|
|
|
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
|
if load_result.is_err() {
|
|
return;
|
|
}
|
|
|
|
let stream = backend.generate_stream("Hello", GenerateParams::default());
|
|
match stream {
|
|
Ok(stream) => {
|
|
let tokens: Vec<_> = stream.collect();
|
|
assert!(!tokens.is_empty());
|
|
}
|
|
Err(e) => {
|
|
panic!("Streaming generation failed: {}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
#[ignore = "Requires model file - run with --include-ignored"]
|
|
fn test_embeddings_extraction() {
|
|
let mut backend = MistralBackend::new().unwrap();
|
|
|
|
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
|
if load_result.is_err() {
|
|
return;
|
|
}
|
|
|
|
let embeddings = backend.get_embeddings("Test text for embedding");
|
|
match embeddings {
|
|
Ok(emb) => {
|
|
assert!(!emb.is_empty());
|
|
assert!(emb.iter().all(|&v| v.is_finite()));
|
|
}
|
|
Err(e) => {
|
|
panic!("Embedding extraction failed: {}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
#[ignore = "Requires model file - run with --include-ignored"]
|
|
fn test_model_unload_and_reload() {
|
|
let mut backend = MistralBackend::new().unwrap();
|
|
|
|
// Load model
|
|
let load_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
|
if load_result.is_err() {
|
|
return;
|
|
}
|
|
assert!(backend.is_model_loaded());
|
|
|
|
// Unload
|
|
backend.unload_model();
|
|
assert!(!backend.is_model_loaded());
|
|
assert!(backend.model_info().is_none());
|
|
|
|
// Reload
|
|
let reload_result = backend.load_model("models/test-model.gguf", ModelConfig::default());
|
|
if reload_result.is_ok() {
|
|
assert!(backend.is_model_loaded());
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Integration Tests with PagedAttention
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_backend_paged_attention_integration() {
|
|
let config = MistralBackendConfig::default().with_paged_attention(16, 4096);
|
|
let backend = MistralBackend::with_config(config).unwrap();
|
|
|
|
// Verify PagedAttention is configured
|
|
let stats = backend.paged_attention_stats().unwrap();
|
|
assert!(stats.total_blocks > 0);
|
|
assert!(stats.free_blocks > 0);
|
|
assert_eq!(stats.active_sequences, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_backend_xlora_integration() {
|
|
let config = MistralBackendConfig::default().with_xlora_adapters(vec!["code", "math", "chat"]);
|
|
let backend = MistralBackend::with_config(config).unwrap();
|
|
|
|
// Verify X-LoRA is configured
|
|
let stats = backend.xlora_stats().unwrap();
|
|
assert_eq!(stats.loaded_adapters, 0); // None loaded yet
|
|
assert!(stats.adapter_usage.is_empty());
|
|
}
|
|
|
|
// ============================================================================
|
|
// Serialization Tests
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_config_serialization() {
|
|
let config = MistralBackendConfig::default()
|
|
.with_paged_attention(32, 8192)
|
|
.with_xlora_adapters(vec!["test"])
|
|
.with_isq(4);
|
|
|
|
// Test serialization roundtrip
|
|
let json = serde_json::to_string(&config).unwrap();
|
|
let deserialized: MistralBackendConfig = serde_json::from_str(&json).unwrap();
|
|
|
|
assert_eq!(deserialized.max_seq_len, config.max_seq_len);
|
|
assert_eq!(deserialized.max_batch_size, config.max_batch_size);
|
|
assert!(deserialized.paged_attention.is_some());
|
|
assert!(deserialized.xlora.is_some());
|
|
assert!(deserialized.isq.is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn test_paged_attention_config_serialization() {
|
|
let config = PagedAttentionConfigExt::default();
|
|
|
|
let json = serde_json::to_string(&config).unwrap();
|
|
let deserialized: PagedAttentionConfigExt = serde_json::from_str(&json).unwrap();
|
|
|
|
assert_eq!(deserialized.block_size, config.block_size);
|
|
assert_eq!(deserialized.max_pages, config.max_pages);
|
|
assert_eq!(
|
|
deserialized.enable_prefix_caching,
|
|
config.enable_prefix_caching
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_xlora_config_serialization() {
|
|
let config = XLoraConfig {
|
|
adapter_names: vec!["a".to_string(), "b".to_string()],
|
|
top_k: 3,
|
|
temperature: 0.5,
|
|
..Default::default()
|
|
};
|
|
|
|
let json = serde_json::to_string(&config).unwrap();
|
|
let deserialized: XLoraConfig = serde_json::from_str(&json).unwrap();
|
|
|
|
assert_eq!(deserialized.adapter_names.len(), 2);
|
|
assert_eq!(deserialized.top_k, 3);
|
|
assert!((deserialized.temperature - 0.5).abs() < f32::EPSILON);
|
|
}
|
|
|
|
#[test]
|
|
fn test_isq_config_serialization() {
|
|
let config = IsqConfig {
|
|
bits: 8,
|
|
method: IsqMethod::GPTQ,
|
|
symmetric: true,
|
|
per_channel: false,
|
|
calibration_samples: 256,
|
|
};
|
|
|
|
let json = serde_json::to_string(&config).unwrap();
|
|
let deserialized: IsqConfig = serde_json::from_str(&json).unwrap();
|
|
|
|
assert_eq!(deserialized.bits, 8);
|
|
assert!(matches!(deserialized.method, IsqMethod::GPTQ));
|
|
assert!(deserialized.symmetric);
|
|
assert!(!deserialized.per_channel);
|
|
assert_eq!(deserialized.calibration_samples, 256);
|
|
}
|
|
|
|
// ============================================================================
|
|
// Edge Case Tests
|
|
// ============================================================================
|
|
|
|
#[test]
|
|
fn test_empty_xlora_adapters() {
|
|
let config = MistralBackendConfig::default().with_xlora_adapters(vec![]);
|
|
|
|
let xlora = config.xlora.unwrap();
|
|
assert!(xlora.adapter_names.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_large_page_config() {
|
|
let config = MistralBackendConfig::default().with_paged_attention(256, 65536);
|
|
|
|
let pa = config.paged_attention.unwrap();
|
|
assert_eq!(pa.block_size, 256);
|
|
assert_eq!(pa.max_pages, 65536);
|
|
}
|
|
|
|
#[test]
|
|
fn test_multiple_backend_instances() {
|
|
let backend1 = MistralBackend::new().unwrap();
|
|
let backend2 = MistralBackend::for_metal().unwrap();
|
|
let backend3 = MistralBackend::for_cuda(0).unwrap();
|
|
|
|
assert!(!backend1.is_model_loaded());
|
|
assert!(!backend2.is_model_loaded());
|
|
assert!(!backend3.is_model_loaded());
|
|
}
|
|
|
|
#[test]
|
|
fn test_generate_params_integration() {
|
|
let params = GenerateParams {
|
|
max_tokens: 256,
|
|
temperature: 0.8,
|
|
top_p: 0.95,
|
|
top_k: 50,
|
|
repetition_penalty: 1.1,
|
|
frequency_penalty: 0.1,
|
|
presence_penalty: 0.1,
|
|
stop_sequences: vec!["STOP".to_string(), "\n\n".to_string()],
|
|
seed: Some(12345),
|
|
};
|
|
|
|
assert_eq!(params.max_tokens, 256);
|
|
assert!((params.temperature - 0.8).abs() < f32::EPSILON);
|
|
assert_eq!(params.stop_sequences.len(), 2);
|
|
assert_eq!(params.seed, Some(12345));
|
|
}
|