Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
251
crates/ruvector-core/tests/README.md
Normal file
251
crates/ruvector-core/tests/README.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# Ruvector Core Test Suite
|
||||
|
||||
## Overview
|
||||
|
||||
This directory contains a comprehensive Test-Driven Development (TDD) test suite following the London School approach. The test suite covers unit tests, integration tests, property-based tests, stress tests, and concurrent access tests.
|
||||
|
||||
## Test Files
|
||||
|
||||
### 1. `unit_tests.rs` - Unit Tests with Mocking (London School)
|
||||
Comprehensive unit tests using `mockall` for mocking dependencies:
|
||||
|
||||
- **Distance Metric Tests**: Tests for all 4 distance metrics (Euclidean, Cosine, Dot Product, Manhattan)
|
||||
- Self-distance verification
|
||||
- Symmetry properties
|
||||
- Orthogonal and parallel vector cases
|
||||
- Dimension mismatch error handling
|
||||
|
||||
- **Quantization Tests**: Tests for scalar and binary quantization
|
||||
- Round-trip reconstruction accuracy
|
||||
- Distance calculation correctness
|
||||
- Sign preservation (binary quantization)
|
||||
- Hamming distance validation
|
||||
|
||||
- **Storage Layer Tests**: Tests for VectorStorage
|
||||
- Insert with explicit and auto-generated IDs
|
||||
- Metadata handling
|
||||
- Dimension validation
|
||||
- Batch operations
|
||||
- Delete operations
|
||||
- Error cases (non-existent vectors, dimension mismatches)
|
||||
|
||||
- **VectorDB Tests**: High-level API tests
|
||||
- Empty database operations
|
||||
- Insert/delete with len() tracking
|
||||
- Search functionality
|
||||
- Metadata filtering
|
||||
- Batch insert operations
|
||||
|
||||
### 2. `integration_tests.rs` - End-to-End Integration Tests
|
||||
Full workflow tests that verify all components work together:
|
||||
|
||||
- **Complete Workflows**: Insert + search + retrieve with metadata
|
||||
- **Large Batch Operations**: 10K+ vector batch insertions
|
||||
- **Persistence**: Database save and reload verification
|
||||
- **Mixed Operations**: Combined insert, delete, and search operations
|
||||
- **Distance Metrics**: Tests for all 4 metrics end-to-end
|
||||
- **HNSW Configurations**: Different HNSW parameter combinations
|
||||
- **Metadata Filtering**: Complex filtering scenarios
|
||||
- **Error Handling**: Dimension validation, wrong query dimensions
|
||||
|
||||
### 3. `property_tests.rs` - Property-Based Tests (proptest)
|
||||
Mathematical property verification using proptest:
|
||||
|
||||
- **Distance Metric Properties**:
|
||||
- Self-distance is zero
|
||||
- Symmetry: d(a,b) = d(b,a)
|
||||
- Triangle inequality: d(a,c) ≤ d(a,b) + d(b,c)
|
||||
- Non-negativity
|
||||
- Scale invariance (cosine)
|
||||
- Translation invariance (Euclidean)
|
||||
|
||||
- **Quantization Properties**:
|
||||
- Round-trip reconstruction bounds
|
||||
- Sign preservation (binary)
|
||||
- Self-distance is zero
|
||||
- Symmetry
|
||||
- Distance bounds
|
||||
|
||||
- **Batch Operations**:
|
||||
- Consistency between batch and individual operations
|
||||
|
||||
- **Dimension Handling**:
|
||||
- Mismatch error detection
|
||||
- Success on matching dimensions
|
||||
|
||||
### 4. `stress_tests.rs` - Scalability and Performance Stress Tests
|
||||
Tests that push the system to its limits:
|
||||
|
||||
- **Million Vector Insertion** (ignored by default): Insert 1M vectors in batches
|
||||
- **Concurrent Queries**: 10 threads × 100 queries each
|
||||
- **Concurrent Mixed Operations**: Simultaneous readers and writers
|
||||
- **Memory Pressure**: Large 2048-dimensional vectors
|
||||
- **Error Recovery**: Invalid operations don't crash the system
|
||||
- **Repeated Operations**: Same operation executed many times
|
||||
- **Extreme Parameters**: k values larger than database size
|
||||
|
||||
### 5. `concurrent_tests.rs` - Thread-Safety Tests
|
||||
Multi-threaded access patterns:
|
||||
|
||||
- **Concurrent Reads**: Multiple threads reading simultaneously
|
||||
- **Concurrent Writes**: Non-overlapping writes from multiple threads
|
||||
- **Mixed Read/Write**: Concurrent reads and writes
|
||||
- **Delete and Insert**: Simultaneous deletes and inserts
|
||||
- **Search and Insert**: Searching while inserting
|
||||
- **Batch Atomicity**: Verifying batch operations are atomic
|
||||
- **Read-Write Consistency**: Ensuring no data corruption
|
||||
- **Metadata Updates**: Concurrent metadata modifications
|
||||
|
||||
## Benchmarks
|
||||
|
||||
### 6. `benches/quantization_bench.rs` - Quantization Performance
|
||||
Criterion benchmarks for quantization operations:
|
||||
|
||||
- Scalar quantization encode/decode/distance
|
||||
- Binary quantization encode/decode/distance
|
||||
- Compression ratio comparisons
|
||||
|
||||
### 7. `benches/batch_operations.rs` - Batch Operation Performance
|
||||
Criterion benchmarks for batch operations:
|
||||
|
||||
- Batch insert at various scales (100, 1K, 10K)
|
||||
- Individual vs batch insert comparison
|
||||
- Parallel search performance
|
||||
- Batch delete operations
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Run All Tests
|
||||
```bash
|
||||
cargo test --package ruvector-core
|
||||
```
|
||||
|
||||
### Run Specific Test Suites
|
||||
```bash
|
||||
# Unit tests only
|
||||
cargo test --test unit_tests
|
||||
|
||||
# Integration tests only
|
||||
cargo test --test integration_tests
|
||||
|
||||
# Property tests only
|
||||
cargo test --test property_tests
|
||||
|
||||
# Concurrent tests only
|
||||
cargo test --test concurrent_tests
|
||||
|
||||
# Stress tests (including ignored tests)
|
||||
cargo test --test stress_tests -- --ignored --test-threads=1
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
```bash
|
||||
# Distance metrics (existing)
|
||||
cargo bench --bench distance_metrics
|
||||
|
||||
# HNSW search (existing)
|
||||
cargo bench --bench hnsw_search
|
||||
|
||||
# Quantization (new)
|
||||
cargo bench --bench quantization_bench
|
||||
|
||||
# Batch operations (new)
|
||||
cargo bench --bench batch_operations
|
||||
```
|
||||
|
||||
### Generate Coverage Report
|
||||
```bash
|
||||
# Install tarpaulin if not already installed
|
||||
cargo install cargo-tarpaulin
|
||||
|
||||
# Generate HTML coverage report
|
||||
cargo tarpaulin --out Html --output-dir target/coverage
|
||||
|
||||
# Open coverage report
|
||||
open target/coverage/index.html
|
||||
```
|
||||
|
||||
## Test Coverage Goals
|
||||
|
||||
- **Target**: 90%+ code coverage
|
||||
- **Focus Areas**:
|
||||
- Distance calculations
|
||||
- Index operations
|
||||
- Storage layer
|
||||
- Error handling paths
|
||||
- Edge cases
|
||||
|
||||
## Known Issues
|
||||
|
||||
As of the current implementation, there are pre-existing compilation errors in the codebase that prevent some tests from running:
|
||||
|
||||
1. **HNSW Index**: `DataId::new` construction issues in `src/index/hnsw.rs`
|
||||
2. **AgenticDB**: Missing `Encode`/`Decode` trait implementations for `ReflexionEpisode`
|
||||
|
||||
These issues exist in the main codebase and need to be fixed before the full test suite can execute.
|
||||
|
||||
## Test Organization
|
||||
|
||||
Tests are organized by purpose and scope:
|
||||
|
||||
1. **Unit Tests**: Test individual components in isolation with mocking
|
||||
2. **Integration Tests**: Test component interactions and workflows
|
||||
3. **Property Tests**: Test mathematical properties and invariants
|
||||
4. **Stress Tests**: Test performance limits and edge cases
|
||||
5. **Concurrent Tests**: Test thread-safety and concurrent access patterns
|
||||
|
||||
## Dependencies
|
||||
|
||||
Test dependencies (in `Cargo.toml`):
|
||||
|
||||
```toml
|
||||
[dev-dependencies]
|
||||
criterion = { workspace = true }
|
||||
proptest = { workspace = true }
|
||||
mockall = { workspace = true }
|
||||
tempfile = "3.13"
|
||||
tracing-subscriber = { workspace = true }
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
When adding new tests:
|
||||
|
||||
1. Follow the existing structure and naming conventions
|
||||
2. Add tests to the appropriate file (unit, integration, property, etc.)
|
||||
3. Document the test purpose clearly
|
||||
4. Ensure tests are deterministic and don't depend on timing
|
||||
5. Use `tempdir()` for database paths in tests
|
||||
6. Clean up resources properly
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Recommended CI pipeline:
|
||||
|
||||
```yaml
|
||||
test:
|
||||
script:
|
||||
- cargo test --all-features
|
||||
- cargo tarpaulin --out Xml
|
||||
coverage: '/\d+\.\d+% coverage/'
|
||||
|
||||
bench:
|
||||
script:
|
||||
- cargo bench --no-run
|
||||
```
|
||||
|
||||
## Performance Expectations
|
||||
|
||||
Based on stress tests:
|
||||
|
||||
- **Insert**: ~10K vectors/second (batch mode)
|
||||
- **Search**: ~1K queries/second (k=10, HNSW)
|
||||
- **Concurrent**: 10+ threads without performance degradation
|
||||
- **Memory**: ~4KB per 384-dim vector (uncompressed)
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Mockall Documentation](https://docs.rs/mockall/)
|
||||
- [Proptest Guide](https://altsysrq.github.io/proptest-book/)
|
||||
- [Criterion.rs Guide](https://bheisler.github.io/criterion.rs/book/)
|
||||
- [Cargo Tarpaulin](https://github.com/xd009642/tarpaulin)
|
||||
550
crates/ruvector-core/tests/advanced_features_integration.rs
Normal file
550
crates/ruvector-core/tests/advanced_features_integration.rs
Normal file
@@ -0,0 +1,550 @@
|
||||
//! Integration tests for advanced features
|
||||
//!
|
||||
//! Tests Enhanced PQ, Filtered Search, MMR, Hybrid Search, and Conformal Prediction
|
||||
//! across multiple vector dimensions (128D, 384D, 768D)
|
||||
|
||||
use ruvector_core::advanced_features::*;
|
||||
use ruvector_core::types::{DistanceMetric, SearchResult};
|
||||
use ruvector_core::{Result, RuvectorError};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Helper function to generate random vectors
|
||||
fn generate_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
(0..count)
|
||||
.map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Helper function to normalize vectors
|
||||
fn normalize_vector(v: &[f32]) -> Vec<f32> {
|
||||
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
v.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enhanced_pq_128d() {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 1000;
|
||||
|
||||
let config = PQConfig {
|
||||
num_subspaces: 8,
|
||||
codebook_size: 256,
|
||||
num_iterations: 10,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
|
||||
|
||||
// Generate training data
|
||||
let training_data = generate_vectors(num_vectors, dimensions);
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
// Test encoding and search
|
||||
let query = normalize_vector(&generate_vectors(1, dimensions)[0]);
|
||||
|
||||
// Add quantized vectors
|
||||
for (i, vector) in training_data.iter().enumerate() {
|
||||
pq.add_quantized(format!("vec_{}", i), vector).unwrap();
|
||||
}
|
||||
|
||||
// Perform search
|
||||
let results = pq.search(&query, 10).unwrap();
|
||||
assert_eq!(results.len(), 10);
|
||||
|
||||
// Test compression ratio
|
||||
let compression_ratio = pq.compression_ratio();
|
||||
assert!(compression_ratio >= 8.0); // Should be 16x for 128D with 8 subspaces
|
||||
|
||||
println!(
|
||||
"✓ Enhanced PQ 128D: compression ratio = {:.1}x",
|
||||
compression_ratio
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enhanced_pq_384d() {
|
||||
let dimensions = 384;
|
||||
let num_vectors = 500;
|
||||
|
||||
let config = PQConfig {
|
||||
num_subspaces: 8,
|
||||
codebook_size: 256,
|
||||
num_iterations: 10,
|
||||
metric: DistanceMetric::Cosine,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
|
||||
|
||||
// Generate training data
|
||||
let training_data: Vec<Vec<f32>> = generate_vectors(num_vectors, dimensions)
|
||||
.into_iter()
|
||||
.map(|v| normalize_vector(&v))
|
||||
.collect();
|
||||
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
// Test reconstruction
|
||||
let test_vector = &training_data[0];
|
||||
let codes = pq.encode(test_vector).unwrap();
|
||||
let reconstructed = pq.reconstruct(&codes).unwrap();
|
||||
|
||||
assert_eq!(reconstructed.len(), dimensions);
|
||||
|
||||
// Calculate reconstruction error
|
||||
let error: f32 = test_vector
|
||||
.iter()
|
||||
.zip(&reconstructed)
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt();
|
||||
|
||||
println!("✓ Enhanced PQ 384D: reconstruction error = {:.4}", error);
|
||||
assert!(error < 5.0); // Reasonable reconstruction error
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enhanced_pq_768d() {
|
||||
let dimensions = 768;
|
||||
let num_vectors = 300; // Increased to ensure we have enough vectors for search
|
||||
|
||||
let config = PQConfig {
|
||||
num_subspaces: 16,
|
||||
codebook_size: 256,
|
||||
num_iterations: 10,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
|
||||
|
||||
let training_data = generate_vectors(num_vectors, dimensions);
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
// Test lookup table creation
|
||||
let query = generate_vectors(1, dimensions)[0].clone();
|
||||
let lookup_table = pq.create_lookup_table(&query).unwrap();
|
||||
|
||||
assert_eq!(lookup_table.tables.len(), 16);
|
||||
assert_eq!(lookup_table.tables[0].len(), 256);
|
||||
|
||||
println!("✓ Enhanced PQ 768D: lookup table created successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filtered_search_pre_filter() {
|
||||
use serde_json::json;
|
||||
|
||||
// Create metadata store
|
||||
let mut metadata_store = HashMap::new();
|
||||
for i in 0..100 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"category".to_string(),
|
||||
json!(if i % 3 == 0 { "A" } else { "B" }),
|
||||
);
|
||||
metadata.insert("price".to_string(), json!(i as f32 * 10.0));
|
||||
metadata_store.insert(format!("vec_{}", i), metadata);
|
||||
}
|
||||
|
||||
// Create filter: category == "A" AND price < 500
|
||||
let filter = FilterExpression::And(vec![
|
||||
FilterExpression::Eq("category".to_string(), json!("A")),
|
||||
FilterExpression::Lt("price".to_string(), json!(500.0)),
|
||||
]);
|
||||
|
||||
let search = FilteredSearch::new(filter, FilterStrategy::PreFilter, metadata_store);
|
||||
|
||||
// Test pre-filtering
|
||||
let filtered_ids = search.get_filtered_ids();
|
||||
assert!(!filtered_ids.is_empty());
|
||||
assert!(filtered_ids.len() < 50); // Should be selective
|
||||
|
||||
println!(
|
||||
"✓ Filtered Search (Pre-filter): {} matching documents",
|
||||
filtered_ids.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filtered_search_auto_strategy() {
|
||||
use serde_json::json;
|
||||
|
||||
let mut metadata_store = HashMap::new();
|
||||
for i in 0..1000 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("id".to_string(), json!(i));
|
||||
metadata_store.insert(format!("vec_{}", i), metadata);
|
||||
}
|
||||
|
||||
// Highly selective filter (should choose pre-filter)
|
||||
let selective_filter = FilterExpression::Eq("id".to_string(), json!(42));
|
||||
let search1 = FilteredSearch::new(
|
||||
selective_filter,
|
||||
FilterStrategy::Auto,
|
||||
metadata_store.clone(),
|
||||
);
|
||||
assert_eq!(search1.auto_select_strategy(), FilterStrategy::PreFilter);
|
||||
|
||||
// Less selective filter (should choose post-filter)
|
||||
let broad_filter = FilterExpression::Gte("id".to_string(), json!(0));
|
||||
let search2 = FilteredSearch::new(broad_filter, FilterStrategy::Auto, metadata_store);
|
||||
assert_eq!(search2.auto_select_strategy(), FilterStrategy::PostFilter);
|
||||
|
||||
println!("✓ Filtered Search: automatic strategy selection working");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_diversity_128d() {
|
||||
let dimensions = 128;
|
||||
|
||||
let config = MMRConfig {
|
||||
lambda: 0.5, // Balance relevance and diversity
|
||||
metric: DistanceMetric::Cosine,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
|
||||
let mmr = MMRSearch::new(config).unwrap();
|
||||
|
||||
// Create query and candidates
|
||||
let query = normalize_vector(&generate_vectors(1, dimensions)[0]);
|
||||
let candidates: Vec<SearchResult> = (0..20)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("doc_{}", i),
|
||||
score: i as f32 * 0.05,
|
||||
vector: Some(normalize_vector(&generate_vectors(1, dimensions)[0])),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Rerank using MMR
|
||||
let results = mmr.rerank(&query, candidates, 10).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
// Results should be diverse (not just top-10 by relevance)
|
||||
|
||||
println!("✓ MMR 128D: diversified {} results", results.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_lambda_variations() {
|
||||
let dimensions = 64;
|
||||
|
||||
// Test with pure relevance (lambda = 1.0)
|
||||
let config_relevance = MMRConfig {
|
||||
lambda: 1.0,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
let mmr_relevance = MMRSearch::new(config_relevance).unwrap();
|
||||
|
||||
// Test with pure diversity (lambda = 0.0)
|
||||
let config_diversity = MMRConfig {
|
||||
lambda: 0.0,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
let mmr_diversity = MMRSearch::new(config_diversity).unwrap();
|
||||
|
||||
let query = generate_vectors(1, dimensions)[0].clone();
|
||||
let candidates: Vec<SearchResult> = (0..10)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("doc_{}", i),
|
||||
score: i as f32 * 0.1,
|
||||
vector: Some(generate_vectors(1, dimensions)[0].clone()),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results_relevance = mmr_relevance.rerank(&query, candidates.clone(), 5).unwrap();
|
||||
let results_diversity = mmr_diversity.rerank(&query, candidates, 5).unwrap();
|
||||
|
||||
assert_eq!(results_relevance.len(), 5);
|
||||
assert_eq!(results_diversity.len(), 5);
|
||||
|
||||
println!("✓ MMR: lambda variations tested successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_search_basic() {
|
||||
let config = HybridConfig {
|
||||
vector_weight: 0.7,
|
||||
keyword_weight: 0.3,
|
||||
normalization: NormalizationStrategy::MinMax,
|
||||
};
|
||||
|
||||
let mut hybrid = HybridSearch::new(config);
|
||||
|
||||
// Index documents
|
||||
hybrid.index_document("doc1".to_string(), "rust programming language".to_string());
|
||||
hybrid.index_document("doc2".to_string(), "python machine learning".to_string());
|
||||
hybrid.index_document("doc3".to_string(), "rust systems programming".to_string());
|
||||
hybrid.finalize_indexing();
|
||||
|
||||
// Test BM25 scoring
|
||||
let score = hybrid.bm25.score(
|
||||
"rust programming",
|
||||
&"doc1".to_string(),
|
||||
"rust programming language",
|
||||
);
|
||||
assert!(score > 0.0);
|
||||
|
||||
println!(
|
||||
"✓ Hybrid Search: indexed {} documents",
|
||||
hybrid.doc_texts.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_search_keyword_matching() {
|
||||
let mut bm25 = BM25::new(1.5, 0.75);
|
||||
|
||||
bm25.index_document("doc1".to_string(), "vector database with HNSW indexing");
|
||||
bm25.index_document("doc2".to_string(), "relational database management system");
|
||||
bm25.index_document("doc3".to_string(), "vector search and similarity matching");
|
||||
bm25.build_idf();
|
||||
|
||||
// Test candidate retrieval
|
||||
let candidates = bm25.get_candidate_docs("vector database");
|
||||
assert!(candidates.contains(&"doc1".to_string()));
|
||||
assert!(candidates.contains(&"doc3".to_string()));
|
||||
|
||||
// Test scoring
|
||||
let score1 = bm25.score(
|
||||
"vector database",
|
||||
&"doc1".to_string(),
|
||||
"vector database with HNSW indexing",
|
||||
);
|
||||
let score2 = bm25.score(
|
||||
"vector database",
|
||||
&"doc2".to_string(),
|
||||
"relational database management system",
|
||||
);
|
||||
|
||||
assert!(score1 > score2); // doc1 matches better
|
||||
|
||||
println!(
|
||||
"✓ Hybrid Search (BM25): {} candidate documents",
|
||||
candidates.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_prediction_128d() {
|
||||
let dimensions = 128;
|
||||
|
||||
let config = ConformalConfig {
|
||||
alpha: 0.1, // 90% coverage
|
||||
calibration_fraction: 0.2,
|
||||
nonconformity_measure: NonconformityMeasure::Distance,
|
||||
};
|
||||
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
// Create calibration data
|
||||
let calibration_queries = generate_vectors(10, dimensions);
|
||||
let true_neighbors: Vec<Vec<String>> = (0..10)
|
||||
.map(|i| vec![format!("vec_{}", i), format!("vec_{}", i + 1)])
|
||||
.collect();
|
||||
|
||||
// Mock search function
|
||||
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
|
||||
Ok((0..k)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("vec_{}", i),
|
||||
score: i as f32 * 0.1,
|
||||
vector: Some(vec![0.0; dimensions]),
|
||||
metadata: None,
|
||||
})
|
||||
.collect())
|
||||
};
|
||||
|
||||
// Calibrate
|
||||
predictor
|
||||
.calibrate(&calibration_queries, &true_neighbors, search_fn)
|
||||
.unwrap();
|
||||
|
||||
assert!(predictor.threshold.is_some());
|
||||
assert!(!predictor.calibration_scores.is_empty());
|
||||
|
||||
// Make prediction
|
||||
let query = generate_vectors(1, dimensions)[0].clone();
|
||||
let prediction_set = predictor.predict(&query, search_fn).unwrap();
|
||||
|
||||
assert_eq!(prediction_set.confidence, 0.9);
|
||||
assert!(!prediction_set.results.is_empty());
|
||||
|
||||
println!(
|
||||
"✓ Conformal Prediction 128D: prediction set size = {}",
|
||||
prediction_set.results.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_prediction_384d() {
|
||||
let dimensions = 384;
|
||||
|
||||
let config = ConformalConfig {
|
||||
alpha: 0.05, // 95% coverage
|
||||
calibration_fraction: 0.2,
|
||||
nonconformity_measure: NonconformityMeasure::NormalizedDistance,
|
||||
};
|
||||
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
let calibration_queries = generate_vectors(5, dimensions);
|
||||
let true_neighbors: Vec<Vec<String>> = (0..5).map(|i| vec![format!("vec_{}", i)]).collect();
|
||||
|
||||
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
|
||||
Ok((0..k)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("vec_{}", i),
|
||||
score: 0.1 + (i as f32 * 0.05),
|
||||
vector: Some(vec![0.0; dimensions]),
|
||||
metadata: None,
|
||||
})
|
||||
.collect())
|
||||
};
|
||||
|
||||
predictor
|
||||
.calibrate(&calibration_queries, &true_neighbors, search_fn)
|
||||
.unwrap();
|
||||
|
||||
// Test calibration statistics
|
||||
let stats = predictor.get_statistics().unwrap();
|
||||
assert_eq!(stats.num_samples, 5);
|
||||
assert!(stats.mean > 0.0);
|
||||
|
||||
println!("✓ Conformal Prediction 384D: calibration stats computed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_prediction_adaptive_k() {
|
||||
let dimensions = 256;
|
||||
|
||||
let config = ConformalConfig {
|
||||
alpha: 0.1,
|
||||
calibration_fraction: 0.2,
|
||||
nonconformity_measure: NonconformityMeasure::InverseRank,
|
||||
};
|
||||
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
let calibration_queries = generate_vectors(8, dimensions);
|
||||
let true_neighbors: Vec<Vec<String>> = (0..8).map(|i| vec![format!("vec_{}", i)]).collect();
|
||||
|
||||
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
|
||||
Ok((0..k)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("vec_{}", i),
|
||||
score: i as f32 * 0.08,
|
||||
vector: Some(vec![0.0; dimensions]),
|
||||
metadata: None,
|
||||
})
|
||||
.collect())
|
||||
};
|
||||
|
||||
predictor
|
||||
.calibrate(&calibration_queries, &true_neighbors, search_fn)
|
||||
.unwrap();
|
||||
|
||||
// Test adaptive top-k
|
||||
let query = generate_vectors(1, dimensions)[0].clone();
|
||||
let adaptive_k = predictor.adaptive_top_k(&query, search_fn).unwrap();
|
||||
|
||||
assert!(adaptive_k > 0);
|
||||
println!("✓ Conformal Prediction: adaptive k = {}", adaptive_k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_features_integration() {
|
||||
// Test that all features can work together
|
||||
let dimensions = 128;
|
||||
|
||||
// 1. Enhanced PQ
|
||||
let pq_config = PQConfig {
|
||||
num_subspaces: 4,
|
||||
codebook_size: 16,
|
||||
num_iterations: 5,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
let mut pq = EnhancedPQ::new(dimensions, pq_config).unwrap();
|
||||
let training_data = generate_vectors(50, dimensions);
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
// 2. MMR
|
||||
let mmr_config = MMRConfig::default();
|
||||
let mmr = MMRSearch::new(mmr_config).unwrap();
|
||||
|
||||
// 3. Hybrid Search
|
||||
let hybrid_config = HybridConfig::default();
|
||||
let mut hybrid = HybridSearch::new(hybrid_config);
|
||||
hybrid.index_document("doc1".to_string(), "test document".to_string());
|
||||
hybrid.finalize_indexing();
|
||||
|
||||
// 4. Conformal Prediction
|
||||
let cp_config = ConformalConfig::default();
|
||||
let predictor = ConformalPredictor::new(cp_config).unwrap();
|
||||
|
||||
println!("✓ All features integrated successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pq_recall_384d() {
|
||||
let dimensions = 384;
|
||||
let num_vectors = 500;
|
||||
let k = 10;
|
||||
|
||||
let config = PQConfig {
|
||||
num_subspaces: 8,
|
||||
codebook_size: 256,
|
||||
num_iterations: 15,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
|
||||
|
||||
// Generate and train
|
||||
let vectors = generate_vectors(num_vectors, dimensions);
|
||||
pq.train(&vectors).unwrap();
|
||||
|
||||
// Add vectors
|
||||
for (i, vector) in vectors.iter().enumerate() {
|
||||
pq.add_quantized(format!("vec_{}", i), vector).unwrap();
|
||||
}
|
||||
|
||||
// Test search
|
||||
let query = &vectors[0]; // Use first vector as query
|
||||
let results = pq.search(query, k).unwrap();
|
||||
|
||||
// Verify we got results
|
||||
assert!(!results.is_empty(), "Search should return results");
|
||||
assert_eq!(results.len(), k, "Should return k results");
|
||||
|
||||
// First result should be among the top candidates (PQ is approximate)
|
||||
// Due to quantization, the exact match might not be at position 0
|
||||
// but the distance should be reasonably small relative to random vectors
|
||||
let min_distance = results
|
||||
.iter()
|
||||
.map(|(_, d)| *d)
|
||||
.fold(f32::INFINITY, f32::min);
|
||||
|
||||
// In high dimensions, PQ distances vary based on quantization quality
|
||||
// Check that we get reasonable results (top result should be closer than random)
|
||||
assert!(
|
||||
min_distance < 50.0,
|
||||
"Minimum distance {} should be reasonable for quantized search",
|
||||
min_distance
|
||||
);
|
||||
|
||||
println!(
|
||||
"✓ PQ 384D Recall Test: top-{} results retrieved, min distance = {:.4}",
|
||||
results.len(),
|
||||
min_distance
|
||||
);
|
||||
}
|
||||
440
crates/ruvector-core/tests/concurrent_tests.rs
Normal file
440
crates/ruvector-core/tests/concurrent_tests.rs
Normal file
@@ -0,0 +1,440 @@
|
||||
//! Concurrent access tests with multiple threads
|
||||
//!
|
||||
//! These tests verify thread-safety and correct behavior under concurrent access.
|
||||
|
||||
use ruvector_core::types::{DbOptions, HnswConfig, SearchQuery};
|
||||
use ruvector_core::{VectorDB, VectorEntry};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_reads() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent_reads.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 32;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial data
|
||||
for i in 0..100 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Spawn multiple reader threads
|
||||
let num_threads = 10;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..50 {
|
||||
let id = format!("vec_{}", (thread_id * 10 + i) % 100);
|
||||
let result = db_clone.get(&id).unwrap();
|
||||
assert!(result.is_some(), "Failed to get {}", id);
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_writes_no_collision() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent_writes.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 32;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Spawn multiple writer threads with non-overlapping IDs
|
||||
let num_threads = 10;
|
||||
let vectors_per_thread = 20;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..vectors_per_thread {
|
||||
let id = format!("thread_{}_{}", thread_id, i);
|
||||
db_clone
|
||||
.insert(VectorEntry {
|
||||
id: Some(id.clone()),
|
||||
vector: vec![thread_id as f32; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify all vectors were inserted
|
||||
assert_eq!(db.len().unwrap(), num_threads * vectors_per_thread);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_delete_and_insert() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent_delete_insert.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 16;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial data
|
||||
for i in 0..100 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 16],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let num_threads = 5;
|
||||
let mut handles = vec![];
|
||||
|
||||
// Deleter threads
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
let id = format!("vec_{}", thread_id * 10 + i);
|
||||
db_clone.delete(&id).unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Inserter threads
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
let id = format!("new_{}_{}", thread_id, i);
|
||||
db_clone
|
||||
.insert(VectorEntry {
|
||||
id: Some(id),
|
||||
vector: vec![(thread_id * 100 + i) as f32; 16],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify database is in consistent state
|
||||
let final_len = db.len().unwrap();
|
||||
assert_eq!(final_len, 100); // 100 original - 50 deleted + 50 inserted
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_search_and_insert() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent_search_insert.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = Some(HnswConfig::default());
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial data
|
||||
for i in 0..100 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let num_search_threads = 5;
|
||||
let num_insert_threads = 2;
|
||||
let mut handles = vec![];
|
||||
|
||||
// Search threads
|
||||
for search_id in 0..num_search_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..20 {
|
||||
let query: Vec<f32> = (0..64)
|
||||
.map(|j| ((search_id * 10 + i + j) as f32) * 0.01)
|
||||
.collect();
|
||||
let results = db_clone
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 5,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Should always return some results (at least from initial data)
|
||||
assert!(results.len() > 0);
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Insert threads
|
||||
for insert_id in 0..num_insert_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..50 {
|
||||
db_clone
|
||||
.insert(VectorEntry {
|
||||
id: Some(format!("new_{}_{}", insert_id, i)),
|
||||
vector: (0..64)
|
||||
.map(|j| ((insert_id * 1000 + i + j) as f32) * 0.01)
|
||||
.collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
assert_eq!(db.len().unwrap(), 200); // 100 initial + 100 new
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomicity_of_batch_insert() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("atomic_batch.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 16;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Track successful insertions
|
||||
let inserted_ids = Arc::new(Mutex::new(HashSet::new()));
|
||||
|
||||
let num_threads = 5;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let ids_clone = Arc::clone(&inserted_ids);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for batch_idx in 0..10 {
|
||||
let vectors: Vec<VectorEntry> = (0..10)
|
||||
.map(|i| {
|
||||
let id = format!("t{}_b{}_v{}", thread_id, batch_idx, i);
|
||||
VectorEntry {
|
||||
id: Some(id.clone()),
|
||||
vector: vec![(thread_id * 100 + batch_idx * 10 + i) as f32; 16],
|
||||
metadata: None,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ids = db_clone.insert_batch(vectors).unwrap();
|
||||
|
||||
let mut lock = ids_clone.lock().unwrap();
|
||||
for id in ids {
|
||||
lock.insert(id);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify all insertions were recorded
|
||||
let total_inserted = inserted_ids.lock().unwrap().len();
|
||||
assert_eq!(total_inserted, num_threads * 10 * 10); // threads * batches * vectors_per_batch
|
||||
assert_eq!(db.len().unwrap(), total_inserted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_write_consistency() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("consistency.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 32;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial vector
|
||||
db.insert(VectorEntry {
|
||||
id: Some("test".to_string()),
|
||||
vector: vec![1.0; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let num_threads = 10;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
// Read
|
||||
let entry = db_clone.get("test").unwrap();
|
||||
assert!(entry.is_some());
|
||||
|
||||
// Verify vector is consistent
|
||||
let vector = entry.unwrap().vector;
|
||||
assert_eq!(vector.len(), 32);
|
||||
|
||||
// All values should be the same (not corrupted)
|
||||
let first_val = vector[0];
|
||||
assert!(vector
|
||||
.iter()
|
||||
.all(|&v| v == first_val || (first_val == 1.0 || v == (thread_id as f32))));
|
||||
|
||||
// Write (update) - this creates a race condition intentionally
|
||||
if thread_id % 2 == 0 {
|
||||
let _ = db_clone.insert(VectorEntry {
|
||||
id: Some("test".to_string()),
|
||||
vector: vec![thread_id as f32; 32],
|
||||
metadata: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify database is still consistent
|
||||
let final_entry = db.get("test").unwrap();
|
||||
assert!(final_entry.is_some());
|
||||
|
||||
let vector = final_entry.unwrap().vector;
|
||||
assert_eq!(vector.len(), 32);
|
||||
|
||||
// Check no corruption (all values should be the same)
|
||||
let first_val = vector[0];
|
||||
assert!(vector.iter().all(|&v| v == first_val));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_metadata_updates() {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("metadata.db").to_string_lossy().to_string();
|
||||
options.dimensions = 16;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial vectors
|
||||
for i in 0..50 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 16],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let num_threads = 5;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(format!("thread_{}", thread_id), serde_json::json!(i));
|
||||
|
||||
// Update vector with metadata
|
||||
let id = format!("vec_{}", i * 5 + thread_id);
|
||||
db_clone
|
||||
.insert(VectorEntry {
|
||||
id: Some(id.clone()),
|
||||
vector: vec![thread_id as f32; 16],
|
||||
metadata: Some(metadata),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify some vectors have metadata
|
||||
let entry = db.get("vec_0").unwrap();
|
||||
assert!(entry.is_some());
|
||||
}
|
||||
307
crates/ruvector-core/tests/embeddings_test.rs
Normal file
307
crates/ruvector-core/tests/embeddings_test.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
//! Integration tests for embedding providers
|
||||
|
||||
use ruvector_core::embeddings::{ApiEmbedding, EmbeddingProvider, HashEmbedding};
|
||||
use ruvector_core::{types::DbOptions, AgenticDB};
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_hash_embedding_provider() {
|
||||
let provider = HashEmbedding::new(128);
|
||||
|
||||
// Test basic embedding
|
||||
let emb1 = provider.embed("hello world").unwrap();
|
||||
assert_eq!(emb1.len(), 128);
|
||||
|
||||
// Test consistency
|
||||
let emb2 = provider.embed("hello world").unwrap();
|
||||
assert_eq!(emb1, emb2, "Same text should produce same embedding");
|
||||
|
||||
// Test different text produces different embeddings
|
||||
let emb3 = provider.embed("goodbye world").unwrap();
|
||||
assert_ne!(
|
||||
emb1, emb3,
|
||||
"Different text should produce different embeddings"
|
||||
);
|
||||
|
||||
// Test normalization
|
||||
let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!(
|
||||
(norm - 1.0).abs() < 1e-5,
|
||||
"Embedding should be normalized to unit length"
|
||||
);
|
||||
|
||||
// Test provider info
|
||||
assert_eq!(provider.dimensions(), 128);
|
||||
assert!(provider.name().contains("Hash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agenticdb_with_hash_embeddings() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
|
||||
// Create AgenticDB with default hash embeddings
|
||||
let db = AgenticDB::new(options).unwrap();
|
||||
|
||||
assert_eq!(db.embedding_provider_name(), "HashEmbedding (placeholder)");
|
||||
|
||||
// Test storing a reflexion episode
|
||||
let episode_id = db
|
||||
.store_episode(
|
||||
"Solve a math problem".to_string(),
|
||||
vec!["read problem".to_string(), "calculate".to_string()],
|
||||
vec!["got answer 42".to_string()],
|
||||
"Should have shown intermediate steps".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Test retrieving similar episodes
|
||||
let episodes = db
|
||||
.retrieve_similar_episodes("math problem solving", 5)
|
||||
.unwrap();
|
||||
assert!(!episodes.is_empty());
|
||||
assert_eq!(episodes[0].id, episode_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agenticdb_with_custom_hash_provider() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 256;
|
||||
|
||||
// Create custom hash provider
|
||||
let provider = Arc::new(HashEmbedding::new(256));
|
||||
|
||||
// Create AgenticDB with custom provider
|
||||
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
|
||||
|
||||
assert_eq!(db.embedding_provider_name(), "HashEmbedding (placeholder)");
|
||||
|
||||
// Test creating a skill
|
||||
let mut params = std::collections::HashMap::new();
|
||||
params.insert("input".to_string(), "string".to_string());
|
||||
|
||||
let skill_id = db
|
||||
.create_skill(
|
||||
"Parse JSON".to_string(),
|
||||
"Parse JSON from string".to_string(),
|
||||
params,
|
||||
vec!["json.parse()".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Search for skills
|
||||
let skills = db.search_skills("parse json data", 5).unwrap();
|
||||
assert!(!skills.is_empty());
|
||||
assert_eq!(skills[0].id, skill_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch_validation() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
|
||||
// Try to create with mismatched dimensions
|
||||
let provider = Arc::new(HashEmbedding::new(256)); // Different from options
|
||||
|
||||
let result = AgenticDB::with_embedding_provider(options, provider);
|
||||
assert!(result.is_err(), "Should fail when dimensions don't match");
|
||||
|
||||
if let Err(err) = result {
|
||||
assert!(
|
||||
err.to_string().contains("do not match"),
|
||||
"Error should mention dimension mismatch"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_embedding_provider_construction() {
|
||||
// Test OpenAI provider construction
|
||||
let openai_small = ApiEmbedding::openai("sk-test", "text-embedding-3-small");
|
||||
assert_eq!(openai_small.dimensions(), 1536);
|
||||
assert_eq!(openai_small.name(), "ApiEmbedding");
|
||||
|
||||
let openai_large = ApiEmbedding::openai("sk-test", "text-embedding-3-large");
|
||||
assert_eq!(openai_large.dimensions(), 3072);
|
||||
|
||||
// Test Cohere provider construction
|
||||
let cohere = ApiEmbedding::cohere("co-test", "embed-english-v3.0");
|
||||
assert_eq!(cohere.dimensions(), 1024);
|
||||
|
||||
// Test Voyage provider construction
|
||||
let voyage = ApiEmbedding::voyage("vo-test", "voyage-2");
|
||||
assert_eq!(voyage.dimensions(), 1024);
|
||||
|
||||
let voyage_large = ApiEmbedding::voyage("vo-test", "voyage-large-2");
|
||||
assert_eq!(voyage_large.dimensions(), 1536);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Requires API key and network access
|
||||
fn test_api_embedding_openai() {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.expect("OPENAI_API_KEY environment variable required for this test");
|
||||
|
||||
let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
|
||||
|
||||
let embedding = provider.embed("hello world").unwrap();
|
||||
assert_eq!(embedding.len(), 1536);
|
||||
|
||||
// Check that embeddings are different for different texts
|
||||
let embedding2 = provider.embed("goodbye world").unwrap();
|
||||
assert_ne!(embedding, embedding2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Requires API key and network access
|
||||
fn test_agenticdb_with_openai_embeddings() {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.expect("OPENAI_API_KEY environment variable required for this test");
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 1536; // OpenAI text-embedding-3-small dimensions
|
||||
|
||||
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
|
||||
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
|
||||
|
||||
assert_eq!(db.embedding_provider_name(), "ApiEmbedding");
|
||||
|
||||
// Test with real semantic embeddings
|
||||
let _episode1_id = db
|
||||
.store_episode(
|
||||
"Solve calculus problem".to_string(),
|
||||
vec![
|
||||
"identify function".to_string(),
|
||||
"take derivative".to_string(),
|
||||
],
|
||||
vec!["computed derivative".to_string()],
|
||||
"Should explain chain rule application".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let _episode2_id = db
|
||||
.store_episode(
|
||||
"Solve algebra problem".to_string(),
|
||||
vec!["simplify equation".to_string(), "solve for x".to_string()],
|
||||
vec!["found x = 5".to_string()],
|
||||
"Should show all steps".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Search with semantic query - should find calculus episode first
|
||||
let episodes = db
|
||||
.retrieve_similar_episodes("derivative calculation", 2)
|
||||
.unwrap();
|
||||
assert!(!episodes.is_empty());
|
||||
|
||||
// With real embeddings, "derivative" should match calculus better than algebra
|
||||
println!(
|
||||
"Found episodes: {:?}",
|
||||
episodes.iter().map(|e| &e.task).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "real-embeddings")]
|
||||
#[test]
|
||||
#[ignore] // Requires model download
|
||||
fn test_candle_embedding_provider() {
|
||||
use ruvector_core::CandleEmbedding;
|
||||
|
||||
let provider =
|
||||
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false).unwrap();
|
||||
|
||||
assert_eq!(provider.dimensions(), 384);
|
||||
assert_eq!(provider.name(), "CandleEmbedding (transformer)");
|
||||
|
||||
let embedding = provider.embed("hello world").unwrap();
|
||||
assert_eq!(embedding.len(), 384);
|
||||
|
||||
// Check normalization
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-3, "Embedding should be normalized");
|
||||
|
||||
// Test semantic similarity
|
||||
let emb_dog = provider.embed("dog").unwrap();
|
||||
let emb_cat = provider.embed("cat").unwrap();
|
||||
let emb_car = provider.embed("car").unwrap();
|
||||
|
||||
// Cosine similarity
|
||||
let similarity_dog_cat: f32 = emb_dog.iter().zip(emb_cat.iter()).map(|(a, b)| a * b).sum();
|
||||
|
||||
let similarity_dog_car: f32 = emb_dog.iter().zip(emb_car.iter()).map(|(a, b)| a * b).sum();
|
||||
|
||||
// "dog" and "cat" should be more similar than "dog" and "car"
|
||||
assert!(
|
||||
similarity_dog_cat > similarity_dog_car,
|
||||
"Semantic embeddings should show dog-cat more similar than dog-car"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "real-embeddings")]
|
||||
#[test]
|
||||
#[ignore] // Requires model download
|
||||
fn test_agenticdb_with_candle_embeddings() {
|
||||
use ruvector_core::CandleEmbedding;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 384;
|
||||
|
||||
let provider = Arc::new(
|
||||
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false).unwrap(),
|
||||
);
|
||||
|
||||
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
db.embedding_provider_name(),
|
||||
"CandleEmbedding (transformer)"
|
||||
);
|
||||
|
||||
// Test with real semantic embeddings
|
||||
let skill1_id = db
|
||||
.create_skill(
|
||||
"File I/O".to_string(),
|
||||
"Read and write files to disk".to_string(),
|
||||
std::collections::HashMap::new(),
|
||||
vec![
|
||||
"open()".to_string(),
|
||||
"read()".to_string(),
|
||||
"write()".to_string(),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let skill2_id = db
|
||||
.create_skill(
|
||||
"Network I/O".to_string(),
|
||||
"Send and receive data over network".to_string(),
|
||||
std::collections::HashMap::new(),
|
||||
vec![
|
||||
"connect()".to_string(),
|
||||
"send()".to_string(),
|
||||
"recv()".to_string(),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Search with semantic query
|
||||
let skills = db.search_skills("reading files from storage", 2).unwrap();
|
||||
assert!(!skills.is_empty());
|
||||
|
||||
// With real embeddings, file I/O should match better
|
||||
println!(
|
||||
"Found skills: {:?}",
|
||||
skills.iter().map(|s| &s.name).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
495
crates/ruvector-core/tests/hnsw_integration_test.rs
Normal file
495
crates/ruvector-core/tests/hnsw_integration_test.rs
Normal file
@@ -0,0 +1,495 @@
|
||||
//! Comprehensive HNSW integration tests with different index sizes
|
||||
|
||||
use ruvector_core::index::hnsw::HnswIndex;
|
||||
use ruvector_core::index::VectorIndex;
|
||||
use ruvector_core::types::{DistanceMetric, HnswConfig};
|
||||
use ruvector_core::Result;
|
||||
|
||||
fn generate_random_vectors(count: usize, dimensions: usize, seed: u64) -> Vec<Vec<f32>> {
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
|
||||
(0..count)
|
||||
.map(|_| {
|
||||
(0..dimensions)
|
||||
.map(|_| rng.gen::<f32>() * 2.0 - 1.0)
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn normalize_vector(v: &[f32]) -> Vec<f32> {
|
||||
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
v.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_recall(ground_truth: &[String], results: &[String]) -> f32 {
|
||||
let gt_set: std::collections::HashSet<_> = ground_truth.iter().collect();
|
||||
let found = results.iter().filter(|id| gt_set.contains(id)).count();
|
||||
found as f32 / ground_truth.len() as f32
|
||||
}
|
||||
|
||||
fn brute_force_search(
|
||||
query: &[f32],
|
||||
vectors: &[(String, Vec<f32>)],
|
||||
k: usize,
|
||||
metric: DistanceMetric,
|
||||
) -> Vec<String> {
|
||||
use ruvector_core::distance::distance;
|
||||
|
||||
let mut distances: Vec<_> = vectors
|
||||
.iter()
|
||||
.map(|(id, v)| {
|
||||
let dist = distance(query, v, metric).unwrap();
|
||||
(id.clone(), dist)
|
||||
})
|
||||
.collect();
|
||||
|
||||
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
distances.into_iter().take(k).map(|(id, _)| id).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_100_vectors() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 100;
|
||||
let k = 10;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 200,
|
||||
max_elements: 1000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
// Generate and insert vectors
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 42);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
for (i, vector) in normalized_vectors.iter().enumerate() {
|
||||
index.add(format!("vec_{}", i), vector.clone())?;
|
||||
}
|
||||
|
||||
assert_eq!(index.len(), num_vectors);
|
||||
|
||||
// Test search accuracy with multiple queries
|
||||
let num_queries = 10;
|
||||
let mut total_recall = 0.0;
|
||||
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * (num_vectors / num_queries);
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
// Get HNSW results
|
||||
let results = index.search(query, k)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
// Get ground truth with brute force
|
||||
let vectors_with_ids: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
|
||||
.collect();
|
||||
|
||||
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
|
||||
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
println!(
|
||||
"100 vectors - Average recall@{}: {:.2}%",
|
||||
k,
|
||||
avg_recall * 100.0
|
||||
);
|
||||
|
||||
// For small datasets, we expect very high recall
|
||||
assert!(
|
||||
avg_recall >= 0.90,
|
||||
"Recall should be at least 90% for 100 vectors"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_1k_vectors() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 1000;
|
||||
let k = 10;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 200,
|
||||
max_elements: 10000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
// Generate and insert vectors
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 12345);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
// Use batch insert for better performance
|
||||
let entries: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
assert_eq!(index.len(), num_vectors);
|
||||
|
||||
// Test search accuracy
|
||||
let num_queries = 20;
|
||||
let mut total_recall = 0.0;
|
||||
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * (num_vectors / num_queries);
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
let results = index.search(query, k)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
let vectors_with_ids: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
|
||||
.collect();
|
||||
|
||||
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
println!(
|
||||
"1K vectors - Average recall@{}: {:.2}%",
|
||||
k,
|
||||
avg_recall * 100.0
|
||||
);
|
||||
|
||||
// Should achieve at least 95% recall with ef_search=200
|
||||
assert!(
|
||||
avg_recall >= 0.95,
|
||||
"Recall should be at least 95% for 1K vectors with ef_search=200"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_10k_vectors() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 10000;
|
||||
let k = 10;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 200,
|
||||
max_elements: 100000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
println!("Generating {} vectors...", num_vectors);
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 98765);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
println!("Inserting vectors in batches...");
|
||||
// Insert in batches for better performance
|
||||
let batch_size = 1000;
|
||||
for batch_start in (0..num_vectors).step_by(batch_size) {
|
||||
let batch_end = (batch_start + batch_size).min(num_vectors);
|
||||
let entries: Vec<_> = normalized_vectors[batch_start..batch_end]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", batch_start + i), v.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
}
|
||||
|
||||
assert_eq!(index.len(), num_vectors);
|
||||
println!("Index built with {} vectors", index.len());
|
||||
|
||||
// Prepare all vectors for ground truth computation
|
||||
let all_vectors: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
// Test search accuracy with a sample of queries
|
||||
let num_queries = 20; // Reduced for faster testing
|
||||
let mut total_recall = 0.0;
|
||||
|
||||
println!("Running {} queries...", num_queries);
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * (num_vectors / num_queries);
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
let results = index.search(query, k)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
// Compare against all vectors for accurate ground truth
|
||||
let ground_truth = brute_force_search(query, &all_vectors, k, DistanceMetric::Cosine);
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
println!(
|
||||
"10K vectors - Average recall@{}: {:.2}%",
|
||||
k,
|
||||
avg_recall * 100.0
|
||||
);
|
||||
|
||||
// With ef_search=200 and m=32, we should achieve good recall
|
||||
assert!(
|
||||
avg_recall >= 0.70,
|
||||
"Recall should be at least 70% for 10K vectors, got {:.2}%",
|
||||
avg_recall * 100.0
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_ef_search_tuning() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 500;
|
||||
let k = 10;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 50, // Start with lower ef_search
|
||||
max_elements: 10000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 54321);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
let entries: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
|
||||
// Test different ef_search values
|
||||
let ef_values = vec![50, 100, 200, 500];
|
||||
|
||||
for ef in ef_values {
|
||||
let mut total_recall = 0.0;
|
||||
let num_queries = 10;
|
||||
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * 50;
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
let results = index.search_with_ef(query, k, ef)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
let vectors_with_ids: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
|
||||
.collect();
|
||||
|
||||
let ground_truth =
|
||||
brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
println!(
|
||||
"ef_search={} - Average recall@{}: {:.2}%",
|
||||
ef,
|
||||
k,
|
||||
avg_recall * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
// Verify that ef_search=200 achieves at least 95% recall
|
||||
let mut total_recall = 0.0;
|
||||
let num_queries = 10;
|
||||
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * 50;
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
let results = index.search_with_ef(query, k, 200)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
let vectors_with_ids: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
|
||||
.collect();
|
||||
|
||||
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
assert!(
|
||||
avg_recall >= 0.95,
|
||||
"ef_search=200 should achieve at least 95% recall"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_serialization_large() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 500;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 100,
|
||||
max_elements: 10000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 11111);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
let entries: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
|
||||
// Serialize
|
||||
println!("Serializing index with {} vectors...", num_vectors);
|
||||
let bytes = index.serialize()?;
|
||||
println!(
|
||||
"Serialized size: {} bytes ({:.2} KB)",
|
||||
bytes.len(),
|
||||
bytes.len() as f32 / 1024.0
|
||||
);
|
||||
|
||||
// Deserialize
|
||||
println!("Deserializing index...");
|
||||
let restored_index = HnswIndex::deserialize(&bytes)?;
|
||||
|
||||
assert_eq!(restored_index.len(), num_vectors);
|
||||
|
||||
// Test that search works on restored index
|
||||
let query = &normalized_vectors[0];
|
||||
let original_results = index.search(query, 10)?;
|
||||
let restored_results = restored_index.search(query, 10)?;
|
||||
|
||||
// Results should be identical
|
||||
assert_eq!(original_results.len(), restored_results.len());
|
||||
|
||||
println!("Serialization test passed!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_different_metrics() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 200;
|
||||
let k = 5;
|
||||
|
||||
// Note: DotProduct can produce negative distances on normalized vectors,
|
||||
// which causes issues with the underlying hnsw_rs library.
|
||||
// We test Cosine and Euclidean which are the most commonly used metrics.
|
||||
let metrics = vec![DistanceMetric::Cosine, DistanceMetric::Euclidean];
|
||||
|
||||
for metric in metrics {
|
||||
println!("Testing metric: {:?}", metric);
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 100,
|
||||
max_elements: 1000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, metric, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 99999);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
for (i, vector) in normalized_vectors.iter().enumerate() {
|
||||
index.add(format!("vec_{}", i), vector.clone())?;
|
||||
}
|
||||
|
||||
// Test search
|
||||
let query = &normalized_vectors[0];
|
||||
let results = index.search(query, k)?;
|
||||
|
||||
assert!(!results.is_empty());
|
||||
println!(" Found {} results for metric {:?}", results.len(), metric);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_parallel_batch_insert() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 2000;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 100,
|
||||
max_elements: 10000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 77777);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
let entries: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
// Time the batch insert
|
||||
let start = std::time::Instant::now();
|
||||
index.add_batch(entries)?;
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!("Batch inserted {} vectors in {:?}", num_vectors, duration);
|
||||
println!(
|
||||
"Throughput: {:.0} vectors/sec",
|
||||
num_vectors as f64 / duration.as_secs_f64()
|
||||
);
|
||||
|
||||
assert_eq!(index.len(), num_vectors);
|
||||
|
||||
// Verify search still works
|
||||
let query = &normalized_vectors[0];
|
||||
let results = index.search(query, 10)?;
|
||||
assert!(!results.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
453
crates/ruvector-core/tests/integration_tests.rs
Normal file
453
crates/ruvector-core/tests/integration_tests.rs
Normal file
@@ -0,0 +1,453 @@
|
||||
//! Integration tests for end-to-end workflows
|
||||
//!
|
||||
//! These tests verify that all components work together correctly.
|
||||
|
||||
use ruvector_core::types::{DbOptions, DistanceMetric, HnswConfig, SearchQuery};
|
||||
use ruvector_core::{VectorDB, VectorEntry};
|
||||
use std::collections::HashMap;
|
||||
use tempfile::tempdir;
|
||||
|
||||
// ============================================================================
|
||||
// End-to-End Workflow Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_complete_insert_search_workflow() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
options.distance_metric = DistanceMetric::Cosine;
|
||||
options.hnsw_config = Some(HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 50,
|
||||
max_elements: 100_000,
|
||||
});
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert training data
|
||||
let vectors: Vec<VectorEntry> = (0..100)
|
||||
.map(|i| {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("index".to_string(), serde_json::json!(i));
|
||||
|
||||
VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..128).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: Some(metadata),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ids = db.insert_batch(vectors).unwrap();
|
||||
assert_eq!(ids.len(), 100);
|
||||
|
||||
// Search for similar vectors
|
||||
let query: Vec<f32> = (0..128).map(|j| (j as f32) * 0.01).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: Some(100),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
assert!(results[0].vector.is_some());
|
||||
assert!(results[0].metadata.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_operations_10k_vectors() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 384;
|
||||
options.distance_metric = DistanceMetric::Euclidean;
|
||||
options.hnsw_config = Some(HnswConfig::default());
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Generate 10K vectors
|
||||
println!("Generating 10K vectors...");
|
||||
let vectors: Vec<VectorEntry> = (0..10_000)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..384).map(|j| ((i + j) as f32) * 0.001).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Batch insert
|
||||
println!("Batch inserting 10K vectors...");
|
||||
let start = std::time::Instant::now();
|
||||
let ids = db.insert_batch(vectors).unwrap();
|
||||
let duration = start.elapsed();
|
||||
println!("Batch insert took: {:?}", duration);
|
||||
|
||||
assert_eq!(ids.len(), 10_000);
|
||||
assert_eq!(db.len().unwrap(), 10_000);
|
||||
|
||||
// Perform multiple searches
|
||||
println!("Performing searches...");
|
||||
for i in 0..10 {
|
||||
let query: Vec<f32> = (0..384).map(|j| ((i * 100 + j) as f32) * 0.001).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_persistence_and_reload() {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir
|
||||
.path()
|
||||
.join("persistent.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
// Create and populate database
|
||||
{
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = db_path.clone();
|
||||
options.dimensions = 3;
|
||||
options.hnsw_config = None; // Use flat index for simpler persistence test
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32, (i * 2) as f32, (i * 3) as f32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(db.len().unwrap(), 10);
|
||||
}
|
||||
|
||||
// Reload database
|
||||
{
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = db_path.clone();
|
||||
options.dimensions = 3;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Verify data persisted
|
||||
assert_eq!(db.len().unwrap(), 10);
|
||||
|
||||
let entry = db.get("vec_5").unwrap().unwrap();
|
||||
assert_eq!(entry.vector, vec![5.0, 10.0, 15.0]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_operations_workflow() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert initial batch
|
||||
let initial: Vec<VectorEntry> = (0..50)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(initial).unwrap();
|
||||
assert_eq!(db.len().unwrap(), 50);
|
||||
|
||||
// Delete some vectors
|
||||
for i in 0..10 {
|
||||
db.delete(&format!("vec_{}", i)).unwrap();
|
||||
}
|
||||
assert_eq!(db.len().unwrap(), 40);
|
||||
|
||||
// Insert more individual vectors
|
||||
for i in 50..60 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(db.len().unwrap(), 50);
|
||||
|
||||
// Search
|
||||
let query: Vec<f32> = (0..64).map(|j| (j as f32) * 0.1).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 20,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(results.len() > 0);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Different Distance Metrics
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_all_distance_metrics() {
|
||||
let metrics = vec![
|
||||
DistanceMetric::Euclidean,
|
||||
DistanceMetric::Cosine,
|
||||
DistanceMetric::DotProduct,
|
||||
DistanceMetric::Manhattan,
|
||||
];
|
||||
|
||||
for metric in metrics {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 32;
|
||||
options.distance_metric = metric;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert test vectors
|
||||
for i in 0..20 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Search
|
||||
let query: Vec<f32> = (0..32).map(|j| (j as f32) * 0.1).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 5,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 5, "Failed for metric {:?}", metric);
|
||||
|
||||
// Verify scores are in ascending order (lower is better for distance)
|
||||
for i in 0..results.len() - 1 {
|
||||
assert!(
|
||||
results[i].score <= results[i + 1].score,
|
||||
"Results not sorted for metric {:?}: {} > {}",
|
||||
metric,
|
||||
results[i].score,
|
||||
results[i + 1].score
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// HNSW Configuration Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_different_configurations() {
|
||||
let configs = vec![
|
||||
HnswConfig {
|
||||
m: 8,
|
||||
ef_construction: 50,
|
||||
ef_search: 50,
|
||||
max_elements: 1000,
|
||||
},
|
||||
HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 100,
|
||||
max_elements: 1000,
|
||||
},
|
||||
HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 200,
|
||||
max_elements: 1000,
|
||||
},
|
||||
];
|
||||
|
||||
for config in configs {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = Some(config.clone());
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert vectors
|
||||
let vectors: Vec<VectorEntry> = (0..100)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(vectors).unwrap();
|
||||
|
||||
// Search with different ef_search values
|
||||
let query: Vec<f32> = (0..64).map(|j| (j as f32) * 0.01).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: Some(config.ef_search),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10, "Failed for config M={}", config.m);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Metadata Filtering Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_complex_metadata_filtering() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 16;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert vectors with different categories and values
|
||||
for i in 0..50 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("category".to_string(), serde_json::json!(i % 3));
|
||||
metadata.insert("value".to_string(), serde_json::json!(i / 10));
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..16).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: Some(metadata),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Search with single filter
|
||||
let mut filter1 = HashMap::new();
|
||||
filter1.insert("category".to_string(), serde_json::json!(0));
|
||||
|
||||
let query: Vec<f32> = (0..16).map(|j| (j as f32) * 0.1).collect();
|
||||
let results1 = db
|
||||
.search(SearchQuery {
|
||||
vector: query.clone(),
|
||||
k: 100,
|
||||
filter: Some(filter1),
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Should only get vectors where i % 3 == 0
|
||||
for result in &results1 {
|
||||
let meta = result.metadata.as_ref().unwrap();
|
||||
assert_eq!(meta.get("category").unwrap(), &serde_json::json!(0));
|
||||
}
|
||||
|
||||
// Search with different filter
|
||||
let mut filter2 = HashMap::new();
|
||||
filter2.insert("value".to_string(), serde_json::json!(2));
|
||||
|
||||
let results2 = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 100,
|
||||
filter: Some(filter2),
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Should only get vectors where i / 10 == 2 (i.e., i in 20..30)
|
||||
for result in &results2 {
|
||||
let meta = result.metadata.as_ref().unwrap();
|
||||
assert_eq!(meta.get("value").unwrap(), &serde_json::json!(2));
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_dimension_validation() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Try to insert vector with wrong dimensions
|
||||
let result = db.insert(VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0], // Only 3 dimensions, should be 64
|
||||
metadata: None,
|
||||
});
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_with_wrong_dimension() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert some vectors
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v1".to_string()),
|
||||
vector: (0..64).map(|i| i as f32).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Try to search with wrong dimension query
|
||||
// Note: This might not error in the current implementation, but should be validated
|
||||
let query = vec![1.0, 2.0, 3.0]; // Wrong dimension
|
||||
let result = db.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
});
|
||||
|
||||
// Depending on implementation, this might error or return empty results
|
||||
// The important thing is it doesn't panic
|
||||
let _ = result;
|
||||
}
|
||||
345
crates/ruvector-core/tests/property_tests.rs
Normal file
345
crates/ruvector-core/tests/property_tests.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
//! Property-based tests using proptest
|
||||
//!
|
||||
//! These tests verify mathematical properties and invariants that should hold
|
||||
//! for all inputs within a given domain.
|
||||
|
||||
use proptest::prelude::*;
|
||||
use ruvector_core::distance::*;
|
||||
use ruvector_core::quantization::*;
|
||||
use ruvector_core::types::DistanceMetric;
|
||||
|
||||
// ============================================================================
|
||||
// Distance Metric Properties
|
||||
// ============================================================================
|
||||
|
||||
// Strategy to generate valid vectors with bounded values to prevent overflow
|
||||
// Using range that won't overflow when squared: sqrt(f32::MAX) ≈ 1.84e19
|
||||
// We use a more conservative range for numerical stability in distance calculations
|
||||
fn vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
|
||||
prop::collection::vec(-1000.0f32..1000.0f32, dim)
|
||||
}
|
||||
|
||||
// Strategy for normalized vectors (for cosine similarity)
|
||||
fn normalized_vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
|
||||
vector_strategy(dim).prop_map(move |v| {
|
||||
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
vec![1.0 / (dim as f32).sqrt(); dim]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
proptest! {
|
||||
// Property: Distance to self is zero
|
||||
#[test]
|
||||
fn test_euclidean_self_distance_zero(v in vector_strategy(128)) {
|
||||
let dist = euclidean_distance(&v, &v);
|
||||
prop_assert!(dist < 0.001, "Distance to self should be ~0, got {}", dist);
|
||||
}
|
||||
|
||||
// Property: Euclidean distance is symmetric
|
||||
#[test]
|
||||
fn test_euclidean_symmetry(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist_ab = euclidean_distance(&a, &b);
|
||||
let dist_ba = euclidean_distance(&b, &a);
|
||||
prop_assert!((dist_ab - dist_ba).abs() < 0.001, "Distance should be symmetric");
|
||||
}
|
||||
|
||||
// Property: Triangle inequality for Euclidean distance
|
||||
#[test]
|
||||
fn test_euclidean_triangle_inequality(
|
||||
a in vector_strategy(32),
|
||||
b in vector_strategy(32),
|
||||
c in vector_strategy(32)
|
||||
) {
|
||||
let dist_ab = euclidean_distance(&a, &b);
|
||||
let dist_bc = euclidean_distance(&b, &c);
|
||||
let dist_ac = euclidean_distance(&a, &c);
|
||||
|
||||
// d(a,c) <= d(a,b) + d(b,c)
|
||||
prop_assert!(
|
||||
dist_ac <= dist_ab + dist_bc + 0.01, // Small epsilon for floating point
|
||||
"Triangle inequality violated: {} > {} + {}",
|
||||
dist_ac, dist_ab, dist_bc
|
||||
);
|
||||
}
|
||||
|
||||
// Property: Non-negativity of Euclidean distance
|
||||
#[test]
|
||||
fn test_euclidean_non_negative(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist = euclidean_distance(&a, &b);
|
||||
prop_assert!(dist >= 0.0, "Distance must be non-negative, got {}", dist);
|
||||
}
|
||||
|
||||
// Property: Cosine distance symmetry
|
||||
#[test]
|
||||
fn test_cosine_symmetry(
|
||||
a in normalized_vector_strategy(64),
|
||||
b in normalized_vector_strategy(64)
|
||||
) {
|
||||
let dist_ab = cosine_distance(&a, &b);
|
||||
let dist_ba = cosine_distance(&b, &a);
|
||||
prop_assert!((dist_ab - dist_ba).abs() < 0.01, "Cosine distance should be symmetric");
|
||||
}
|
||||
|
||||
// Property: Cosine distance to self is zero
|
||||
#[test]
|
||||
fn test_cosine_self_distance(v in normalized_vector_strategy(64)) {
|
||||
let dist = cosine_distance(&v, &v);
|
||||
prop_assert!(dist < 0.01, "Cosine distance to self should be ~0, got {}", dist);
|
||||
}
|
||||
|
||||
// Property: Manhattan distance symmetry
|
||||
#[test]
|
||||
fn test_manhattan_symmetry(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist_ab = manhattan_distance(&a, &b);
|
||||
let dist_ba = manhattan_distance(&b, &a);
|
||||
prop_assert!((dist_ab - dist_ba).abs() < 0.001);
|
||||
}
|
||||
|
||||
// Property: Manhattan distance non-negativity
|
||||
#[test]
|
||||
fn test_manhattan_non_negative(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist = manhattan_distance(&a, &b);
|
||||
prop_assert!(dist >= 0.0, "Manhattan distance must be non-negative");
|
||||
}
|
||||
|
||||
// Property: Dot product symmetry
|
||||
#[test]
|
||||
fn test_dot_product_symmetry(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist_ab = dot_product_distance(&a, &b);
|
||||
let dist_ba = dot_product_distance(&b, &a);
|
||||
prop_assert!((dist_ab - dist_ba).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Quantization Round-Trip Properties
|
||||
// ============================================================================
|
||||
|
||||
proptest! {
|
||||
// Property: Scalar quantization round-trip preserves approximate values
|
||||
#[test]
|
||||
fn test_scalar_quantization_roundtrip(
|
||||
v in prop::collection::vec(0.0f32..100.0f32, 64)
|
||||
) {
|
||||
let quantized = ScalarQuantized::quantize(&v);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
prop_assert_eq!(v.len(), reconstructed.len());
|
||||
|
||||
// Check reconstruction error is bounded
|
||||
for (orig, recon) in v.iter().zip(reconstructed.iter()) {
|
||||
let error = (orig - recon).abs();
|
||||
let relative_error = if *orig != 0.0 {
|
||||
error / orig.abs()
|
||||
} else {
|
||||
error
|
||||
};
|
||||
prop_assert!(
|
||||
relative_error < 0.5 || error < 1.0,
|
||||
"Reconstruction error too large: {} vs {}",
|
||||
orig, recon
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Property: Binary quantization preserves signs
|
||||
#[test]
|
||||
fn test_binary_quantization_sign_preservation(
|
||||
v in prop::collection::vec(-10.0f32..10.0f32, 64)
|
||||
.prop_filter("No zeros", |v| v.iter().all(|x| *x != 0.0))
|
||||
) {
|
||||
let quantized = BinaryQuantized::quantize(&v);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in v.iter().zip(reconstructed.iter()) {
|
||||
prop_assert_eq!(
|
||||
orig.signum(),
|
||||
*recon,
|
||||
"Sign not preserved for {} -> {}",
|
||||
orig, recon
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Property: Binary quantization distance to self is zero
|
||||
#[test]
|
||||
fn test_binary_quantization_self_distance(
|
||||
v in prop::collection::vec(-10.0f32..10.0f32, 64)
|
||||
) {
|
||||
let quantized = BinaryQuantized::quantize(&v);
|
||||
let dist = quantized.distance(&quantized);
|
||||
prop_assert_eq!(dist, 0.0, "Distance to self should be 0");
|
||||
}
|
||||
|
||||
// Property: Binary quantization distance is symmetric
|
||||
#[test]
|
||||
fn test_binary_quantization_symmetry(
|
||||
a in prop::collection::vec(-10.0f32..10.0f32, 64),
|
||||
b in prop::collection::vec(-10.0f32..10.0f32, 64)
|
||||
) {
|
||||
let qa = BinaryQuantized::quantize(&a);
|
||||
let qb = BinaryQuantized::quantize(&b);
|
||||
|
||||
let dist_ab = qa.distance(&qb);
|
||||
let dist_ba = qb.distance(&qa);
|
||||
|
||||
prop_assert_eq!(dist_ab, dist_ba, "Distance should be symmetric");
|
||||
}
|
||||
|
||||
// Property: Binary quantization distance is bounded
|
||||
#[test]
|
||||
fn test_binary_quantization_distance_bounded(
|
||||
a in prop::collection::vec(-10.0f32..10.0f32, 64),
|
||||
b in prop::collection::vec(-10.0f32..10.0f32, 64)
|
||||
) {
|
||||
let qa = BinaryQuantized::quantize(&a);
|
||||
let qb = BinaryQuantized::quantize(&b);
|
||||
|
||||
let dist = qa.distance(&qb);
|
||||
|
||||
// Hamming distance for 64 bits should be in [0, 64]
|
||||
prop_assert!(dist >= 0.0 && dist <= 64.0, "Distance {} out of bounds", dist);
|
||||
}
|
||||
|
||||
// Property: Scalar quantization distance is non-negative
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance_non_negative(
|
||||
a in prop::collection::vec(0.0f32..100.0f32, 64),
|
||||
b in prop::collection::vec(0.0f32..100.0f32, 64)
|
||||
) {
|
||||
let qa = ScalarQuantized::quantize(&a);
|
||||
let qb = ScalarQuantized::quantize(&b);
|
||||
|
||||
let dist = qa.distance(&qb);
|
||||
prop_assert!(dist >= 0.0, "Distance must be non-negative");
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Vector Operation Properties
|
||||
// ============================================================================
|
||||
|
||||
proptest! {
|
||||
// Property: Scaling preserves direction for cosine similarity
|
||||
#[test]
|
||||
fn test_cosine_scale_invariance(
|
||||
v in normalized_vector_strategy(32),
|
||||
scale in 0.1f32..10.0f32
|
||||
) {
|
||||
let scaled: Vec<f32> = v.iter().map(|x| x * scale).collect();
|
||||
let dist_original = cosine_distance(&v, &v);
|
||||
let dist_scaled = cosine_distance(&v, &scaled);
|
||||
|
||||
// Cosine distance should be approximately the same (scale invariant)
|
||||
prop_assert!(
|
||||
(dist_original - dist_scaled).abs() < 0.1,
|
||||
"Cosine distance should be scale invariant: {} vs {}",
|
||||
dist_original, dist_scaled
|
||||
);
|
||||
}
|
||||
|
||||
// Property: Adding the same vector to both preserves distance
|
||||
#[test]
|
||||
fn test_euclidean_translation_invariance(
|
||||
a in vector_strategy(32),
|
||||
b in vector_strategy(32),
|
||||
offset in vector_strategy(32)
|
||||
) {
|
||||
let a_offset: Vec<f32> = a.iter().zip(&offset).map(|(x, o)| x + o).collect();
|
||||
let b_offset: Vec<f32> = b.iter().zip(&offset).map(|(x, o)| x + o).collect();
|
||||
|
||||
let dist_original = euclidean_distance(&a, &b);
|
||||
let dist_offset = euclidean_distance(&a_offset, &b_offset);
|
||||
|
||||
prop_assert!(
|
||||
(dist_original - dist_offset).abs() < 0.01,
|
||||
"Euclidean distance should be translation invariant"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Batch Operations Properties
|
||||
// ============================================================================
|
||||
|
||||
proptest! {
|
||||
// Property: Batch distance calculation consistency
|
||||
#[test]
|
||||
fn test_batch_distances_consistency(
|
||||
query in vector_strategy(32),
|
||||
vectors in prop::collection::vec(vector_strategy(32), 10..20)
|
||||
) {
|
||||
// Calculate distances in batch
|
||||
let batch_dists = batch_distances(&query, &vectors, DistanceMetric::Euclidean).unwrap();
|
||||
|
||||
// Calculate distances individually
|
||||
let individual_dists: Vec<f32> = vectors.iter()
|
||||
.map(|v| euclidean_distance(&query, v))
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(batch_dists.len(), individual_dists.len());
|
||||
|
||||
for (batch, individual) in batch_dists.iter().zip(individual_dists.iter()) {
|
||||
prop_assert!(
|
||||
(batch - individual).abs() < 0.01,
|
||||
"Batch and individual distances should match: {} vs {}",
|
||||
batch, individual
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Dimension Handling Properties
|
||||
// ============================================================================
|
||||
|
||||
proptest! {
|
||||
// Property: Distance calculation fails on dimension mismatch
|
||||
#[test]
|
||||
fn test_dimension_mismatch_error(
|
||||
dim1 in 1usize..100,
|
||||
dim2 in 1usize..100
|
||||
) {
|
||||
prop_assume!(dim1 != dim2); // Only test when dimensions differ
|
||||
|
||||
let a = vec![1.0f32; dim1];
|
||||
let b = vec![1.0f32; dim2];
|
||||
|
||||
let result = distance(&a, &b, DistanceMetric::Euclidean);
|
||||
prop_assert!(result.is_err(), "Should error on dimension mismatch");
|
||||
}
|
||||
|
||||
// Property: Distance calculation succeeds on matching dimensions
|
||||
#[test]
|
||||
fn test_dimension_match_success(
|
||||
dim in 1usize..200,
|
||||
a in prop::collection::vec(any::<f32>().prop_filter("Must be finite", |x| x.is_finite()), 1..200),
|
||||
b in prop::collection::vec(any::<f32>().prop_filter("Must be finite", |x| x.is_finite()), 1..200)
|
||||
) {
|
||||
// Ensure same dimensions
|
||||
let a_resized = vec![1.0f32; dim];
|
||||
let b_resized = vec![1.0f32; dim];
|
||||
|
||||
let result = distance(&a_resized, &b_resized, DistanceMetric::Euclidean);
|
||||
prop_assert!(result.is_ok(), "Should succeed on matching dimensions");
|
||||
}
|
||||
}
|
||||
486
crates/ruvector-core/tests/stress_tests.rs
Normal file
486
crates/ruvector-core/tests/stress_tests.rs
Normal file
@@ -0,0 +1,486 @@
|
||||
//! Stress tests for scalability, concurrency, and resilience
|
||||
//!
|
||||
//! These tests push the system to its limits to verify robustness.
|
||||
|
||||
use ruvector_core::types::{DbOptions, HnswConfig, SearchQuery};
|
||||
use ruvector_core::{VectorDB, VectorEntry};
|
||||
use std::sync::{Arc, Barrier};
|
||||
use std::thread;
|
||||
use tempfile::tempdir;
|
||||
|
||||
// ============================================================================
|
||||
// Large-Scale Insertion Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore] // Run with: cargo test --test stress_tests -- --ignored --test-threads=1
|
||||
fn test_million_vector_insertion() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("million.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
options.hnsw_config = Some(HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 50,
|
||||
max_elements: 2_000_000,
|
||||
});
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
println!("Starting million-vector insertion test...");
|
||||
let batch_size = 10_000;
|
||||
let num_batches = 100; // Total: 1M vectors
|
||||
|
||||
for batch_idx in 0..num_batches {
|
||||
println!("Inserting batch {}/{}...", batch_idx + 1, num_batches);
|
||||
|
||||
let vectors: Vec<VectorEntry> = (0..batch_size)
|
||||
.map(|i| {
|
||||
let global_idx = batch_idx * batch_size + i;
|
||||
VectorEntry {
|
||||
id: Some(format!("vec_{}", global_idx)),
|
||||
vector: (0..128)
|
||||
.map(|j| ((global_idx + j) as f32) * 0.0001)
|
||||
.collect(),
|
||||
metadata: None,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
db.insert_batch(vectors).unwrap();
|
||||
let duration = start.elapsed();
|
||||
println!("Batch {} took: {:?}", batch_idx + 1, duration);
|
||||
}
|
||||
|
||||
println!("Final database size: {}", db.len().unwrap());
|
||||
assert_eq!(db.len().unwrap(), 1_000_000);
|
||||
|
||||
// Perform some searches to verify functionality
|
||||
println!("Testing search on 1M vectors...");
|
||||
for i in 0..10 {
|
||||
let query: Vec<f32> = (0..128)
|
||||
.map(|j| ((i * 10000 + j) as f32) * 0.0001)
|
||||
.collect();
|
||||
let start = std::time::Instant::now();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: Some(50),
|
||||
})
|
||||
.unwrap();
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!(
|
||||
"Search {} took: {:?}, found {} results",
|
||||
i + 1,
|
||||
duration,
|
||||
results.len()
|
||||
);
|
||||
assert_eq!(results.len(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrent Query Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_queries() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = Some(HnswConfig::default());
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert test data
|
||||
println!("Inserting test data...");
|
||||
let vectors: Vec<VectorEntry> = (0..1000)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(vectors).unwrap();
|
||||
|
||||
// Spawn multiple threads doing concurrent searches
|
||||
println!("Starting 10 concurrent query threads...");
|
||||
let num_threads = 10;
|
||||
let queries_per_thread = 100;
|
||||
|
||||
let barrier = Arc::new(Barrier::new(num_threads));
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
// Wait for all threads to be ready
|
||||
barrier_clone.wait();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for i in 0..queries_per_thread {
|
||||
let query: Vec<f32> = (0..64)
|
||||
.map(|j| ((thread_id * 1000 + i + j) as f32) * 0.01)
|
||||
.collect();
|
||||
|
||||
let results = db_clone
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
println!(
|
||||
"Thread {} completed {} queries in {:?}",
|
||||
thread_id, queries_per_thread, duration
|
||||
);
|
||||
duration
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads and collect results
|
||||
let mut total_duration = std::time::Duration::ZERO;
|
||||
for handle in handles {
|
||||
let duration = handle.join().unwrap();
|
||||
total_duration += duration;
|
||||
}
|
||||
|
||||
let total_queries = num_threads * queries_per_thread;
|
||||
println!("Total queries: {}", total_queries);
|
||||
println!(
|
||||
"Average duration per thread: {:?}",
|
||||
total_duration / num_threads as u32
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_inserts_and_queries() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("mixed_concurrent.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 32;
|
||||
options.hnsw_config = Some(HnswConfig::default());
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Initial data
|
||||
let initial: Vec<VectorEntry> = (0..100)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("initial_{}", i)),
|
||||
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(initial).unwrap();
|
||||
|
||||
// Spawn reader threads
|
||||
let num_readers = 5;
|
||||
let num_writers = 2;
|
||||
let barrier = Arc::new(Barrier::new(num_readers + num_writers));
|
||||
let mut handles = vec![];
|
||||
|
||||
// Reader threads
|
||||
for reader_id in 0..num_readers {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
barrier_clone.wait();
|
||||
|
||||
for i in 0..50 {
|
||||
let query: Vec<f32> = (0..32)
|
||||
.map(|j| ((reader_id * 100 + i + j) as f32) * 0.1)
|
||||
.collect();
|
||||
let results = db_clone
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 5,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(results.len() > 0 && results.len() <= 5);
|
||||
}
|
||||
|
||||
println!("Reader {} completed", reader_id);
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Writer threads
|
||||
for writer_id in 0..num_writers {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
barrier_clone.wait();
|
||||
|
||||
for i in 0..20 {
|
||||
let entry = VectorEntry {
|
||||
id: Some(format!("writer_{}_{}", writer_id, i)),
|
||||
vector: (0..32)
|
||||
.map(|j| ((writer_id * 1000 + i + j) as f32) * 0.1)
|
||||
.collect(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
db_clone.insert(entry).unwrap();
|
||||
}
|
||||
|
||||
println!("Writer {} completed", writer_id);
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
let final_len = db.len().unwrap();
|
||||
println!("Final database size: {}", final_len);
|
||||
assert!(final_len >= 100); // At least initial data should remain
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Pressure Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore] // Run with: cargo test --test stress_tests -- --ignored
|
||||
fn test_memory_pressure_large_vectors() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("large_vectors.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 2048; // Very large vectors
|
||||
options.hnsw_config = Some(HnswConfig {
|
||||
m: 8,
|
||||
ef_construction: 50,
|
||||
ef_search: 50,
|
||||
max_elements: 100_000,
|
||||
});
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
println!("Testing with large 2048-dimensional vectors...");
|
||||
let num_vectors = 10_000;
|
||||
let batch_size = 1000;
|
||||
|
||||
for batch_idx in 0..(num_vectors / batch_size) {
|
||||
let vectors: Vec<VectorEntry> = (0..batch_size)
|
||||
.map(|i| {
|
||||
let global_idx = batch_idx * batch_size + i;
|
||||
VectorEntry {
|
||||
id: Some(format!("vec_{}", global_idx)),
|
||||
vector: (0..2048)
|
||||
.map(|j| ((global_idx + j) as f32) * 0.0001)
|
||||
.collect(),
|
||||
metadata: None,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(vectors).unwrap();
|
||||
println!(
|
||||
"Inserted batch {}/{}",
|
||||
batch_idx + 1,
|
||||
num_vectors / batch_size
|
||||
);
|
||||
}
|
||||
|
||||
println!("Database size: {}", db.len().unwrap());
|
||||
assert_eq!(db.len().unwrap(), num_vectors);
|
||||
|
||||
// Perform searches
|
||||
for i in 0..5 {
|
||||
let query: Vec<f32> = (0..2048)
|
||||
.map(|j| ((i * 1000 + j) as f32) * 0.0001)
|
||||
.collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Recovery Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_invalid_operations_dont_crash() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 32;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Try various invalid operations
|
||||
|
||||
// 1. Delete non-existent vector
|
||||
let _ = db.delete("nonexistent");
|
||||
|
||||
// 2. Get non-existent vector
|
||||
let _ = db.get("nonexistent");
|
||||
|
||||
// 3. Search with k=0
|
||||
let result = db.search(SearchQuery {
|
||||
vector: vec![0.0; 32],
|
||||
k: 0,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
});
|
||||
// Should either return empty or error gracefully
|
||||
let _ = result;
|
||||
|
||||
// 4. Insert and immediately delete in rapid succession
|
||||
for i in 0..100 {
|
||||
let id = db
|
||||
.insert(VectorEntry {
|
||||
id: Some(format!("temp_{}", i)),
|
||||
vector: vec![1.0; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
db.delete(&id).unwrap();
|
||||
}
|
||||
|
||||
// Database should still be functional
|
||||
db.insert(VectorEntry {
|
||||
id: Some("final".to_string()),
|
||||
vector: vec![1.0; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(db.get("final").unwrap().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_repeated_operations() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 16;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert the same ID multiple times (should replace or error)
|
||||
for _ in 0..10 {
|
||||
let _ = db.insert(VectorEntry {
|
||||
id: Some("same_id".to_string()),
|
||||
vector: vec![1.0; 16],
|
||||
metadata: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Delete the same ID multiple times
|
||||
for _ in 0..5 {
|
||||
let _ = db.delete("same_id");
|
||||
}
|
||||
|
||||
// Search repeatedly with the same query
|
||||
let query = vec![1.0; 16];
|
||||
for _ in 0..100 {
|
||||
let _ = db.search(SearchQuery {
|
||||
vector: query.clone(),
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Extreme Parameter Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_extreme_k_values() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 16;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert some vectors
|
||||
for i in 0..10 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 16],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Search with k larger than database size
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: vec![1.0; 16],
|
||||
k: 1000,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Should return at most 10 results
|
||||
assert!(results.len() <= 10);
|
||||
|
||||
// Search with k=1
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: vec![1.0; 16],
|
||||
k: 1,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
772
crates/ruvector-core/tests/test_memory_pool.rs
Normal file
772
crates/ruvector-core/tests/test_memory_pool.rs
Normal file
@@ -0,0 +1,772 @@
|
||||
//! Memory Pool and Allocation Tests
|
||||
//!
|
||||
//! This module tests the arena allocator and cache-optimized storage
|
||||
//! for correct memory management, eviction, and performance characteristics.
|
||||
|
||||
use ruvector_core::arena::{Arena, ArenaVec};
|
||||
use ruvector_core::cache_optimized::SoAVectorStorage;
|
||||
use std::sync::{Arc, Barrier};
|
||||
use std::thread;
|
||||
|
||||
// ============================================================================
|
||||
// Arena Allocator Tests
|
||||
// ============================================================================
|
||||
|
||||
mod arena_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arena_basic_allocation() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
|
||||
assert_eq!(vec.capacity(), 10);
|
||||
assert_eq!(vec.len(), 0);
|
||||
assert!(vec.is_empty());
|
||||
|
||||
vec.push(1.0);
|
||||
vec.push(2.0);
|
||||
vec.push(3.0);
|
||||
|
||||
assert_eq!(vec.len(), 3);
|
||||
assert!(!vec.is_empty());
|
||||
assert_eq!(vec[0], 1.0);
|
||||
assert_eq!(vec[1], 2.0);
|
||||
assert_eq!(vec[2], 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_multiple_allocations() {
|
||||
let arena = Arena::new(4096);
|
||||
|
||||
let vec1: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
let vec2: ArenaVec<f64> = arena.alloc_vec(50);
|
||||
let vec3: ArenaVec<u32> = arena.alloc_vec(200);
|
||||
let vec4: ArenaVec<i64> = arena.alloc_vec(75);
|
||||
|
||||
assert_eq!(vec1.capacity(), 100);
|
||||
assert_eq!(vec2.capacity(), 50);
|
||||
assert_eq!(vec3.capacity(), 200);
|
||||
assert_eq!(vec4.capacity(), 75);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_different_types() {
|
||||
let arena = Arena::new(2048);
|
||||
|
||||
// Allocate different types
|
||||
let mut floats: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
let mut doubles: ArenaVec<f64> = arena.alloc_vec(10);
|
||||
let mut ints: ArenaVec<i32> = arena.alloc_vec(10);
|
||||
let mut bytes: ArenaVec<u8> = arena.alloc_vec(10);
|
||||
|
||||
// Push values
|
||||
for i in 0..10 {
|
||||
floats.push(i as f32);
|
||||
doubles.push(i as f64);
|
||||
ints.push(i);
|
||||
bytes.push(i as u8);
|
||||
}
|
||||
|
||||
// Verify
|
||||
for i in 0..10 {
|
||||
assert_eq!(floats[i], i as f32);
|
||||
assert_eq!(doubles[i], i as f64);
|
||||
assert_eq!(ints[i], i as i32);
|
||||
assert_eq!(bytes[i], i as u8);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_reset() {
|
||||
let arena = Arena::new(4096);
|
||||
|
||||
// First allocation cycle
|
||||
{
|
||||
let mut vec1: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
let mut vec2: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
|
||||
for i in 0..50 {
|
||||
vec1.push(i as f32);
|
||||
vec2.push(i as f32 * 2.0);
|
||||
}
|
||||
}
|
||||
|
||||
let used_before = arena.used_bytes();
|
||||
assert!(used_before > 0, "Should have used some bytes");
|
||||
|
||||
arena.reset();
|
||||
|
||||
let used_after = arena.used_bytes();
|
||||
assert_eq!(used_after, 0, "Reset should set used bytes to 0");
|
||||
|
||||
// Allocated bytes should remain (memory is reused, not freed)
|
||||
let allocated = arena.allocated_bytes();
|
||||
assert!(allocated > 0, "Allocated bytes should remain after reset");
|
||||
|
||||
// Second allocation cycle - should reuse memory
|
||||
let mut vec3: ArenaVec<f32> = arena.alloc_vec(50);
|
||||
for i in 0..50 {
|
||||
vec3.push(i as f32);
|
||||
}
|
||||
|
||||
// Memory was reused
|
||||
assert!(
|
||||
arena.allocated_bytes() == allocated,
|
||||
"Should reuse existing allocation"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_chunk_growth() {
|
||||
// Small initial chunk size to force growth
|
||||
let arena = Arena::new(64);
|
||||
|
||||
// Allocate more than fits in one chunk
|
||||
let vec1: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
let vec2: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
let vec3: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
|
||||
assert_eq!(vec1.capacity(), 100);
|
||||
assert_eq!(vec2.capacity(), 100);
|
||||
assert_eq!(vec3.capacity(), 100);
|
||||
|
||||
// Should have allocated multiple chunks
|
||||
let allocated = arena.allocated_bytes();
|
||||
assert!(allocated > 64 * 3, "Should have grown beyond initial chunk");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_as_slice() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
|
||||
for i in 0..5 {
|
||||
vec.push((i * 10) as f32);
|
||||
}
|
||||
|
||||
let slice = vec.as_slice();
|
||||
assert_eq!(slice, &[0.0, 10.0, 20.0, 30.0, 40.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_as_mut_slice() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
|
||||
for i in 0..5 {
|
||||
vec.push((i * 10) as f32);
|
||||
}
|
||||
|
||||
{
|
||||
let slice = vec.as_mut_slice();
|
||||
slice[0] = 100.0;
|
||||
slice[4] = 500.0;
|
||||
}
|
||||
|
||||
assert_eq!(vec[0], 100.0);
|
||||
assert_eq!(vec[4], 500.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_deref() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
|
||||
vec.push(1.0);
|
||||
vec.push(2.0);
|
||||
vec.push(3.0);
|
||||
|
||||
// Test Deref trait (can use slice methods)
|
||||
assert_eq!(vec.len(), 3);
|
||||
assert_eq!(vec.iter().sum::<f32>(), 6.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_large_allocation() {
|
||||
let arena = Arena::new(1024);
|
||||
|
||||
// Allocate something larger than the chunk size
|
||||
let large_vec: ArenaVec<f32> = arena.alloc_vec(10000);
|
||||
assert_eq!(large_vec.capacity(), 10000);
|
||||
|
||||
// Should have grown to accommodate
|
||||
assert!(arena.allocated_bytes() >= 10000 * std::mem::size_of::<f32>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_statistics() {
|
||||
let arena = Arena::new(1024);
|
||||
|
||||
let initial_allocated = arena.allocated_bytes();
|
||||
let initial_used = arena.used_bytes();
|
||||
|
||||
assert_eq!(initial_allocated, 0);
|
||||
assert_eq!(initial_used, 0);
|
||||
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
|
||||
assert!(arena.allocated_bytes() > 0);
|
||||
assert!(arena.used_bytes() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "ArenaVec capacity exceeded")]
|
||||
fn test_arena_capacity_exceeded() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(5);
|
||||
|
||||
// Push more than capacity
|
||||
for i in 0..10 {
|
||||
vec.push(i as f32);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_with_default_chunk_size() {
|
||||
let arena = Arena::with_default_chunk_size();
|
||||
|
||||
// Default is 1MB
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(1000);
|
||||
assert!(arena.allocated_bytes() >= 1024 * 1024);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cache-Optimized Storage (SoA) Tests
|
||||
// ============================================================================
|
||||
|
||||
mod soa_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_soa_basic_operations() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
assert_eq!(storage.len(), 0);
|
||||
assert!(storage.is_empty());
|
||||
assert_eq!(storage.dimensions(), 3);
|
||||
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
storage.push(&[4.0, 5.0, 6.0]);
|
||||
|
||||
assert_eq!(storage.len(), 2);
|
||||
assert!(!storage.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_get_vector() {
|
||||
let mut storage = SoAVectorStorage::new(4, 8);
|
||||
|
||||
storage.push(&[1.0, 2.0, 3.0, 4.0]);
|
||||
storage.push(&[5.0, 6.0, 7.0, 8.0]);
|
||||
storage.push(&[9.0, 10.0, 11.0, 12.0]);
|
||||
|
||||
let mut output = vec![0.0; 4];
|
||||
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![1.0, 2.0, 3.0, 4.0]);
|
||||
|
||||
storage.get(1, &mut output);
|
||||
assert_eq!(output, vec![5.0, 6.0, 7.0, 8.0]);
|
||||
|
||||
storage.get(2, &mut output);
|
||||
assert_eq!(output, vec![9.0, 10.0, 11.0, 12.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_dimension_slice() {
|
||||
let mut storage = SoAVectorStorage::new(3, 8);
|
||||
|
||||
storage.push(&[1.0, 10.0, 100.0]);
|
||||
storage.push(&[2.0, 20.0, 200.0]);
|
||||
storage.push(&[3.0, 30.0, 300.0]);
|
||||
storage.push(&[4.0, 40.0, 400.0]);
|
||||
|
||||
// Dimension 0: all first elements
|
||||
let dim0 = storage.dimension_slice(0);
|
||||
assert_eq!(dim0, &[1.0, 2.0, 3.0, 4.0]);
|
||||
|
||||
// Dimension 1: all second elements
|
||||
let dim1 = storage.dimension_slice(1);
|
||||
assert_eq!(dim1, &[10.0, 20.0, 30.0, 40.0]);
|
||||
|
||||
// Dimension 2: all third elements
|
||||
let dim2 = storage.dimension_slice(2);
|
||||
assert_eq!(dim2, &[100.0, 200.0, 300.0, 400.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_dimension_slice_mut() {
|
||||
let mut storage = SoAVectorStorage::new(3, 8);
|
||||
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
storage.push(&[4.0, 5.0, 6.0]);
|
||||
|
||||
// Modify dimension 0
|
||||
{
|
||||
let dim0 = storage.dimension_slice_mut(0);
|
||||
dim0[0] = 100.0;
|
||||
dim0[1] = 400.0;
|
||||
}
|
||||
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![100.0, 2.0, 3.0]);
|
||||
|
||||
storage.get(1, &mut output);
|
||||
assert_eq!(output, vec![400.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_auto_growth() {
|
||||
// Start with small capacity
|
||||
let mut storage = SoAVectorStorage::new(4, 2);
|
||||
|
||||
// Push more vectors than initial capacity
|
||||
for i in 0..100 {
|
||||
storage.push(&[i as f32, (i * 2) as f32, (i * 3) as f32, (i * 4) as f32]);
|
||||
}
|
||||
|
||||
assert_eq!(storage.len(), 100);
|
||||
|
||||
// Verify all values are correct
|
||||
let mut output = vec![0.0; 4];
|
||||
for i in 0..100 {
|
||||
storage.get(i, &mut output);
|
||||
assert_eq!(
|
||||
output,
|
||||
vec![i as f32, (i * 2) as f32, (i * 3) as f32, (i * 4) as f32]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_batch_euclidean_distances() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
// Add orthogonal unit vectors
|
||||
storage.push(&[1.0, 0.0, 0.0]);
|
||||
storage.push(&[0.0, 1.0, 0.0]);
|
||||
storage.push(&[0.0, 0.0, 1.0]);
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let mut distances = vec![0.0; 3];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
// Distance to itself should be 0
|
||||
assert!(distances[0] < 0.001, "Distance to self should be ~0");
|
||||
|
||||
// Distance to orthogonal vectors should be sqrt(2)
|
||||
let sqrt2 = (2.0_f32).sqrt();
|
||||
assert!(
|
||||
(distances[1] - sqrt2).abs() < 0.01,
|
||||
"Expected sqrt(2), got {}",
|
||||
distances[1]
|
||||
);
|
||||
assert!(
|
||||
(distances[2] - sqrt2).abs() < 0.01,
|
||||
"Expected sqrt(2), got {}",
|
||||
distances[2]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_batch_distances_large() {
|
||||
let dim = 128;
|
||||
let num_vectors = 1000;
|
||||
|
||||
let mut storage = SoAVectorStorage::new(dim, 16);
|
||||
|
||||
// Add random-ish vectors
|
||||
for i in 0..num_vectors {
|
||||
let vec: Vec<f32> = (0..dim)
|
||||
.map(|j| ((i * dim + j) % 100) as f32 * 0.01)
|
||||
.collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
let query: Vec<f32> = (0..dim).map(|j| (j % 50) as f32 * 0.02).collect();
|
||||
let mut distances = vec![0.0; num_vectors];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
// Verify all distances are non-negative and finite
|
||||
for (i, &dist) in distances.iter().enumerate() {
|
||||
assert!(
|
||||
dist >= 0.0 && dist.is_finite(),
|
||||
"Distance {} is invalid: {}",
|
||||
i,
|
||||
dist
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_common_embedding_dimensions() {
|
||||
// Test common embedding dimensions
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536] {
|
||||
let mut storage = SoAVectorStorage::new(dim, 4);
|
||||
|
||||
let vec: Vec<f32> = (0..dim).map(|i| i as f32 * 0.001).collect();
|
||||
storage.push(&vec);
|
||||
|
||||
let mut output = vec![0.0; dim];
|
||||
storage.get(0, &mut output);
|
||||
|
||||
assert_eq!(output, vec);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "dimensions must be between")]
|
||||
fn test_soa_zero_dimensions() {
|
||||
let _ = SoAVectorStorage::new(0, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_soa_wrong_vector_length() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
storage.push(&[1.0, 2.0]); // Wrong dimension
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_soa_get_out_of_bounds() {
|
||||
let storage = SoAVectorStorage::new(3, 4);
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output); // No vectors added
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_soa_dimension_slice_out_of_bounds() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
let _ = storage.dimension_slice(5); // Invalid dimension
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Pressure Tests
|
||||
// ============================================================================
|
||||
|
||||
mod memory_pressure_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arena_many_small_allocations() {
|
||||
let arena = Arena::new(1024 * 1024); // 1MB
|
||||
|
||||
// Many small allocations
|
||||
for _ in 0..10000 {
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
}
|
||||
|
||||
// Should handle without issues
|
||||
assert!(arena.allocated_bytes() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_alternating_sizes() {
|
||||
let arena = Arena::new(4096);
|
||||
|
||||
for i in 0..100 {
|
||||
let size = if i % 2 == 0 { 10 } else { 1000 };
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(size);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_large_capacity() {
|
||||
let mut storage = SoAVectorStorage::new(128, 10000);
|
||||
|
||||
for i in 0..10000 {
|
||||
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.0001).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
assert_eq!(storage.len(), 10000);
|
||||
|
||||
// Verify random access
|
||||
let mut output = vec![0.0; 128];
|
||||
storage.get(5000, &mut output);
|
||||
assert!((output[0] - (5000 * 128) as f32 * 0.0001).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_batch_operations_under_pressure() {
|
||||
let dim = 512;
|
||||
let num_vectors = 5000;
|
||||
|
||||
let mut storage = SoAVectorStorage::new(dim, 128);
|
||||
|
||||
for i in 0..num_vectors {
|
||||
let vec: Vec<f32> = (0..dim).map(|j| ((i + j) % 1000) as f32 * 0.001).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
// Perform batch distance calculations
|
||||
let query: Vec<f32> = (0..dim).map(|j| (j % 500) as f32 * 0.002).collect();
|
||||
let mut distances = vec![0.0; num_vectors];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
// All distances should be valid
|
||||
for dist in &distances {
|
||||
assert!(dist.is_finite() && *dist >= 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrent Access Tests
|
||||
// ============================================================================
|
||||
|
||||
mod concurrent_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_soa_concurrent_reads() {
|
||||
// Create and populate storage
|
||||
let mut storage = SoAVectorStorage::new(64, 16);
|
||||
|
||||
for i in 0..1000 {
|
||||
let vec: Vec<f32> = (0..64).map(|j| (i * 64 + j) as f32 * 0.01).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
let storage = Arc::new(storage);
|
||||
let num_threads = 8;
|
||||
let barrier = Arc::new(Barrier::new(num_threads));
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let storage_clone = Arc::clone(&storage);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
barrier_clone.wait();
|
||||
|
||||
// Each thread performs many reads
|
||||
for i in 0..100 {
|
||||
let idx = (thread_id * 100 + i) % 1000;
|
||||
|
||||
// Read dimension slices
|
||||
let dim_slice = storage_clone.dimension_slice(idx % 64);
|
||||
assert!(!dim_slice.is_empty());
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_concurrent_batch_distances() {
|
||||
let mut storage = SoAVectorStorage::new(32, 16);
|
||||
|
||||
for i in 0..500 {
|
||||
let vec: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
let storage = Arc::new(storage);
|
||||
let num_threads = 4;
|
||||
let barrier = Arc::new(Barrier::new(num_threads));
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let storage_clone = Arc::clone(&storage);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
barrier_clone.wait();
|
||||
|
||||
for i in 0..50 {
|
||||
let query: Vec<f32> = (0..32)
|
||||
.map(|j| ((thread_id * 50 + i) * 32 + j) as f32 * 0.01)
|
||||
.collect();
|
||||
let mut distances = vec![0.0; 500];
|
||||
|
||||
storage_clone.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
// Verify results
|
||||
for dist in &distances {
|
||||
assert!(dist.is_finite());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases
|
||||
// ============================================================================
|
||||
|
||||
mod edge_cases {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_soa_single_vector() {
|
||||
let mut storage = SoAVectorStorage::new(3, 1);
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
|
||||
assert_eq!(storage.len(), 1);
|
||||
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![1.0, 2.0, 3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_single_dimension() {
|
||||
let mut storage = SoAVectorStorage::new(1, 4);
|
||||
|
||||
storage.push(&[1.0]);
|
||||
storage.push(&[2.0]);
|
||||
storage.push(&[3.0]);
|
||||
|
||||
let dim0 = storage.dimension_slice(0);
|
||||
assert_eq!(dim0, &[1.0, 2.0, 3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_exact_capacity() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(5);
|
||||
|
||||
// Fill to exactly capacity
|
||||
for i in 0..5 {
|
||||
vec.push(i as f32);
|
||||
}
|
||||
|
||||
assert_eq!(vec.len(), 5);
|
||||
assert_eq!(vec.capacity(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_zeros() {
|
||||
let mut storage = SoAVectorStorage::new(4, 4);
|
||||
|
||||
storage.push(&[0.0, 0.0, 0.0, 0.0]);
|
||||
storage.push(&[0.0, 0.0, 0.0, 0.0]);
|
||||
|
||||
let query = vec![0.0; 4];
|
||||
let mut distances = vec![0.0; 2];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
assert!(distances[0] < 1e-6);
|
||||
assert!(distances[1] < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_negative_values() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
storage.push(&[-1.0, -2.0, -3.0]);
|
||||
storage.push(&[-4.0, -5.0, -6.0]);
|
||||
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![-1.0, -2.0, -3.0]);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Characteristics Tests
|
||||
// ============================================================================
|
||||
|
||||
mod performance_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arena_allocation_performance() {
|
||||
// This test verifies that arena allocation is efficient
|
||||
let arena = Arena::new(1024 * 1024); // 1MB
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _ in 0..100000 {
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Should complete quickly (< 1 second for 100k allocations)
|
||||
assert!(
|
||||
duration.as_millis() < 1000,
|
||||
"Arena allocation took too long: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_dimension_access_pattern() {
|
||||
let mut storage = SoAVectorStorage::new(128, 16);
|
||||
|
||||
for i in 0..1000 {
|
||||
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
// Test dimension-wise access (this should be cache-efficient)
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for dim in 0..128 {
|
||||
let slice = storage.dimension_slice(dim);
|
||||
let _sum: f32 = slice.iter().sum();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Dimension-wise access should be fast due to cache locality
|
||||
assert!(
|
||||
duration.as_millis() < 100,
|
||||
"Dimension access took too long: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_batch_distance_performance() {
|
||||
let mut storage = SoAVectorStorage::new(128, 128);
|
||||
|
||||
for i in 0..1000 {
|
||||
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.001).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
let query: Vec<f32> = (0..128).map(|j| j as f32 * 0.001).collect();
|
||||
let mut distances = vec![0.0; 1000];
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _ in 0..100 {
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// 100 batch operations on 1000 vectors should be fast
|
||||
assert!(
|
||||
duration.as_millis() < 500,
|
||||
"Batch distance took too long: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
}
|
||||
767
crates/ruvector-core/tests/test_quantization.rs
Normal file
767
crates/ruvector-core/tests/test_quantization.rs
Normal file
@@ -0,0 +1,767 @@
|
||||
//! Quantization Accuracy Tests
|
||||
//!
|
||||
//! This module provides comprehensive tests for quantization techniques,
|
||||
//! verifying accuracy, compression ratios, and distance calculations.
|
||||
|
||||
use ruvector_core::quantization::*;
|
||||
|
||||
// ============================================================================
|
||||
// Scalar Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
mod scalar_quantization_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_basic() {
|
||||
let vector = vec![0.0, 0.5, 1.0, 1.5, 2.0];
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(quantized.data.len(), 5);
|
||||
assert!(quantized.scale > 0.0, "Scale should be positive");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_min_max() {
|
||||
let vector = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
|
||||
// Min should be -10.0
|
||||
assert!((quantized.min - (-10.0)).abs() < 0.001);
|
||||
|
||||
// Scale should map range 20 to 255
|
||||
let expected_scale = 20.0 / 255.0;
|
||||
assert!(
|
||||
(quantized.scale - expected_scale).abs() < 0.001,
|
||||
"Scale mismatch: expected {}, got {}",
|
||||
expected_scale,
|
||||
quantized.scale
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_reconstruction_accuracy() {
|
||||
let test_vectors = vec![
|
||||
vec![1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
vec![0.0, 0.25, 0.5, 0.75, 1.0],
|
||||
vec![-100.0, 0.0, 100.0],
|
||||
vec![0.001, 0.002, 0.003, 0.004, 0.005],
|
||||
];
|
||||
|
||||
for vector in test_vectors {
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(vector.len(), reconstructed.len());
|
||||
|
||||
// Calculate max error based on range
|
||||
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
|
||||
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let max_allowed_error = (max - min) / 128.0; // Allow 2 quantization steps error
|
||||
|
||||
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
|
||||
let error = (orig - recon).abs();
|
||||
assert!(
|
||||
error <= max_allowed_error,
|
||||
"Reconstruction error {} exceeds max {} for value {}",
|
||||
error,
|
||||
max_allowed_error,
|
||||
orig
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_constant_values() {
|
||||
let constant = vec![5.0, 5.0, 5.0, 5.0, 5.0];
|
||||
let quantized = ScalarQuantized::quantize(&constant);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in constant.iter().zip(reconstructed.iter()) {
|
||||
assert!(
|
||||
(orig - recon).abs() < 0.1,
|
||||
"Constant value reconstruction failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance_self() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
|
||||
let distance = quantized.distance(&quantized);
|
||||
assert!(
|
||||
distance < 0.001,
|
||||
"Distance to self should be ~0, got {}",
|
||||
distance
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance_symmetry() {
|
||||
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let v2 = vec![5.0, 4.0, 3.0, 2.0, 1.0];
|
||||
|
||||
let q1 = ScalarQuantized::quantize(&v1);
|
||||
let q2 = ScalarQuantized::quantize(&v2);
|
||||
|
||||
let dist_ab = q1.distance(&q2);
|
||||
let dist_ba = q2.distance(&q1);
|
||||
|
||||
assert!(
|
||||
(dist_ab - dist_ba).abs() < 0.1,
|
||||
"Distance not symmetric: {} vs {}",
|
||||
dist_ab,
|
||||
dist_ba
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance_triangle_inequality() {
|
||||
let v1 = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let v2 = vec![0.0, 1.0, 0.0, 0.0];
|
||||
let v3 = vec![0.0, 0.0, 1.0, 0.0];
|
||||
|
||||
let q1 = ScalarQuantized::quantize(&v1);
|
||||
let q2 = ScalarQuantized::quantize(&v2);
|
||||
let q3 = ScalarQuantized::quantize(&v3);
|
||||
|
||||
let d12 = q1.distance(&q2);
|
||||
let d23 = q2.distance(&q3);
|
||||
let d13 = q1.distance(&q3);
|
||||
|
||||
// Triangle inequality: d(1,3) <= d(1,2) + d(2,3)
|
||||
// Allow some slack for quantization errors
|
||||
assert!(
|
||||
d13 <= d12 + d23 + 0.5,
|
||||
"Triangle inequality violated: {} > {} + {}",
|
||||
d13,
|
||||
d12,
|
||||
d23
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_common_embedding_sizes() {
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let vector: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(quantized.data.len(), dim);
|
||||
assert_eq!(reconstructed.len(), dim);
|
||||
|
||||
// Verify compression ratio (4x for f32 -> u8)
|
||||
let original_size = dim * std::mem::size_of::<f32>();
|
||||
let quantized_size = quantized.data.len() + std::mem::size_of::<f32>() * 2; // data + min + scale
|
||||
assert!(
|
||||
quantized_size < original_size,
|
||||
"No compression achieved for dim {}",
|
||||
dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_extreme_values() {
|
||||
// Test with large values
|
||||
let large = vec![1e10, 2e10, 3e10];
|
||||
let quantized = ScalarQuantized::quantize(&large);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in large.iter().zip(reconstructed.iter()) {
|
||||
let relative_error = (orig - recon).abs() / orig.abs();
|
||||
assert!(
|
||||
relative_error < 0.02,
|
||||
"Large value reconstruction error too high: {}",
|
||||
relative_error
|
||||
);
|
||||
}
|
||||
|
||||
// Test with small values
|
||||
let small = vec![1e-5, 2e-5, 3e-5, 4e-5, 5e-5];
|
||||
let quantized = ScalarQuantized::quantize(&small);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in small.iter().zip(reconstructed.iter()) {
|
||||
let error = (orig - recon).abs();
|
||||
let range = 4e-5;
|
||||
assert!(
|
||||
error < range / 100.0,
|
||||
"Small value reconstruction error too high: {}",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_negative_values() {
|
||||
let negative = vec![-5.0, -4.0, -3.0, -2.0, -1.0];
|
||||
let quantized = ScalarQuantized::quantize(&negative);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in negative.iter().zip(reconstructed.iter()) {
|
||||
assert!(
|
||||
(orig - recon).abs() < 0.1,
|
||||
"Negative value reconstruction failed: {} vs {}",
|
||||
orig,
|
||||
recon
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Binary Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
mod binary_quantization_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_basic() {
|
||||
let vector = vec![1.0, -1.0, 0.5, -0.5, 0.1];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(quantized.dimensions, 5);
|
||||
assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_packing() {
|
||||
// Test byte packing
|
||||
for dim in 1..=32 {
|
||||
let vector: Vec<f32> = (0..dim)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
let expected_bytes = (dim + 7) / 8;
|
||||
assert_eq!(
|
||||
quantized.bits.len(),
|
||||
expected_bytes,
|
||||
"Wrong byte count for dim {}",
|
||||
dim
|
||||
);
|
||||
assert_eq!(quantized.dimensions, dim);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_sign_preservation() {
|
||||
let test_vectors = vec![
|
||||
vec![1.0, -1.0, 2.0, -2.0],
|
||||
vec![0.001, -0.001, 100.0, -100.0],
|
||||
vec![f32::MAX / 2.0, f32::MIN / 2.0],
|
||||
];
|
||||
|
||||
for vector in test_vectors {
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
|
||||
if *orig > 0.0 {
|
||||
assert_eq!(*recon, 1.0, "Positive value should reconstruct to 1.0");
|
||||
} else if *orig < 0.0 {
|
||||
assert_eq!(*recon, -1.0, "Negative value should reconstruct to -1.0");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_zero_handling() {
|
||||
let vector = vec![0.0, 0.0, 0.0, 0.0];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
// Zero maps to negative bit (0), which reconstructs to -1.0
|
||||
for val in reconstructed {
|
||||
assert_eq!(val, -1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_hamming_distance() {
|
||||
// Test specific Hamming distance cases
|
||||
let cases = vec![
|
||||
// (v1, v2, expected_distance)
|
||||
(vec![1.0, 1.0, 1.0, 1.0], vec![1.0, 1.0, 1.0, 1.0], 0.0), // identical
|
||||
(vec![1.0, 1.0, 1.0, 1.0], vec![-1.0, -1.0, -1.0, -1.0], 4.0), // opposite
|
||||
(vec![1.0, 1.0, -1.0, -1.0], vec![1.0, -1.0, -1.0, 1.0], 2.0), // 2 bits differ
|
||||
(vec![1.0, -1.0, 1.0, -1.0], vec![-1.0, 1.0, -1.0, 1.0], 4.0), // all differ
|
||||
];
|
||||
|
||||
for (v1, v2, expected) in cases {
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let distance = q1.distance(&q2);
|
||||
assert!(
|
||||
(distance - expected).abs() < 0.001,
|
||||
"Hamming distance mismatch: expected {}, got {}",
|
||||
expected,
|
||||
distance
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_distance_symmetry() {
|
||||
let v1 = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
|
||||
let v2 = vec![-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0];
|
||||
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let d12 = q1.distance(&q2);
|
||||
let d21 = q2.distance(&q1);
|
||||
|
||||
assert_eq!(d12, d21, "Binary distance should be symmetric");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_distance_bounds() {
|
||||
for dim in [8, 16, 32, 64, 128, 256] {
|
||||
let v1: Vec<f32> = (0..dim)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
let v2: Vec<f32> = (0..dim)
|
||||
.map(|i| if i % 3 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let distance = q1.distance(&q2);
|
||||
|
||||
// Distance should be in [0, dim]
|
||||
assert!(
|
||||
distance >= 0.0 && distance <= dim as f32,
|
||||
"Distance {} out of bounds [0, {}]",
|
||||
distance,
|
||||
dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_compression_ratio() {
|
||||
for dim in [128, 256, 512, 1024] {
|
||||
let vector: Vec<f32> = (0..dim)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
// f32 to 1 bit = theoretical 32x compression for data only
|
||||
// Actual ratio depends on overhead but should be significant
|
||||
let original_data_size = dim * std::mem::size_of::<f32>();
|
||||
let quantized_data_size = quantized.bits.len();
|
||||
|
||||
let data_compression_ratio = original_data_size as f32 / quantized_data_size as f32;
|
||||
assert!(
|
||||
data_compression_ratio >= 31.0,
|
||||
"Data compression ratio {} less than expected ~32x for dim {}",
|
||||
data_compression_ratio,
|
||||
dim
|
||||
);
|
||||
|
||||
// Verify bits.len() is correct: ceil(dim / 8)
|
||||
assert_eq!(quantized.bits.len(), (dim + 7) / 8);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_common_embedding_sizes() {
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let vector: Vec<f32> = (0..dim).map(|i| (i as f32 - dim as f32 / 2.0)).collect();
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(reconstructed.len(), dim);
|
||||
|
||||
// Check all values are +1 or -1
|
||||
for val in &reconstructed {
|
||||
assert!(*val == 1.0 || *val == -1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Product Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
mod product_quantization_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_training() {
|
||||
let vectors: Vec<Vec<f32>> = (0..100)
|
||||
.map(|i| (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect())
|
||||
.collect();
|
||||
|
||||
let num_subspaces = 4;
|
||||
let codebook_size = 16;
|
||||
|
||||
let pq = ProductQuantized::train(&vectors, num_subspaces, codebook_size, 10).unwrap();
|
||||
|
||||
assert_eq!(pq.codebooks.len(), num_subspaces);
|
||||
for codebook in &pq.codebooks {
|
||||
assert_eq!(codebook.len(), codebook_size);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_encode() {
|
||||
let vectors: Vec<Vec<f32>> = (0..100)
|
||||
.map(|i| (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect())
|
||||
.collect();
|
||||
|
||||
let num_subspaces = 4;
|
||||
let codebook_size = 16;
|
||||
|
||||
let pq = ProductQuantized::train(&vectors, num_subspaces, codebook_size, 10).unwrap();
|
||||
|
||||
let test_vector: Vec<f32> = (0..32).map(|i| i as f32 * 0.02).collect();
|
||||
let codes = pq.encode(&test_vector);
|
||||
|
||||
assert_eq!(codes.len(), num_subspaces);
|
||||
for code in &codes {
|
||||
assert!(*code < codebook_size as u8);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_empty_input_error() {
|
||||
let result = ProductQuantized::train(&[], 4, 16, 10);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_codebook_size_limit() {
|
||||
let vectors: Vec<Vec<f32>> = (0..10)
|
||||
.map(|i| (0..16).map(|j| (i * 16 + j) as f32).collect())
|
||||
.collect();
|
||||
|
||||
// Codebook size > 256 should error
|
||||
let result = ProductQuantized::train(&vectors, 4, 300, 10);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_various_subspaces() {
|
||||
let dim = 64;
|
||||
let vectors: Vec<Vec<f32>> = (0..200)
|
||||
.map(|i| (0..dim).map(|j| (i * dim + j) as f32 * 0.001).collect())
|
||||
.collect();
|
||||
|
||||
for num_subspaces in [1, 2, 4, 8, 16] {
|
||||
let pq = ProductQuantized::train(&vectors, num_subspaces, 16, 5).unwrap();
|
||||
|
||||
assert_eq!(pq.codebooks.len(), num_subspaces);
|
||||
|
||||
let subspace_dim = dim / num_subspaces;
|
||||
for codebook in &pq.codebooks {
|
||||
for centroid in codebook {
|
||||
assert_eq!(centroid.len(), subspace_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Comparative Tests
|
||||
// ============================================================================
|
||||
|
||||
mod comparative_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_vs_binary_reconstruction() {
|
||||
let vector = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0];
|
||||
|
||||
let scalar = ScalarQuantized::quantize(&vector);
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
|
||||
let scalar_recon = scalar.reconstruct();
|
||||
let binary_recon = binary.reconstruct();
|
||||
|
||||
// Scalar should have better accuracy
|
||||
let scalar_error: f32 = vector
|
||||
.iter()
|
||||
.zip(scalar_recon.iter())
|
||||
.map(|(o, r)| (o - r).abs())
|
||||
.sum::<f32>()
|
||||
/ vector.len() as f32;
|
||||
|
||||
// Binary only preserves sign
|
||||
for (orig, recon) in vector.iter().zip(binary_recon.iter()) {
|
||||
assert_eq!(orig.signum(), recon.signum());
|
||||
}
|
||||
|
||||
// Scalar error should be small
|
||||
assert!(
|
||||
scalar_error < 0.5,
|
||||
"Scalar reconstruction error {} too high",
|
||||
scalar_error
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_preserves_relative_ordering() {
|
||||
// Test that vectors closest in original space are also closest in quantized space
|
||||
let v1 = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let v2 = vec![0.9, 0.1, 0.0, 0.0]; // close to v1
|
||||
let v3 = vec![0.0, 0.0, 0.0, 1.0]; // far from v1
|
||||
|
||||
// For scalar quantization
|
||||
let q1_s = ScalarQuantized::quantize(&v1);
|
||||
let q2_s = ScalarQuantized::quantize(&v2);
|
||||
let q3_s = ScalarQuantized::quantize(&v3);
|
||||
|
||||
let d12_s = q1_s.distance(&q2_s);
|
||||
let d13_s = q1_s.distance(&q3_s);
|
||||
|
||||
// v2 should be closer to v1 than v3
|
||||
assert!(
|
||||
d12_s < d13_s,
|
||||
"Scalar: v2 should be closer to v1 than v3: {} vs {}",
|
||||
d12_s,
|
||||
d13_s
|
||||
);
|
||||
|
||||
// For binary quantization
|
||||
let q1_b = BinaryQuantized::quantize(&v1);
|
||||
let q2_b = BinaryQuantized::quantize(&v2);
|
||||
let q3_b = BinaryQuantized::quantize(&v3);
|
||||
|
||||
let d12_b = q1_b.distance(&q2_b);
|
||||
let d13_b = q1_b.distance(&q3_b);
|
||||
|
||||
// Same relative ordering should hold
|
||||
assert!(
|
||||
d12_b <= d13_b,
|
||||
"Binary: v2 should be at most as far as v3: {} vs {}",
|
||||
d12_b,
|
||||
d13_b
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratios() {
|
||||
let dim = 512;
|
||||
let vector: Vec<f32> = (0..dim).map(|i| i as f32 * 0.01).collect();
|
||||
|
||||
// Original size
|
||||
let original_size = dim * std::mem::size_of::<f32>(); // 2048 bytes
|
||||
|
||||
// Scalar quantization: u8 per element + 2 floats for min/scale
|
||||
let scalar = ScalarQuantized::quantize(&vector);
|
||||
let scalar_size = scalar.data.len() + 2 * std::mem::size_of::<f32>(); // ~520 bytes
|
||||
let scalar_ratio = original_size as f32 / scalar_size as f32;
|
||||
|
||||
// Binary quantization: 1 bit per element + usize for dimensions
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
let binary_size = binary.bits.len() + std::mem::size_of::<usize>(); // ~72 bytes
|
||||
let binary_ratio = original_size as f32 / binary_size as f32;
|
||||
|
||||
println!("Original: {} bytes", original_size);
|
||||
println!(
|
||||
"Scalar: {} bytes ({:.1}x compression)",
|
||||
scalar_size, scalar_ratio
|
||||
);
|
||||
println!(
|
||||
"Binary: {} bytes ({:.1}x compression)",
|
||||
binary_size, binary_ratio
|
||||
);
|
||||
|
||||
// Verify expected ratios
|
||||
assert!(scalar_ratio > 3.5, "Scalar should achieve ~4x compression");
|
||||
assert!(
|
||||
binary_ratio > 25.0,
|
||||
"Binary should achieve ~32x compression"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases and Error Handling
|
||||
// ============================================================================
|
||||
|
||||
mod edge_cases {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_single_element_vector() {
|
||||
let vector = vec![42.0];
|
||||
|
||||
let scalar = ScalarQuantized::quantize(&vector);
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(scalar.data.len(), 1);
|
||||
assert_eq!(binary.bits.len(), 1);
|
||||
assert_eq!(binary.dimensions, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_vector() {
|
||||
let dim = 8192;
|
||||
let vector: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
|
||||
|
||||
let scalar = ScalarQuantized::quantize(&vector);
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(scalar.data.len(), dim);
|
||||
assert_eq!(binary.dimensions, dim);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_positive() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = binary.reconstruct();
|
||||
|
||||
// All values should reconstruct to 1.0
|
||||
for val in reconstructed {
|
||||
assert_eq!(val, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_negative() {
|
||||
let vector = vec![-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0];
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = binary.reconstruct();
|
||||
|
||||
// All values should reconstruct to -1.0
|
||||
for val in reconstructed {
|
||||
assert_eq!(val, -1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_alternating_pattern() {
|
||||
let vector: Vec<f32> = (0..100)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = binary.reconstruct();
|
||||
|
||||
for (i, val) in reconstructed.iter().enumerate() {
|
||||
let expected = if i % 2 == 0 { 1.0 } else { -1.0 };
|
||||
assert_eq!(*val, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_deterministic() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
|
||||
// Quantize multiple times - should get same result
|
||||
let q1 = ScalarQuantized::quantize(&vector);
|
||||
let q2 = ScalarQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(q1.data, q2.data);
|
||||
assert_eq!(q1.min, q2.min);
|
||||
assert_eq!(q1.scale, q2.scale);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Characteristic Tests
|
||||
// ============================================================================
|
||||
|
||||
mod performance_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_speed() {
|
||||
let vector: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _ in 0..10000 {
|
||||
let _ = ScalarQuantized::quantize(&vector);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let ops_per_sec = 10000.0 / duration.as_secs_f64();
|
||||
|
||||
println!(
|
||||
"Scalar quantization: {:.0} ops/sec for 1024-dim vectors",
|
||||
ops_per_sec
|
||||
);
|
||||
|
||||
// Should be fast
|
||||
assert!(
|
||||
duration.as_millis() < 5000,
|
||||
"Scalar quantization too slow: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_speed() {
|
||||
let vector: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _ in 0..10000 {
|
||||
let _ = BinaryQuantized::quantize(&vector);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let ops_per_sec = 10000.0 / duration.as_secs_f64();
|
||||
|
||||
println!(
|
||||
"Binary quantization: {:.0} ops/sec for 1024-dim vectors",
|
||||
ops_per_sec
|
||||
);
|
||||
|
||||
// Should be fast
|
||||
assert!(
|
||||
duration.as_millis() < 5000,
|
||||
"Binary quantization too slow: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_calculation_speed() {
|
||||
let v1: Vec<f32> = (0..512).map(|i| i as f32 * 0.01).collect();
|
||||
let v2: Vec<f32> = (0..512).map(|i| (i as f32 * 0.01) + 0.5).collect();
|
||||
|
||||
let q1_s = ScalarQuantized::quantize(&v1);
|
||||
let q2_s = ScalarQuantized::quantize(&v2);
|
||||
|
||||
let q1_b = BinaryQuantized::quantize(&v1);
|
||||
let q2_b = BinaryQuantized::quantize(&v2);
|
||||
|
||||
// Scalar distance
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..100000 {
|
||||
let _ = q1_s.distance(&q2_s);
|
||||
}
|
||||
let scalar_duration = start.elapsed();
|
||||
|
||||
// Binary distance (Hamming)
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..100000 {
|
||||
let _ = q1_b.distance(&q2_b);
|
||||
}
|
||||
let binary_duration = start.elapsed();
|
||||
|
||||
println!("Scalar distance: {:?} for 100k ops", scalar_duration);
|
||||
println!("Binary distance: {:?} for 100k ops", binary_duration);
|
||||
|
||||
// Binary should be faster (just XOR and popcount)
|
||||
// But both should be fast
|
||||
assert!(scalar_duration.as_millis() < 1000);
|
||||
assert!(binary_duration.as_millis() < 1000);
|
||||
}
|
||||
}
|
||||
552
crates/ruvector-core/tests/test_simd_correctness.rs
Normal file
552
crates/ruvector-core/tests/test_simd_correctness.rs
Normal file
@@ -0,0 +1,552 @@
|
||||
//! SIMD Correctness Tests
|
||||
//!
|
||||
//! This module verifies that SIMD implementations produce identical results
|
||||
//! to scalar fallback implementations across various input sizes and edge cases.
|
||||
|
||||
use ruvector_core::simd_intrinsics::*;
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions for Scalar Computations (Ground Truth)
|
||||
// ============================================================================
|
||||
|
||||
fn scalar_euclidean(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| {
|
||||
let diff = x - y;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
|
||||
fn scalar_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm_a > f32::EPSILON && norm_b > f32::EPSILON {
|
||||
dot / (norm_a * norm_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
fn scalar_manhattan(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Euclidean Distance Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_vs_scalar_small() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-5,
|
||||
"Euclidean mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_vs_scalar_exact_simd_width() {
|
||||
// Test with exact AVX2 width (8 floats)
|
||||
let a: Vec<f32> = (0..8).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (0..8).map(|i| (i + 1) as f32).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-5,
|
||||
"8-element Euclidean mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_vs_scalar_non_aligned() {
|
||||
// Test with non-SIMD-aligned sizes
|
||||
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65, 127, 129] {
|
||||
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.01,
|
||||
"Size {} Euclidean mismatch: SIMD={}, scalar={}",
|
||||
size,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_vs_scalar_common_embedding_sizes() {
|
||||
// Test common embedding dimensions
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let a: Vec<f32> = (0..dim).map(|i| ((i % 100) as f32) * 0.01).collect();
|
||||
let b: Vec<f32> = (0..dim).map(|i| (((i + 50) % 100) as f32) * 0.01).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.1,
|
||||
"Dim {} Euclidean mismatch: SIMD={}, scalar={}",
|
||||
dim,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_identical_vectors() {
|
||||
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let result = euclidean_distance_simd(&v, &v);
|
||||
assert!(
|
||||
result < 1e-6,
|
||||
"Distance to self should be ~0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_zero_vectors() {
|
||||
let zeros = vec![0.0; 16];
|
||||
let result = euclidean_distance_simd(&zeros, &zeros);
|
||||
assert!(result < 1e-6, "Distance between zeros should be 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_negative_values() {
|
||||
let a = vec![-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0];
|
||||
let b = vec![-5.0, -6.0, -7.0, -8.0, -9.0, -10.0, -11.0, -12.0];
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-5,
|
||||
"Negative values Euclidean mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_mixed_signs() {
|
||||
let a = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0];
|
||||
let b = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0];
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"Mixed signs Euclidean mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Dot Product Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_vs_scalar_small() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let simd_result = dot_product_simd(&a, &b);
|
||||
let scalar_result = scalar_dot_product(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"Dot product mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_vs_scalar_exact_simd_width() {
|
||||
let a: Vec<f32> = (1..=8).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (1..=8).map(|i| i as f32).collect();
|
||||
|
||||
let simd_result = dot_product_simd(&a, &b);
|
||||
let scalar_result = scalar_dot_product(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"8-element dot product mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_vs_scalar_non_aligned() {
|
||||
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65, 127, 129] {
|
||||
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = dot_product_simd(&a, &b);
|
||||
let scalar_result = scalar_dot_product(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.1,
|
||||
"Size {} dot product mismatch: SIMD={}, scalar={}",
|
||||
size,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_common_embedding_sizes() {
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let a: Vec<f32> = (0..dim).map(|i| ((i % 10) as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (0..dim).map(|i| (((i + 5) % 10) as f32) * 0.1).collect();
|
||||
|
||||
let simd_result = dot_product_simd(&a, &b);
|
||||
let scalar_result = scalar_dot_product(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.5,
|
||||
"Dim {} dot product mismatch: SIMD={}, scalar={}",
|
||||
dim,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_orthogonal_vectors() {
|
||||
let a = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
|
||||
let result = dot_product_simd(&a, &b);
|
||||
assert!(result.abs() < 1e-6, "Orthogonal dot product should be 0");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cosine Similarity Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_vs_scalar_small() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let simd_result = cosine_similarity_simd(&a, &b);
|
||||
let scalar_result = scalar_cosine_similarity(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"Cosine mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_vs_scalar_non_aligned() {
|
||||
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65] {
|
||||
let a: Vec<f32> = (1..=size).map(|i| (i as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (1..=size).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = cosine_similarity_simd(&a, &b);
|
||||
let scalar_result = scalar_cosine_similarity(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.01,
|
||||
"Size {} cosine mismatch: SIMD={}, scalar={}",
|
||||
size,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_identical_vectors() {
|
||||
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let result = cosine_similarity_simd(&v, &v);
|
||||
assert!(
|
||||
(result - 1.0).abs() < 1e-5,
|
||||
"Identical vectors should have similarity 1.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_opposite_vectors() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![-1.0, -2.0, -3.0, -4.0];
|
||||
|
||||
let result = cosine_similarity_simd(&a, &b);
|
||||
assert!(
|
||||
(result + 1.0).abs() < 1e-5,
|
||||
"Opposite vectors should have similarity -1.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_orthogonal_vectors() {
|
||||
let a = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0, 0.0];
|
||||
|
||||
let result = cosine_similarity_simd(&a, &b);
|
||||
assert!(
|
||||
result.abs() < 1e-5,
|
||||
"Orthogonal vectors should have similarity 0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Manhattan Distance Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_simd_vs_scalar_small() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let simd_result = manhattan_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_manhattan(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"Manhattan mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_simd_vs_scalar_non_aligned() {
|
||||
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65] {
|
||||
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = manhattan_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_manhattan(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.01,
|
||||
"Size {} Manhattan mismatch: SIMD={}, scalar={}",
|
||||
size,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_simd_identical_vectors() {
|
||||
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let result = manhattan_distance_simd(&v, &v);
|
||||
assert!(
|
||||
result < 1e-6,
|
||||
"Manhattan to self should be 0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Numerical Stability Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_simd_large_values() {
|
||||
// Test with large but finite values
|
||||
let large_val = 1e10;
|
||||
let a: Vec<f32> = (0..16).map(|i| large_val + (i as f32)).collect();
|
||||
let b: Vec<f32> = (0..16).map(|i| large_val + (i as f32) + 1.0).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
simd_result.is_finite() && scalar_result.is_finite(),
|
||||
"Results should be finite for large values"
|
||||
);
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.1,
|
||||
"Large values mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_small_values() {
|
||||
// Test with small values
|
||||
let small_val = 1e-10;
|
||||
let a: Vec<f32> = (0..16).map(|i| small_val * (i as f32 + 1.0)).collect();
|
||||
let b: Vec<f32> = (0..16).map(|i| small_val * (i as f32 + 2.0)).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
simd_result.is_finite() && scalar_result.is_finite(),
|
||||
"Results should be finite for small values"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_denormalized_values() {
|
||||
// Test with denormalized floats
|
||||
let a = vec![f32::MIN_POSITIVE; 8];
|
||||
let b = vec![f32::MIN_POSITIVE * 2.0; 8];
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
simd_result.is_finite() && scalar_result.is_finite(),
|
||||
"Results should be finite for denormalized values"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Legacy Alias Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_legacy_avx2_aliases_match_simd() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let b = vec![9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0];
|
||||
|
||||
// Legacy AVX2 functions should produce same results as SIMD functions
|
||||
assert_eq!(
|
||||
euclidean_distance_avx2(&a, &b),
|
||||
euclidean_distance_simd(&a, &b)
|
||||
);
|
||||
assert_eq!(dot_product_avx2(&a, &b), dot_product_simd(&a, &b));
|
||||
assert_eq!(
|
||||
cosine_similarity_avx2(&a, &b),
|
||||
cosine_similarity_simd(&a, &b)
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Batch Operation Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_simd_batch_consistency() {
|
||||
let query: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
|
||||
let vectors: Vec<Vec<f32>> = (0..100)
|
||||
.map(|j| (0..64).map(|i| ((i + j) as f32) * 0.1).collect())
|
||||
.collect();
|
||||
|
||||
// Compute distances using SIMD
|
||||
let simd_distances: Vec<f32> = vectors
|
||||
.iter()
|
||||
.map(|v| euclidean_distance_simd(&query, v))
|
||||
.collect();
|
||||
|
||||
// Compute distances using scalar
|
||||
let scalar_distances: Vec<f32> = vectors
|
||||
.iter()
|
||||
.map(|v| scalar_euclidean(&query, v))
|
||||
.collect();
|
||||
|
||||
// Compare
|
||||
for (i, (simd, scalar)) in simd_distances
|
||||
.iter()
|
||||
.zip(scalar_distances.iter())
|
||||
.enumerate()
|
||||
{
|
||||
assert!(
|
||||
(simd - scalar).abs() < 0.01,
|
||||
"Vector {} mismatch: SIMD={}, scalar={}",
|
||||
i,
|
||||
simd,
|
||||
scalar
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Case Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_simd_single_element() {
|
||||
let a = vec![1.0];
|
||||
let b = vec![2.0];
|
||||
|
||||
let euclidean = euclidean_distance_simd(&a, &b);
|
||||
let dot = dot_product_simd(&a, &b);
|
||||
let manhattan = manhattan_distance_simd(&a, &b);
|
||||
|
||||
assert!((euclidean - 1.0).abs() < 1e-6);
|
||||
assert!((dot - 2.0).abs() < 1e-6);
|
||||
assert!((manhattan - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_two_elements() {
|
||||
let a = vec![1.0, 0.0];
|
||||
let b = vec![0.0, 1.0];
|
||||
|
||||
let euclidean = euclidean_distance_simd(&a, &b);
|
||||
let expected = (2.0_f32).sqrt(); // sqrt(1 + 1)
|
||||
|
||||
assert!(
|
||||
(euclidean - expected).abs() < 1e-5,
|
||||
"Two element test: got {}, expected {}",
|
||||
euclidean,
|
||||
expected
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Stress Tests for SIMD
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_simd_many_operations() {
|
||||
let a: Vec<f32> = (0..512).map(|i| (i as f32) * 0.001).collect();
|
||||
let b: Vec<f32> = (0..512).map(|i| ((i + 256) as f32) * 0.001).collect();
|
||||
|
||||
// Perform many operations to stress test
|
||||
for _ in 0..1000 {
|
||||
let _ = euclidean_distance_simd(&a, &b);
|
||||
let _ = dot_product_simd(&a, &b);
|
||||
let _ = cosine_similarity_simd(&a, &b);
|
||||
let _ = manhattan_distance_simd(&a, &b);
|
||||
}
|
||||
|
||||
// Final verification
|
||||
let result = euclidean_distance_simd(&a, &b);
|
||||
assert!(
|
||||
result.is_finite(),
|
||||
"Result should be finite after stress test"
|
||||
);
|
||||
}
|
||||
555
crates/ruvector-core/tests/unit_tests.rs
Normal file
555
crates/ruvector-core/tests/unit_tests.rs
Normal file
@@ -0,0 +1,555 @@
|
||||
//! Unit tests with mocking using mockall (London School TDD)
|
||||
//!
|
||||
//! These tests use mocks to isolate components and test behavior in isolation.
|
||||
|
||||
use mockall::mock;
|
||||
use mockall::predicate::*;
|
||||
use ruvector_core::error::{Result, RuvectorError};
|
||||
use ruvector_core::types::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ============================================================================
|
||||
// Mock Definitions
|
||||
// ============================================================================
|
||||
|
||||
// Mock for storage operations
|
||||
mock! {
|
||||
pub Storage {
|
||||
fn insert(&self, entry: &VectorEntry) -> Result<VectorId>;
|
||||
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>>;
|
||||
fn get(&self, id: &str) -> Result<Option<VectorEntry>>;
|
||||
fn delete(&self, id: &str) -> Result<bool>;
|
||||
fn len(&self) -> Result<usize>;
|
||||
fn is_empty(&self) -> Result<bool>;
|
||||
fn all_ids(&self) -> Result<Vec<VectorId>>;
|
||||
}
|
||||
}
|
||||
|
||||
// Mock for index operations
|
||||
mock! {
|
||||
pub Index {
|
||||
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()>;
|
||||
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()>;
|
||||
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
|
||||
fn remove(&mut self, id: &VectorId) -> Result<bool>;
|
||||
fn len(&self) -> usize;
|
||||
fn is_empty(&self) -> bool;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Distance Metric Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod distance_tests {
|
||||
use super::*;
|
||||
use ruvector_core::distance::*;
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_same_vector() {
|
||||
let v = vec![1.0, 2.0, 3.0];
|
||||
let dist = euclidean_distance(&v, &v);
|
||||
assert!(dist < 0.001, "Distance to self should be ~0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_orthogonal() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0];
|
||||
let dist = euclidean_distance(&a, &b);
|
||||
assert!(
|
||||
(dist - 1.414).abs() < 0.01,
|
||||
"Expected sqrt(2), got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_parallel_vectors() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![2.0, 4.0, 6.0]; // Parallel to a
|
||||
let dist = cosine_distance(&a, &b);
|
||||
assert!(
|
||||
dist < 0.01,
|
||||
"Parallel vectors should have ~0 cosine distance, got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_orthogonal() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0];
|
||||
let dist = cosine_distance(&a, &b);
|
||||
assert!(
|
||||
dist > 0.9 && dist < 1.1,
|
||||
"Orthogonal vectors should have distance ~1, got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_positive() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
let dist = dot_product_distance(&a, &b);
|
||||
assert!(
|
||||
dist < 0.0,
|
||||
"Dot product distance should be negative for similar vectors"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_distance() {
|
||||
let a = vec![0.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 1.0, 1.0];
|
||||
let dist = manhattan_distance(&a, &b);
|
||||
assert!(
|
||||
(dist - 3.0).abs() < 0.001,
|
||||
"Manhattan distance should be 3.0, got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch() {
|
||||
let a = vec![1.0, 2.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
let result = distance(&a, &b, DistanceMetric::Euclidean);
|
||||
assert!(result.is_err(), "Should error on dimension mismatch");
|
||||
|
||||
match result {
|
||||
Err(RuvectorError::DimensionMismatch { expected, actual }) => {
|
||||
assert_eq!(expected, 2);
|
||||
assert_eq!(actual, 3);
|
||||
}
|
||||
_ => panic!("Expected DimensionMismatch error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod quantization_tests {
|
||||
use super::*;
|
||||
use ruvector_core::quantization::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_reconstruction() {
|
||||
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let quantized = ScalarQuantized::quantize(&original);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(original.len(), reconstructed.len());
|
||||
|
||||
for (orig, recon) in original.iter().zip(reconstructed.iter()) {
|
||||
let error = (orig - recon).abs();
|
||||
assert!(error < 0.1, "Reconstruction error {} too large", error);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.1, 2.1, 3.1];
|
||||
|
||||
let qa = ScalarQuantized::quantize(&a);
|
||||
let qb = ScalarQuantized::quantize(&b);
|
||||
|
||||
let quantized_dist = qa.distance(&qb);
|
||||
assert!(quantized_dist >= 0.0, "Distance should be non-negative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_sign_preservation() {
|
||||
let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5, -0.5];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
|
||||
assert_eq!(orig.signum(), *recon, "Sign should be preserved");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_hamming() {
|
||||
let a = vec![1.0, 1.0, 1.0, 1.0];
|
||||
let b = vec![1.0, 1.0, -1.0, -1.0];
|
||||
|
||||
let qa = BinaryQuantized::quantize(&a);
|
||||
let qb = BinaryQuantized::quantize(&b);
|
||||
|
||||
let dist = qa.distance(&qb);
|
||||
assert_eq!(dist, 2.0, "Hamming distance should be 2.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_zero_distance() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
let dist = quantized.distance(&quantized);
|
||||
assert_eq!(dist, 0.0, "Distance to self should be 0");
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Storage Layer Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod storage_tests {
|
||||
use super::*;
|
||||
use ruvector_core::storage::VectorStorage;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_insert_with_explicit_id() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("explicit_id".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let id = storage.insert(&entry)?;
|
||||
assert_eq!(id, "explicit_id");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_auto_generates_id() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let id = storage.insert(&entry)?;
|
||||
assert!(!id.is_empty(), "Should generate a UUID");
|
||||
assert!(id.contains('-'), "Should be a valid UUID format");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_with_metadata() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("key".to_string(), serde_json::json!("value"));
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("meta_test".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: Some(metadata.clone()),
|
||||
};
|
||||
|
||||
storage.insert(&entry)?;
|
||||
let retrieved = storage.get("meta_test")?.unwrap();
|
||||
|
||||
assert_eq!(retrieved.metadata, Some(metadata));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch_error() {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3).unwrap();
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0], // Wrong dimension
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let result = storage.insert(&entry);
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(RuvectorError::DimensionMismatch { expected, actual }) => {
|
||||
assert_eq!(expected, 3);
|
||||
assert_eq!(actual, 2);
|
||||
}
|
||||
_ => panic!("Expected DimensionMismatch error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_nonexistent() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let result = storage.get("nonexistent")?;
|
||||
assert!(result.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_nonexistent() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let deleted = storage.delete("nonexistent")?;
|
||||
assert!(!deleted);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert_empty() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let ids = storage.insert_batch(&[])?;
|
||||
assert_eq!(ids.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert_dimension_mismatch() {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3).unwrap();
|
||||
|
||||
let entries = vec![
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
},
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0], // Wrong dimension
|
||||
metadata: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = storage.insert_batch(&entries);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should error on dimension mismatch in batch"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_ids_empty() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let ids = storage.all_ids()?;
|
||||
assert_eq!(ids.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_ids_after_insert() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
storage.insert(&VectorEntry {
|
||||
id: Some("id1".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
storage.insert(&VectorEntry {
|
||||
id: Some("id2".to_string()),
|
||||
vector: vec![4.0, 5.0, 6.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
let ids = storage.all_ids()?;
|
||||
assert_eq!(ids.len(), 2);
|
||||
assert!(ids.contains(&"id1".to_string()));
|
||||
assert!(ids.contains(&"id2".to_string()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// VectorDB Tests (High-level API)
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod vector_db_tests {
|
||||
use super::*;
|
||||
use ruvector_core::VectorDB;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_empty_database() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
assert!(db.is_empty()?);
|
||||
assert_eq!(db.len()?, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_updates_len() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(db.len()?, 1);
|
||||
assert!(!db.is_empty()?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_updates_len() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
let id = db.insert(VectorEntry {
|
||||
id: Some("test_id".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(db.len()?, 1);
|
||||
|
||||
let deleted = db.delete(&id)?;
|
||||
assert!(deleted);
|
||||
assert_eq!(db.len()?, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_empty_database() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
options.hnsw_config = None; // Use flat index
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(results.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_with_filter() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
// Insert vectors with metadata
|
||||
let mut meta1 = HashMap::new();
|
||||
meta1.insert("category".to_string(), serde_json::json!("A"));
|
||||
|
||||
let mut meta2 = HashMap::new();
|
||||
meta2.insert("category".to_string(), serde_json::json!("B"));
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v1".to_string()),
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
metadata: Some(meta1),
|
||||
})?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v2".to_string()),
|
||||
vector: vec![0.9, 0.1, 0.0],
|
||||
metadata: Some(meta2),
|
||||
})?;
|
||||
|
||||
// Search with filter
|
||||
let mut filter = HashMap::new();
|
||||
filter.insert("category".to_string(), serde_json::json!("A"));
|
||||
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
k: 10,
|
||||
filter: Some(filter),
|
||||
ef_search: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "v1");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
let entries = vec![
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
metadata: None,
|
||||
},
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![0.0, 1.0, 0.0],
|
||||
metadata: None,
|
||||
},
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![0.0, 0.0, 1.0],
|
||||
metadata: None,
|
||||
},
|
||||
];
|
||||
|
||||
let ids = db.insert_batch(entries)?;
|
||||
assert_eq!(ids.len(), 3);
|
||||
assert_eq!(db.len()?, 3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user