Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
101
crates/ruvector-core/Cargo.toml
Normal file
101
crates/ruvector-core/Cargo.toml
Normal file
@@ -0,0 +1,101 @@
|
||||
[package]
|
||||
name = "ruvector-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
repository.workspace = true
|
||||
readme = "README.md"
|
||||
description = "High-performance Rust vector database core with HNSW indexing"
|
||||
|
||||
[dependencies]
|
||||
# Core functionality
|
||||
redb = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true, optional = true }
|
||||
hnsw_rs = { workspace = true, optional = true }
|
||||
simsimd = { workspace = true, optional = true }
|
||||
rayon = { workspace = true, optional = true }
|
||||
crossbeam = { workspace = true, optional = true }
|
||||
|
||||
# Serialization
|
||||
rkyv = { workspace = true }
|
||||
bincode = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Error handling
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
# Math and numerics
|
||||
ndarray = { workspace = true, features = ["serde"] }
|
||||
rand = { workspace = true }
|
||||
rand_distr = { workspace = true }
|
||||
|
||||
# Performance
|
||||
dashmap = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
|
||||
# Time and UUID
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4"] }
|
||||
|
||||
# HTTP client for API embeddings (not available in WASM)
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { workspace = true }
|
||||
proptest = { workspace = true }
|
||||
mockall = { workspace = true }
|
||||
tempfile = "3.13"
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
name = "distance_metrics"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "hnsw_search"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "quantization_bench"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "batch_operations"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "comprehensive_bench"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "real_benchmark"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bench_simd"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "bench_memory"
|
||||
harness = false
|
||||
|
||||
[features]
|
||||
default = ["simd", "storage", "hnsw", "api-embeddings", "parallel"]
|
||||
simd = ["simsimd"] # SIMD acceleration (not available in WASM)
|
||||
parallel = ["rayon", "crossbeam"] # Parallel processing (not available in WASM)
|
||||
storage = ["redb", "memmap2"] # File-based storage (not available in WASM)
|
||||
hnsw = ["hnsw_rs"] # HNSW indexing (not available in WASM due to mmap dependency)
|
||||
memory-only = [] # Pure in-memory storage for WASM
|
||||
uuid-support = [] # Deprecated: uuid is now always included
|
||||
real-embeddings = [] # Feature flag for embedding provider API (use ApiEmbedding for production)
|
||||
api-embeddings = ["reqwest"] # API-based embeddings (not available in WASM)
|
||||
|
||||
[lib]
|
||||
crate-type = ["rlib"]
|
||||
bench = false
|
||||
471
crates/ruvector-core/README.md
Normal file
471
crates/ruvector-core/README.md
Normal file
@@ -0,0 +1,471 @@
|
||||
# Ruvector Core
|
||||
|
||||
[](https://crates.io/crates/ruvector-core)
|
||||
[](https://docs.rs/ruvector-core)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.rust-lang.org)
|
||||
|
||||
**The pure-Rust vector database engine behind RuVector -- HNSW indexing, quantization, and SIMD acceleration in a single crate.**
|
||||
|
||||
`ruvector-core` is the foundational library that powers the entire [RuVector](https://github.com/ruvnet/ruvector) ecosystem. It gives you a production-grade vector database you can embed directly into any Rust application: insert vectors, search them in under a millisecond, filter by metadata, and compress storage up to 32x -- all without external services. If you need vector search as a library instead of a server, this is the crate.
|
||||
|
||||
| | ruvector-core | Typical Vector Database |
|
||||
|---|---|---|
|
||||
| **Deployment** | Embed as a Rust dependency -- no server, no network calls | Run a separate service, manage connections |
|
||||
| **Query latency** | <0.5 ms p50 at 1M vectors with HNSW | ~1-5 ms depending on network and index |
|
||||
| **Memory compression** | Scalar (4x), Product (8-32x), Binary (32x) quantization built in | Often requires paid tiers or external tools |
|
||||
| **SIMD acceleration** | SimSIMD hardware-optimized distance calculations, automatic | Manual tuning or not available |
|
||||
| **Search modes** | Dense vectors, sparse BM25, hybrid, MMR diversity, filtered -- all in one API | Typically dense-only; hybrid and filtering are add-ons |
|
||||
| **Storage** | Zero-copy mmap with `redb` -- instant loading, no deserialization | Load time scales with dataset size |
|
||||
| **Concurrency** | Lock-free indexing with parallel batch processing via Rayon | Varies; many require single-writer locks |
|
||||
| **Dependencies** | Minimal -- pure Rust, compiles anywhere `rustc` runs | Often depends on C/C++ libraries (BLAS, LAPACK) |
|
||||
| **Cost** | Free forever -- open source (MIT) | Per-vector or per-query pricing on managed tiers |
|
||||
|
||||
## Installation
|
||||
|
||||
Add `ruvector-core` to your `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
ruvector-core = "0.1.0"
|
||||
```
|
||||
|
||||
### Feature Flags
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
ruvector-core = { version = "0.1.0", features = ["simd", "uuid-support"] }
|
||||
```
|
||||
|
||||
Available features:
|
||||
- `simd` (default): Enable SIMD-optimized distance calculations
|
||||
- `uuid-support` (default): Enable UUID generation for vector IDs
|
||||
|
||||
## Key Features
|
||||
|
||||
| Feature | What It Does | Why It Matters |
|
||||
|---------|-------------|----------------|
|
||||
| **HNSW Indexing** | Hierarchical Navigable Small World graphs for O(log n) approximate nearest neighbor search | Sub-millisecond queries at million-vector scale |
|
||||
| **Multiple Distance Metrics** | Euclidean, Cosine, Dot Product, Manhattan | Match the metric to your embedding model without conversion |
|
||||
| **Scalar Quantization** | Compress vectors to 8-bit integers (4x reduction) | Cut memory by 75% with 98% recall preserved |
|
||||
| **Product Quantization** | Split vectors into subspaces with codebooks (8-32x reduction) | Store millions of vectors on a single machine |
|
||||
| **Binary Quantization** | 1-bit representation (32x reduction) | Ultra-fast screening pass for massive datasets |
|
||||
| **SIMD Distance** | Hardware-accelerated distance via SimSIMD | Up to 80K QPS on 8 cores without code changes |
|
||||
| **Zero-Copy I/O** | Memory-mapped storage loads instantly | No deserialization step -- open a file and search immediately |
|
||||
| **Hybrid Search** | Combine dense vector similarity with sparse BM25 text scoring | One query handles both semantic and keyword matching |
|
||||
| **Metadata Filtering** | Apply key-value filters during search | No post-filtering needed -- results are already filtered |
|
||||
| **MMR Diversification** | Maximal Marginal Relevance re-ranking | Avoid redundant results when top-K are too similar |
|
||||
| **Conformal Prediction** | Uncertainty quantification on search results | Know when to trust (or distrust) a match |
|
||||
| **Lock-Free Indexing** | Concurrent reads and writes without blocking | High-throughput ingestion while serving queries |
|
||||
| **Batch Processing** | Parallel insert and search via Rayon | Saturate all cores for bulk operations |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```rust
|
||||
use ruvector_core::{VectorDB, DbOptions, VectorEntry, SearchQuery, DistanceMetric};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Create a new vector database
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 384; // Vector dimensions
|
||||
options.storage_path = "./my_vectors.db".to_string();
|
||||
options.distance_metric = DistanceMetric::Cosine;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
// Insert vectors
|
||||
db.insert(VectorEntry {
|
||||
id: Some("doc1".to_string()),
|
||||
vector: vec![0.1, 0.2, 0.3, /* ... 384 dimensions */],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("doc2".to_string()),
|
||||
vector: vec![0.4, 0.5, 0.6, /* ... 384 dimensions */],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
// Search for similar vectors
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![0.1, 0.2, 0.3, /* ... 384 dimensions */],
|
||||
k: 10, // Return top 10 results
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})?;
|
||||
|
||||
for result in results {
|
||||
println!("ID: {}, Score: {}", result.id, result.score);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
### Batch Operations
|
||||
|
||||
```rust
|
||||
use ruvector_core::{VectorDB, VectorEntry};
|
||||
|
||||
// Insert multiple vectors efficiently
|
||||
let entries = vec![
|
||||
VectorEntry {
|
||||
id: Some("doc1".to_string()),
|
||||
vector: vec![0.1, 0.2, 0.3],
|
||||
metadata: None,
|
||||
},
|
||||
VectorEntry {
|
||||
id: Some("doc2".to_string()),
|
||||
vector: vec![0.4, 0.5, 0.6],
|
||||
metadata: None,
|
||||
},
|
||||
];
|
||||
|
||||
let ids = db.insert_batch(entries)?;
|
||||
println!("Inserted {} vectors", ids.len());
|
||||
```
|
||||
|
||||
### With Metadata Filtering
|
||||
|
||||
```rust
|
||||
use std::collections::HashMap;
|
||||
use serde_json::json;
|
||||
|
||||
// Insert with metadata
|
||||
db.insert(VectorEntry {
|
||||
id: Some("product1".to_string()),
|
||||
vector: vec![0.1, 0.2, 0.3],
|
||||
metadata: Some(HashMap::from([
|
||||
("category".to_string(), json!("electronics")),
|
||||
("price".to_string(), json!(299.99)),
|
||||
])),
|
||||
})?;
|
||||
|
||||
// Search with metadata filter
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![0.1, 0.2, 0.3],
|
||||
k: 10,
|
||||
filter: Some(HashMap::from([
|
||||
("category".to_string(), json!("electronics")),
|
||||
])),
|
||||
ef_search: None,
|
||||
})?;
|
||||
```
|
||||
|
||||
### HNSW Configuration
|
||||
|
||||
```rust
|
||||
use ruvector_core::{DbOptions, HnswConfig, DistanceMetric};
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 384;
|
||||
options.distance_metric = DistanceMetric::Cosine;
|
||||
|
||||
// Configure HNSW index parameters
|
||||
options.hnsw_config = Some(HnswConfig {
|
||||
m: 32, // Connections per layer (16-64 typical)
|
||||
ef_construction: 200, // Build-time accuracy (100-500 typical)
|
||||
ef_search: 100, // Search-time accuracy (50-200 typical)
|
||||
max_elements: 10_000_000, // Maximum vectors
|
||||
});
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
```
|
||||
|
||||
### Quantization
|
||||
|
||||
```rust
|
||||
use ruvector_core::{DbOptions, QuantizationConfig};
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 384;
|
||||
|
||||
// Enable scalar quantization (4x compression)
|
||||
options.quantization = Some(QuantizationConfig::Scalar);
|
||||
|
||||
// Or product quantization (8-32x compression)
|
||||
options.quantization = Some(QuantizationConfig::Product {
|
||||
subspaces: 8, // Number of subspaces
|
||||
k: 256, // Codebook size
|
||||
});
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
```
|
||||
|
||||
## API Overview
|
||||
|
||||
### Core Types
|
||||
|
||||
```rust
|
||||
// Main database interface
|
||||
pub struct VectorDB { /* ... */ }
|
||||
|
||||
// Vector entry with optional ID and metadata
|
||||
pub struct VectorEntry {
|
||||
pub id: Option<VectorId>,
|
||||
pub vector: Vec<f32>,
|
||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
// Search query parameters
|
||||
pub struct SearchQuery {
|
||||
pub vector: Vec<f32>,
|
||||
pub k: usize,
|
||||
pub filter: Option<HashMap<String, serde_json::Value>>,
|
||||
pub ef_search: Option<usize>,
|
||||
}
|
||||
|
||||
// Search result with score
|
||||
pub struct SearchResult {
|
||||
pub id: VectorId,
|
||||
pub score: f32,
|
||||
pub vector: Option<Vec<f32>>,
|
||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
```
|
||||
|
||||
### Main Operations
|
||||
|
||||
```rust
|
||||
impl VectorDB {
|
||||
// Create new database with options
|
||||
pub fn new(options: DbOptions) -> Result<Self>;
|
||||
|
||||
// Create with just dimensions (uses defaults)
|
||||
pub fn with_dimensions(dimensions: usize) -> Result<Self>;
|
||||
|
||||
// Insert single vector
|
||||
pub fn insert(&self, entry: VectorEntry) -> Result<VectorId>;
|
||||
|
||||
// Insert multiple vectors
|
||||
pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>>;
|
||||
|
||||
// Search for similar vectors
|
||||
pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>>;
|
||||
|
||||
// Delete vector by ID
|
||||
pub fn delete(&self, id: &str) -> Result<bool>;
|
||||
|
||||
// Get vector by ID
|
||||
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>>;
|
||||
|
||||
// Get total count
|
||||
pub fn len(&self) -> Result<usize>;
|
||||
|
||||
// Check if empty
|
||||
pub fn is_empty(&self) -> Result<bool>;
|
||||
}
|
||||
```
|
||||
|
||||
### Distance Metrics
|
||||
|
||||
```rust
|
||||
pub enum DistanceMetric {
|
||||
Euclidean, // L2 distance - default for embeddings
|
||||
Cosine, // Cosine similarity (1 - similarity)
|
||||
DotProduct, // Negative dot product (for maximization)
|
||||
Manhattan, // L1 distance
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Features
|
||||
|
||||
```rust
|
||||
// Hybrid search (dense + sparse)
|
||||
use ruvector_core::{HybridSearch, HybridConfig};
|
||||
|
||||
let hybrid = HybridSearch::new(HybridConfig {
|
||||
alpha: 0.7, // Balance between dense (0.7) and sparse (0.3)
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Filtered search with expressions
|
||||
use ruvector_core::{FilteredSearch, FilterExpression};
|
||||
|
||||
let filtered = FilteredSearch::new(db);
|
||||
let expr = FilterExpression::And(vec![
|
||||
FilterExpression::Equals("category".to_string(), json!("books")),
|
||||
FilterExpression::GreaterThan("price".to_string(), json!(10.0)),
|
||||
]);
|
||||
|
||||
// MMR diversification
|
||||
use ruvector_core::{MMRSearch, MMRConfig};
|
||||
|
||||
let mmr = MMRSearch::new(MMRConfig {
|
||||
lambda: 0.5, // Balance relevance (0.5) and diversity (0.5)
|
||||
..Default::default()
|
||||
});
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### Latency (Single Query)
|
||||
|
||||
```
|
||||
Operation Flat Index HNSW Index
|
||||
---------------------------------------------
|
||||
Search (1K vecs) ~0.1ms ~0.2ms
|
||||
Search (100K vecs) ~10ms ~0.5ms
|
||||
Search (1M vecs) ~100ms <1ms
|
||||
Insert ~0.1ms ~1ms
|
||||
Batch (1000) ~50ms ~500ms
|
||||
```
|
||||
|
||||
### Memory Usage (1M Vectors, 384 Dimensions)
|
||||
|
||||
```
|
||||
Configuration Memory Recall
|
||||
---------------------------------------------
|
||||
Full Precision (f32) ~1.5GB 100%
|
||||
Scalar Quantization ~400MB 98%
|
||||
Product Quantization ~200MB 95%
|
||||
Binary Quantization ~50MB 85%
|
||||
```
|
||||
|
||||
### Throughput (Queries Per Second)
|
||||
|
||||
```
|
||||
Configuration QPS Latency (p50)
|
||||
-----------------------------------------------------
|
||||
Single Thread ~2,000 ~0.5ms
|
||||
Multi-Thread (8 cores) ~50,000 <0.5ms
|
||||
With SIMD ~80,000 <0.3ms
|
||||
With Quantization ~100,000 <0.2ms
|
||||
```
|
||||
|
||||
## Configuration Guide
|
||||
|
||||
### For Maximum Accuracy
|
||||
|
||||
```rust
|
||||
let options = DbOptions {
|
||||
dimensions: 384,
|
||||
distance_metric: DistanceMetric::Cosine,
|
||||
hnsw_config: Some(HnswConfig {
|
||||
m: 64,
|
||||
ef_construction: 500,
|
||||
ef_search: 200,
|
||||
max_elements: 10_000_000,
|
||||
}),
|
||||
quantization: None, // Full precision
|
||||
..Default::default()
|
||||
};
|
||||
```
|
||||
|
||||
### For Maximum Speed
|
||||
|
||||
```rust
|
||||
let options = DbOptions {
|
||||
dimensions: 384,
|
||||
distance_metric: DistanceMetric::DotProduct,
|
||||
hnsw_config: Some(HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 50,
|
||||
max_elements: 10_000_000,
|
||||
}),
|
||||
quantization: Some(QuantizationConfig::Binary),
|
||||
..Default::default()
|
||||
};
|
||||
```
|
||||
|
||||
### For Balanced Performance
|
||||
|
||||
```rust
|
||||
let options = DbOptions::default(); // Recommended defaults
|
||||
```
|
||||
|
||||
## Building and Testing
|
||||
|
||||
### Build
|
||||
|
||||
```bash
|
||||
# Build with default features
|
||||
cargo build --release
|
||||
|
||||
# Build without SIMD
|
||||
cargo build --release --no-default-features --features uuid-support
|
||||
|
||||
# Build for specific target with optimizations
|
||||
RUSTFLAGS="-C target-cpu=native" cargo build --release
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
cargo test
|
||||
|
||||
# Run with specific features
|
||||
cargo test --features simd
|
||||
|
||||
# Run with logging
|
||||
RUST_LOG=debug cargo test
|
||||
```
|
||||
|
||||
### Benchmarks
|
||||
|
||||
```bash
|
||||
# Run all benchmarks
|
||||
cargo bench
|
||||
|
||||
# Run specific benchmark
|
||||
cargo bench --bench hnsw_search
|
||||
|
||||
# Run with features
|
||||
cargo bench --features simd
|
||||
```
|
||||
|
||||
Available benchmarks:
|
||||
- `distance_metrics` - SIMD-optimized distance calculations
|
||||
- `hnsw_search` - HNSW index search performance
|
||||
- `quantization_bench` - Quantization techniques
|
||||
- `batch_operations` - Batch insert/search operations
|
||||
- `comprehensive_bench` - Full system benchmarks
|
||||
|
||||
## Related Crates
|
||||
|
||||
`ruvector-core` is the foundation for platform-specific bindings:
|
||||
|
||||
- **[ruvector-node](../ruvector-node/)** - Node.js bindings via NAPI-RS
|
||||
- **[ruvector-wasm](../ruvector-wasm/)** - WebAssembly bindings for browsers
|
||||
- **[ruvector-gnn](../ruvector-gnn/)** - Graph Neural Network layer for learned search
|
||||
- **[ruvector-cli](../ruvector-cli/)** - Command-line interface
|
||||
- **[ruvector-bench](../ruvector-bench/)** - Performance benchmarks
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[Main README](../../README.md)** - Complete project overview
|
||||
- **[Getting Started Guide](../../docs/guide/GETTING_STARTED.md)** - Quick start tutorial
|
||||
- **[Rust API Reference](../../docs/api/RUST_API.md)** - Detailed API documentation
|
||||
- **[Advanced Features Guide](../../docs/guide/ADVANCED_FEATURES.md)** - Quantization, indexing, tuning
|
||||
- **[Performance Tuning](../../docs/optimization/PERFORMANCE_TUNING_GUIDE.md)** - Optimization strategies
|
||||
- **[API Documentation](https://docs.rs/ruvector-core)** - Full API reference on docs.rs
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Built with state-of-the-art algorithms and libraries:
|
||||
|
||||
- **[hnsw_rs](https://crates.io/crates/hnsw_rs)** - HNSW implementation
|
||||
- **[simsimd](https://crates.io/crates/simsimd)** - SIMD distance calculations
|
||||
- **[redb](https://crates.io/crates/redb)** - Embedded database
|
||||
- **[rayon](https://crates.io/crates/rayon)** - Data parallelism
|
||||
- **[memmap2](https://crates.io/crates/memmap2)** - Memory-mapped files
|
||||
|
||||
## License
|
||||
|
||||
**MIT License** - see [LICENSE](../../LICENSE) for details.
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
|
||||
**Part of [RuVector](https://github.com/ruvnet/ruvector) - Built by [rUv](https://ruv.io)**
|
||||
|
||||
[](https://github.com/ruvnet/ruvector)
|
||||
|
||||
[Documentation](https://docs.rs/ruvector-core) | [Crates.io](https://crates.io/crates/ruvector-core) | [GitHub](https://github.com/ruvnet/ruvector)
|
||||
|
||||
</div>
|
||||
204
crates/ruvector-core/benches/batch_operations.rs
Normal file
204
crates/ruvector-core/benches/batch_operations.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use ruvector_core::types::{DistanceMetric, SearchQuery};
|
||||
use ruvector_core::{DbOptions, VectorDB, VectorEntry};
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn bench_batch_insert(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("batch_insert");
|
||||
|
||||
for batch_size in [100, 1000, 10000].iter() {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(batch_size),
|
||||
batch_size,
|
||||
|bench, &size| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
// Setup: Create DB and vectors
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path =
|
||||
dir.path().join("bench.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
options.hnsw_config = None; // Use flat index for faster insertion
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
let vectors: Vec<VectorEntry> = (0..size)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..128).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
(db, vectors, dir)
|
||||
},
|
||||
|(db, vectors, _dir)| {
|
||||
// Benchmark: Batch insert
|
||||
db.insert_batch(black_box(vectors)).unwrap()
|
||||
},
|
||||
criterion::BatchSize::LargeInput,
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_individual_insert_vs_batch(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("individual_vs_batch_insert");
|
||||
let size = 1000;
|
||||
|
||||
// Individual inserts
|
||||
group.bench_function("individual_1000", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("bench.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
let vectors: Vec<VectorEntry> = (0..size)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 64],
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
(db, vectors, dir)
|
||||
},
|
||||
|(db, vectors, _dir)| {
|
||||
for vector in vectors {
|
||||
db.insert(black_box(vector)).unwrap();
|
||||
}
|
||||
},
|
||||
criterion::BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
|
||||
// Batch insert
|
||||
group.bench_function("batch_1000", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("bench.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
let vectors: Vec<VectorEntry> = (0..size)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 64],
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
(db, vectors, dir)
|
||||
},
|
||||
|(db, vectors, _dir)| db.insert_batch(black_box(vectors)).unwrap(),
|
||||
criterion::BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_parallel_searches(c: &mut Criterion) {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("search_bench.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 128;
|
||||
options.distance_metric = DistanceMetric::Cosine;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert test data
|
||||
let vectors: Vec<VectorEntry> = (0..1000)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..128).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(vectors).unwrap();
|
||||
|
||||
// Benchmark multiple sequential searches
|
||||
c.bench_function("sequential_searches_100", |bench| {
|
||||
bench.iter(|| {
|
||||
for i in 0..100 {
|
||||
let query: Vec<f32> = (0..128).map(|j| ((i + j) as f32) * 0.01).collect();
|
||||
let _ = db
|
||||
.search(SearchQuery {
|
||||
vector: black_box(query),
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn bench_batch_delete(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("batch_delete");
|
||||
|
||||
for size in [100, 1000].iter() {
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size), size, |bench, &size| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
// Setup: Create DB with vectors
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path =
|
||||
dir.path().join("bench.db").to_string_lossy().to_string();
|
||||
options.dimensions = 32;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
let vectors: Vec<VectorEntry> = (0..size)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ids = db.insert_batch(vectors).unwrap();
|
||||
(db, ids, dir)
|
||||
},
|
||||
|(db, ids, _dir)| {
|
||||
// Benchmark: Delete all
|
||||
for id in ids {
|
||||
db.delete(black_box(&id)).unwrap();
|
||||
}
|
||||
},
|
||||
criterion::BatchSize::LargeInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_batch_insert,
|
||||
bench_individual_insert_vs_batch,
|
||||
bench_parallel_searches,
|
||||
bench_batch_delete
|
||||
);
|
||||
criterion_main!(benches);
|
||||
474
crates/ruvector-core/benches/bench_memory.rs
Normal file
474
crates/ruvector-core/benches/bench_memory.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
//! Memory Allocation and Pool Benchmarks
|
||||
//!
|
||||
//! This module benchmarks arena allocation, cache-optimized storage,
|
||||
//! and memory access patterns.
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ruvector_core::arena::Arena;
|
||||
use ruvector_core::cache_optimized::SoAVectorStorage;
|
||||
|
||||
// ============================================================================
|
||||
// Arena Allocation Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_arena_allocation(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("arena_allocation");
|
||||
|
||||
for count in [10, 100, 1000, 10000] {
|
||||
group.throughput(Throughput::Elements(count));
|
||||
|
||||
// Benchmark arena allocation
|
||||
group.bench_with_input(BenchmarkId::new("arena", count), &count, |bench, &count| {
|
||||
bench.iter(|| {
|
||||
let arena = Arena::new(1024 * 1024);
|
||||
for _ in 0..count {
|
||||
let _vec = arena.alloc_vec::<f32>(black_box(64));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Compare with standard Vec allocation
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("std_vec", count),
|
||||
&count,
|
||||
|bench, &count| {
|
||||
bench.iter(|| {
|
||||
let mut vecs = Vec::with_capacity(count as usize);
|
||||
for _ in 0..count {
|
||||
vecs.push(Vec::<f32>::with_capacity(black_box(64)));
|
||||
}
|
||||
vecs
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_arena_allocation_sizes(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("arena_allocation_sizes");
|
||||
|
||||
for size in [8, 32, 64, 128, 256, 512, 1024, 4096] {
|
||||
group.throughput(Throughput::Bytes(size as u64 * 4)); // f32 = 4 bytes
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("alloc", size), &size, |bench, &size| {
|
||||
bench.iter(|| {
|
||||
let arena = Arena::new(1024 * 1024);
|
||||
for _ in 0..1000 {
|
||||
let _vec = arena.alloc_vec::<f32>(black_box(size));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_arena_reset_reuse(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("arena_reset_reuse");
|
||||
|
||||
for iterations in [10, 100, 1000] {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("with_reset", iterations),
|
||||
&iterations,
|
||||
|bench, &iterations| {
|
||||
bench.iter(|| {
|
||||
let arena = Arena::new(1024 * 1024);
|
||||
for _ in 0..iterations {
|
||||
// Allocate
|
||||
for _ in 0..100 {
|
||||
let _vec = arena.alloc_vec::<f32>(64);
|
||||
}
|
||||
// Reset for reuse
|
||||
arena.reset();
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("without_reset", iterations),
|
||||
&iterations,
|
||||
|bench, &iterations| {
|
||||
bench.iter(|| {
|
||||
for _ in 0..iterations {
|
||||
let arena = Arena::new(1024 * 1024);
|
||||
for _ in 0..100 {
|
||||
let _vec = arena.alloc_vec::<f32>(64);
|
||||
}
|
||||
// No reset, create new arena each time
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_arena_push_operations(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("arena_push");
|
||||
|
||||
for count in [100, 1000, 10000] {
|
||||
group.throughput(Throughput::Elements(count));
|
||||
|
||||
// Arena push
|
||||
group.bench_with_input(BenchmarkId::new("arena", count), &count, |bench, &count| {
|
||||
bench.iter(|| {
|
||||
let arena = Arena::new(1024 * 1024);
|
||||
let mut vec = arena.alloc_vec::<f32>(count as usize);
|
||||
for i in 0..count {
|
||||
vec.push(black_box(i as f32));
|
||||
}
|
||||
vec
|
||||
});
|
||||
});
|
||||
|
||||
// Standard Vec push
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("std_vec", count),
|
||||
&count,
|
||||
|bench, &count| {
|
||||
bench.iter(|| {
|
||||
let mut vec = Vec::with_capacity(count as usize);
|
||||
for i in 0..count {
|
||||
vec.push(black_box(i as f32));
|
||||
}
|
||||
vec
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SoA Vector Storage Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_soa_storage_push(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("soa_storage_push");
|
||||
|
||||
for dim in [64, 128, 256, 384, 512, 768] {
|
||||
let vector: Vec<f32> = (0..dim).map(|i| i as f32 * 0.01).collect();
|
||||
|
||||
group.throughput(Throughput::Elements(dim as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("soa", dim), &dim, |bench, _| {
|
||||
bench.iter(|| {
|
||||
let mut storage = SoAVectorStorage::new(dim, 128);
|
||||
for _ in 0..1000 {
|
||||
storage.push(black_box(&vector));
|
||||
}
|
||||
storage
|
||||
});
|
||||
});
|
||||
|
||||
// Compare with Vec<Vec<f32>>
|
||||
group.bench_with_input(BenchmarkId::new("vec_of_vec", dim), &dim, |bench, _| {
|
||||
bench.iter(|| {
|
||||
let mut storage: Vec<Vec<f32>> = Vec::with_capacity(1000);
|
||||
for _ in 0..1000 {
|
||||
storage.push(black_box(vector.clone()));
|
||||
}
|
||||
storage
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_soa_storage_get(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("soa_storage_get");
|
||||
|
||||
for dim in [128, 384, 768] {
|
||||
let mut storage = SoAVectorStorage::new(dim, 128);
|
||||
|
||||
for i in 0..10000 {
|
||||
let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 * 0.001).collect();
|
||||
storage.push(&vector);
|
||||
}
|
||||
|
||||
let mut output = vec![0.0_f32; dim];
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("sequential", dim), &dim, |bench, _| {
|
||||
bench.iter(|| {
|
||||
for i in 0..10000 {
|
||||
storage.get(black_box(i), &mut output);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("random", dim), &dim, |bench, _| {
|
||||
let indices: Vec<usize> = (0..10000).map(|i| (i * 37 + 13) % 10000).collect();
|
||||
bench.iter(|| {
|
||||
for &idx in &indices {
|
||||
storage.get(black_box(idx), &mut output);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_soa_dimension_slice(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("soa_dimension_slice");
|
||||
|
||||
for dim in [64, 128, 256, 512] {
|
||||
let mut storage = SoAVectorStorage::new(dim, 128);
|
||||
|
||||
for i in 0..10000 {
|
||||
let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 * 0.001).collect();
|
||||
storage.push(&vector);
|
||||
}
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("access_all_dims", dim),
|
||||
&dim,
|
||||
|bench, &dim| {
|
||||
bench.iter(|| {
|
||||
let mut sum = 0.0_f32;
|
||||
for d in 0..dim {
|
||||
let slice = storage.dimension_slice(black_box(d));
|
||||
sum += slice.iter().sum::<f32>();
|
||||
}
|
||||
sum
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("access_single_dim", dim),
|
||||
&dim,
|
||||
|bench, _| {
|
||||
bench.iter(|| {
|
||||
let slice = storage.dimension_slice(black_box(0));
|
||||
slice.iter().sum::<f32>()
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_soa_batch_distances(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("soa_batch_distances");
|
||||
|
||||
for (dim, count) in [
|
||||
(128, 1000),
|
||||
(384, 1000),
|
||||
(768, 1000),
|
||||
(128, 10000),
|
||||
(384, 5000),
|
||||
] {
|
||||
let mut storage = SoAVectorStorage::new(dim, 128);
|
||||
|
||||
for i in 0..count {
|
||||
let vector: Vec<f32> = (0..dim)
|
||||
.map(|j| ((i * dim + j) % 1000) as f32 * 0.001)
|
||||
.collect();
|
||||
storage.push(&vector);
|
||||
}
|
||||
|
||||
let query: Vec<f32> = (0..dim).map(|j| j as f32 * 0.002).collect();
|
||||
let mut distances = vec![0.0_f32; count];
|
||||
|
||||
group.throughput(Throughput::Elements(count as u64));
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new(format!("{}d_x{}", dim, count), dim),
|
||||
&dim,
|
||||
|bench, _| {
|
||||
bench.iter(|| {
|
||||
storage.batch_euclidean_distances(black_box(&query), &mut distances);
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Access Pattern Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_memory_layout_comparison(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("memory_layout");
|
||||
|
||||
let dim = 384;
|
||||
let count = 10000;
|
||||
|
||||
// SoA layout
|
||||
let mut soa_storage = SoAVectorStorage::new(dim, 128);
|
||||
for i in 0..count {
|
||||
let vector: Vec<f32> = (0..dim)
|
||||
.map(|j| ((i * dim + j) % 1000) as f32 * 0.001)
|
||||
.collect();
|
||||
soa_storage.push(&vector);
|
||||
}
|
||||
|
||||
// AoS layout (Vec<Vec<f32>>)
|
||||
let aos_storage: Vec<Vec<f32>> = (0..count)
|
||||
.map(|i| {
|
||||
(0..dim)
|
||||
.map(|j| ((i * dim + j) % 1000) as f32 * 0.001)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let query: Vec<f32> = (0..dim).map(|j| j as f32 * 0.002).collect();
|
||||
let mut soa_distances = vec![0.0_f32; count];
|
||||
|
||||
group.bench_function("soa_batch_euclidean", |bench| {
|
||||
bench.iter(|| {
|
||||
soa_storage.batch_euclidean_distances(black_box(&query), &mut soa_distances);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("aos_naive_euclidean", |bench| {
|
||||
bench.iter(|| {
|
||||
let distances: Vec<f32> = aos_storage
|
||||
.iter()
|
||||
.map(|v| {
|
||||
query
|
||||
.iter()
|
||||
.zip(v.iter())
|
||||
.map(|(a, b)| (a - b) * (a - b))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
})
|
||||
.collect();
|
||||
distances
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_cache_efficiency(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("cache_efficiency");
|
||||
|
||||
let dim = 512;
|
||||
|
||||
// Test with different vector counts to observe cache effects
|
||||
for count in [100, 1000, 10000, 50000] {
|
||||
let mut storage = SoAVectorStorage::new(dim, 128);
|
||||
|
||||
for i in 0..count {
|
||||
let vector: Vec<f32> = (0..dim)
|
||||
.map(|j| ((i * dim + j) % 1000) as f32 * 0.001)
|
||||
.collect();
|
||||
storage.push(&vector);
|
||||
}
|
||||
|
||||
let query: Vec<f32> = (0..dim).map(|j| j as f32 * 0.001).collect();
|
||||
let mut distances = vec![0.0_f32; count];
|
||||
|
||||
group.throughput(Throughput::Elements(count as u64));
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("batch_distance", count),
|
||||
&count,
|
||||
|bench, _| {
|
||||
bench.iter(|| {
|
||||
storage.batch_euclidean_distances(black_box(&query), &mut distances);
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Growth and Reallocation Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_soa_growth(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("soa_growth");
|
||||
|
||||
// Test growth from small initial capacity
|
||||
group.bench_function("grow_from_small", |bench| {
|
||||
bench.iter(|| {
|
||||
let mut storage = SoAVectorStorage::new(128, 4); // Very small initial
|
||||
for i in 0..10000 {
|
||||
let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.001).collect();
|
||||
storage.push(black_box(&vector));
|
||||
}
|
||||
storage
|
||||
});
|
||||
});
|
||||
|
||||
// Test with pre-allocated capacity
|
||||
group.bench_function("preallocated", |bench| {
|
||||
bench.iter(|| {
|
||||
let mut storage = SoAVectorStorage::new(128, 16384); // Pre-allocate
|
||||
for i in 0..10000 {
|
||||
let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.001).collect();
|
||||
storage.push(black_box(&vector));
|
||||
}
|
||||
storage
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mixed Type Allocation Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_arena_mixed_types(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("arena_mixed_types");
|
||||
|
||||
group.bench_function("mixed_allocations", |bench| {
|
||||
bench.iter(|| {
|
||||
let arena = Arena::new(1024 * 1024);
|
||||
for _ in 0..100 {
|
||||
let _f32_vec = arena.alloc_vec::<f32>(black_box(64));
|
||||
let _f64_vec = arena.alloc_vec::<f64>(black_box(32));
|
||||
let _u32_vec = arena.alloc_vec::<u32>(black_box(128));
|
||||
let _u8_vec = arena.alloc_vec::<u8>(black_box(256));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("uniform_allocations", |bench| {
|
||||
bench.iter(|| {
|
||||
let arena = Arena::new(1024 * 1024);
|
||||
for _ in 0..400 {
|
||||
let _f32_vec = arena.alloc_vec::<f32>(black_box(64));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Criterion Groups
|
||||
// ============================================================================
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_arena_allocation,
|
||||
bench_arena_allocation_sizes,
|
||||
bench_arena_reset_reuse,
|
||||
bench_arena_push_operations,
|
||||
bench_soa_storage_push,
|
||||
bench_soa_storage_get,
|
||||
bench_soa_dimension_slice,
|
||||
bench_soa_batch_distances,
|
||||
bench_memory_layout_comparison,
|
||||
bench_cache_efficiency,
|
||||
bench_soa_growth,
|
||||
bench_arena_mixed_types,
|
||||
);
|
||||
|
||||
criterion_main!(benches);
|
||||
335
crates/ruvector-core/benches/bench_simd.rs
Normal file
335
crates/ruvector-core/benches/bench_simd.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
//! SIMD Performance Benchmarks
|
||||
//!
|
||||
//! This module benchmarks SIMD-optimized distance calculations
|
||||
//! across various vector dimensions and operation types.
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ruvector_core::simd_intrinsics::*;
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
fn generate_vectors(dim: usize) -> (Vec<f32>, Vec<f32>) {
|
||||
let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
|
||||
let b: Vec<f32> = (0..dim).map(|i| ((i + 100) as f32) * 0.01).collect();
|
||||
(a, b)
|
||||
}
|
||||
|
||||
fn generate_batch_vectors(dim: usize, count: usize) -> (Vec<f32>, Vec<Vec<f32>>) {
|
||||
let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
|
||||
let vectors: Vec<Vec<f32>> = (0..count)
|
||||
.map(|j| (0..dim).map(|i| ((i + j * 10) as f32) * 0.01).collect())
|
||||
.collect();
|
||||
(query, vectors)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Euclidean Distance Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_euclidean_by_dimension(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("euclidean_by_dimension");
|
||||
|
||||
for dim in [32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
group.throughput(Throughput::Elements(dim as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
|
||||
bench.iter(|| euclidean_distance_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_euclidean_small_vectors(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("euclidean_small_vectors");
|
||||
|
||||
// Test small vector sizes that may not benefit from SIMD
|
||||
for dim in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16] {
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
|
||||
bench.iter(|| euclidean_distance_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_euclidean_non_aligned(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("euclidean_non_aligned");
|
||||
|
||||
// Test non-SIMD-aligned sizes
|
||||
for dim in [31, 33, 63, 65, 127, 129, 255, 257, 383, 385, 511, 513] {
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
|
||||
bench.iter(|| euclidean_distance_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Dot Product Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_dot_product_by_dimension(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("dot_product_by_dimension");
|
||||
|
||||
for dim in [32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
group.throughput(Throughput::Elements(dim as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
|
||||
bench.iter(|| dot_product_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_dot_product_common_embeddings(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("dot_product_common_embeddings");
|
||||
|
||||
// Common embedding model dimensions
|
||||
let dims = [
|
||||
(128, "small"),
|
||||
(384, "all-MiniLM-L6"),
|
||||
(512, "e5-small"),
|
||||
(768, "all-mpnet-base"),
|
||||
(1024, "e5-large"),
|
||||
(1536, "text-embedding-ada-002"),
|
||||
(2048, "llama-7b-hidden"),
|
||||
];
|
||||
|
||||
for (dim, name) in dims {
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
group.bench_with_input(BenchmarkId::new(name, dim), &dim, |bench, _| {
|
||||
bench.iter(|| dot_product_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cosine Similarity Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_cosine_by_dimension(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("cosine_by_dimension");
|
||||
|
||||
for dim in [32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
group.throughput(Throughput::Elements(dim as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
|
||||
bench.iter(|| cosine_similarity_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Manhattan Distance Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_manhattan_by_dimension(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("manhattan_by_dimension");
|
||||
|
||||
for dim in [32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
group.throughput(Throughput::Elements(dim as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
|
||||
bench.iter(|| manhattan_distance_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Batch Operations Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_batch_euclidean(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("batch_euclidean");
|
||||
|
||||
for count in [10, 100, 1000, 10000] {
|
||||
let (query, vectors) = generate_batch_vectors(384, count);
|
||||
|
||||
group.throughput(Throughput::Elements(count as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("384d", count), &count, |bench, _| {
|
||||
bench.iter(|| {
|
||||
for v in &vectors {
|
||||
euclidean_distance_simd(black_box(&query), black_box(v));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_batch_dot_product(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("batch_dot_product");
|
||||
|
||||
for count in [10, 100, 1000, 10000] {
|
||||
let (query, vectors) = generate_batch_vectors(768, count);
|
||||
|
||||
group.throughput(Throughput::Elements(count as u64));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("768d", count), &count, |bench, _| {
|
||||
bench.iter(|| {
|
||||
for v in &vectors {
|
||||
dot_product_simd(black_box(&query), black_box(v));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Comparison Benchmarks (All Metrics)
|
||||
// ============================================================================
|
||||
|
||||
fn bench_all_metrics_comparison(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("metrics_comparison");
|
||||
|
||||
let dim = 384; // Common embedding dimension
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
group.bench_function("euclidean", |bench| {
|
||||
bench.iter(|| euclidean_distance_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
|
||||
group.bench_function("dot_product", |bench| {
|
||||
bench.iter(|| dot_product_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
|
||||
group.bench_function("cosine", |bench| {
|
||||
bench.iter(|| cosine_similarity_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
|
||||
group.bench_function("manhattan", |bench| {
|
||||
bench.iter(|| manhattan_distance_simd(black_box(&a), black_box(&b)));
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Access Pattern Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_sequential_vs_random_access(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("access_patterns");
|
||||
|
||||
let dim = 512;
|
||||
let count = 1000;
|
||||
|
||||
// Generate vectors
|
||||
let vectors: Vec<Vec<f32>> = (0..count)
|
||||
.map(|j| (0..dim).map(|i| ((i + j * 10) as f32) * 0.01).collect())
|
||||
.collect();
|
||||
let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
|
||||
|
||||
// Sequential access indices
|
||||
let sequential_indices: Vec<usize> = (0..count).collect();
|
||||
|
||||
// Random-ish access indices
|
||||
let random_indices: Vec<usize> = (0..count)
|
||||
.map(|i| (i * 37 + 13) % count) // Pseudo-random
|
||||
.collect();
|
||||
|
||||
group.bench_function("sequential", |bench| {
|
||||
bench.iter(|| {
|
||||
for &idx in &sequential_indices {
|
||||
euclidean_distance_simd(black_box(&query), black_box(&vectors[idx]));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("random", |bench| {
|
||||
bench.iter(|| {
|
||||
for &idx in &random_indices {
|
||||
euclidean_distance_simd(black_box(&query), black_box(&vectors[idx]));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Throughput Measurement
|
||||
// ============================================================================
|
||||
|
||||
fn bench_throughput_ops_per_second(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("throughput");
|
||||
group.sample_size(50);
|
||||
|
||||
for dim in [128, 384, 768, 1536] {
|
||||
let (a, b) = generate_vectors(dim);
|
||||
|
||||
// Report throughput in operations/second
|
||||
group.bench_with_input(BenchmarkId::new("euclidean_ops", dim), &dim, |bench, _| {
|
||||
bench.iter(|| {
|
||||
// Perform 100 operations per iteration
|
||||
for _ in 0..100 {
|
||||
euclidean_distance_simd(black_box(&a), black_box(&b));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("dot_product_ops", dim),
|
||||
&dim,
|
||||
|bench, _| {
|
||||
bench.iter(|| {
|
||||
for _ in 0..100 {
|
||||
dot_product_simd(black_box(&a), black_box(&b));
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Criterion Groups
|
||||
// ============================================================================
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_euclidean_by_dimension,
|
||||
bench_euclidean_small_vectors,
|
||||
bench_euclidean_non_aligned,
|
||||
bench_dot_product_by_dimension,
|
||||
bench_dot_product_common_embeddings,
|
||||
bench_cosine_by_dimension,
|
||||
bench_manhattan_by_dimension,
|
||||
bench_batch_euclidean,
|
||||
bench_batch_dot_product,
|
||||
bench_all_metrics_comparison,
|
||||
bench_sequential_vs_random_access,
|
||||
bench_throughput_ops_per_second,
|
||||
);
|
||||
|
||||
criterion_main!(benches);
|
||||
262
crates/ruvector-core/benches/comprehensive_bench.rs
Normal file
262
crates/ruvector-core/benches/comprehensive_bench.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ruvector_core::arena::Arena;
|
||||
use ruvector_core::cache_optimized::SoAVectorStorage;
|
||||
use ruvector_core::distance::*;
|
||||
use ruvector_core::lockfree::{LockFreeCounter, LockFreeStats, ObjectPool};
|
||||
use ruvector_core::simd_intrinsics::*;
|
||||
use ruvector_core::types::DistanceMetric;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
// Benchmark SIMD intrinsics vs SimSIMD
|
||||
fn bench_simd_comparison(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("simd_comparison");
|
||||
|
||||
for size in [128, 384, 768, 1536].iter() {
|
||||
let a: Vec<f32> = (0..*size).map(|i| i as f32 * 0.1).collect();
|
||||
let b: Vec<f32> = (0..*size).map(|i| (i + 1) as f32 * 0.1).collect();
|
||||
|
||||
// Euclidean distance
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("euclidean_simsimd", size),
|
||||
size,
|
||||
|bench, _| {
|
||||
bench.iter(|| euclidean_distance(black_box(&a), black_box(&b)));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("euclidean_avx2", size),
|
||||
size,
|
||||
|bench, _| {
|
||||
bench.iter(|| euclidean_distance_avx2(black_box(&a), black_box(&b)));
|
||||
},
|
||||
);
|
||||
|
||||
// Dot product
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("dot_product_simsimd", size),
|
||||
size,
|
||||
|bench, _| {
|
||||
bench.iter(|| dot_product_distance(black_box(&a), black_box(&b)));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("dot_product_avx2", size),
|
||||
size,
|
||||
|bench, _| {
|
||||
bench.iter(|| dot_product_avx2(black_box(&a), black_box(&b)));
|
||||
},
|
||||
);
|
||||
|
||||
// Cosine similarity
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("cosine_simsimd", size),
|
||||
size,
|
||||
|bench, _| {
|
||||
bench.iter(|| cosine_distance(black_box(&a), black_box(&b)));
|
||||
},
|
||||
);
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("cosine_avx2", size), size, |bench, _| {
|
||||
bench.iter(|| cosine_similarity_avx2(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark Structure-of-Arrays vs Array-of-Structures
|
||||
fn bench_cache_optimization(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("cache_optimization");
|
||||
|
||||
let dimensions = 384;
|
||||
let num_vectors = 10000;
|
||||
|
||||
// Prepare data
|
||||
let vectors: Vec<Vec<f32>> = (0..num_vectors)
|
||||
.map(|i| (0..dimensions).map(|j| (i * j) as f32 * 0.001).collect())
|
||||
.collect();
|
||||
|
||||
let query: Vec<f32> = (0..dimensions).map(|i| i as f32 * 0.01).collect();
|
||||
|
||||
// Array-of-Structures (traditional Vec<Vec<f32>>)
|
||||
group.bench_function("aos_batch_distance", |bench| {
|
||||
bench.iter(|| {
|
||||
let mut distances: Vec<f32> = Vec::with_capacity(num_vectors);
|
||||
for vector in &vectors {
|
||||
let dist = euclidean_distance(black_box(&query), black_box(vector));
|
||||
distances.push(dist);
|
||||
}
|
||||
black_box(distances)
|
||||
});
|
||||
});
|
||||
|
||||
// Structure-of-Arrays
|
||||
let mut soa_storage = SoAVectorStorage::new(dimensions, num_vectors);
|
||||
for vector in &vectors {
|
||||
soa_storage.push(vector);
|
||||
}
|
||||
|
||||
group.bench_function("soa_batch_distance", |bench| {
|
||||
bench.iter(|| {
|
||||
let mut distances = vec![0.0; num_vectors];
|
||||
soa_storage.batch_euclidean_distances(black_box(&query), &mut distances);
|
||||
black_box(distances)
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark arena allocation vs standard allocation
|
||||
fn bench_arena_allocation(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("arena_allocation");
|
||||
|
||||
let num_allocations = 1000;
|
||||
let vec_size = 100;
|
||||
|
||||
group.bench_function("standard_allocation", |bench| {
|
||||
bench.iter(|| {
|
||||
let mut vecs = Vec::new();
|
||||
for _ in 0..num_allocations {
|
||||
let mut v = Vec::with_capacity(vec_size);
|
||||
for j in 0..vec_size {
|
||||
v.push(j as f32);
|
||||
}
|
||||
vecs.push(v);
|
||||
}
|
||||
black_box(vecs)
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("arena_allocation", |bench| {
|
||||
bench.iter(|| {
|
||||
let arena = Arena::new(1024 * 1024);
|
||||
let mut vecs = Vec::new();
|
||||
for _ in 0..num_allocations {
|
||||
let mut v = arena.alloc_vec::<f32>(vec_size);
|
||||
for j in 0..vec_size {
|
||||
v.push(j as f32);
|
||||
}
|
||||
vecs.push(v);
|
||||
}
|
||||
black_box(vecs)
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark lock-free operations vs locked operations
|
||||
fn bench_lockfree(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("lockfree");
|
||||
|
||||
// Counter benchmark
|
||||
group.bench_function("lockfree_counter_single_thread", |bench| {
|
||||
let counter = LockFreeCounter::new(0);
|
||||
bench.iter(|| {
|
||||
for _ in 0..10000 {
|
||||
counter.increment();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("lockfree_counter_multi_thread", |bench| {
|
||||
bench.iter(|| {
|
||||
let counter = Arc::new(LockFreeCounter::new(0));
|
||||
let mut handles = vec![];
|
||||
|
||||
for _ in 0..4 {
|
||||
let counter_clone = Arc::clone(&counter);
|
||||
handles.push(thread::spawn(move || {
|
||||
for _ in 0..2500 {
|
||||
counter_clone.increment();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
black_box(counter.get())
|
||||
});
|
||||
});
|
||||
|
||||
// Stats collector benchmark
|
||||
group.bench_function("lockfree_stats", |bench| {
|
||||
let stats = LockFreeStats::new();
|
||||
bench.iter(|| {
|
||||
for i in 0..1000 {
|
||||
stats.record_query(i);
|
||||
}
|
||||
black_box(stats.snapshot())
|
||||
});
|
||||
});
|
||||
|
||||
// Object pool benchmark
|
||||
group.bench_function("object_pool_acquire_release", |bench| {
|
||||
let pool = ObjectPool::new(10, || Vec::<f32>::with_capacity(1000));
|
||||
bench.iter(|| {
|
||||
let mut obj = pool.acquire();
|
||||
for i in 0..100 {
|
||||
obj.push(i as f32);
|
||||
}
|
||||
black_box(&*obj);
|
||||
});
|
||||
});
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
// Benchmark thread scaling
|
||||
fn bench_thread_scaling(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("thread_scaling");
|
||||
|
||||
let dimensions = 384;
|
||||
let num_vectors = 10000;
|
||||
let query: Vec<f32> = (0..dimensions).map(|i| i as f32 * 0.01).collect();
|
||||
let vectors: Vec<Vec<f32>> = (0..num_vectors)
|
||||
.map(|i| (0..dimensions).map(|j| (i * j) as f32 * 0.001).collect())
|
||||
.collect();
|
||||
|
||||
for num_threads in [1, 2, 4, 8].iter() {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("parallel_distance", num_threads),
|
||||
num_threads,
|
||||
|bench, &threads| {
|
||||
bench.iter(|| {
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(threads)
|
||||
.build()
|
||||
.unwrap()
|
||||
.install(|| {
|
||||
let result = batch_distances(
|
||||
black_box(&query),
|
||||
black_box(&vectors),
|
||||
DistanceMetric::Euclidean,
|
||||
)
|
||||
.unwrap();
|
||||
black_box(result)
|
||||
})
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_simd_comparison,
|
||||
bench_cache_optimization,
|
||||
bench_arena_allocation,
|
||||
bench_lockfree,
|
||||
bench_thread_scaling
|
||||
);
|
||||
criterion_main!(benches);
|
||||
74
crates/ruvector-core/benches/distance_metrics.rs
Normal file
74
crates/ruvector-core/benches/distance_metrics.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use ruvector_core::distance::*;
|
||||
use ruvector_core::types::DistanceMetric;
|
||||
|
||||
fn bench_euclidean(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("euclidean_distance");
|
||||
|
||||
for size in [128, 384, 768, 1536].iter() {
|
||||
let a: Vec<f32> = (0..*size).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (0..*size).map(|i| (i + 1) as f32).collect();
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size), size, |bench, _| {
|
||||
bench.iter(|| euclidean_distance(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_cosine(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("cosine_distance");
|
||||
|
||||
for size in [128, 384, 768, 1536].iter() {
|
||||
let a: Vec<f32> = (0..*size).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (0..*size).map(|i| (i + 1) as f32).collect();
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size), size, |bench, _| {
|
||||
bench.iter(|| cosine_distance(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_dot_product(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("dot_product_distance");
|
||||
|
||||
for size in [128, 384, 768, 1536].iter() {
|
||||
let a: Vec<f32> = (0..*size).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (0..*size).map(|i| (i + 1) as f32).collect();
|
||||
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size), size, |bench, _| {
|
||||
bench.iter(|| dot_product_distance(black_box(&a), black_box(&b)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_batch_distances(c: &mut Criterion) {
|
||||
let query: Vec<f32> = (0..384).map(|i| i as f32).collect();
|
||||
let vectors: Vec<Vec<f32>> = (0..1000)
|
||||
.map(|_| (0..384).map(|i| (i as f32) * 1.1).collect())
|
||||
.collect();
|
||||
|
||||
c.bench_function("batch_distances_1000x384", |b| {
|
||||
b.iter(|| {
|
||||
batch_distances(
|
||||
black_box(&query),
|
||||
black_box(&vectors),
|
||||
DistanceMetric::Cosine,
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_euclidean,
|
||||
bench_cosine,
|
||||
bench_dot_product,
|
||||
bench_batch_distances
|
||||
);
|
||||
criterion_main!(benches);
|
||||
56
crates/ruvector-core/benches/hnsw_search.rs
Normal file
56
crates/ruvector-core/benches/hnsw_search.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use ruvector_core::types::{DbOptions, DistanceMetric, HnswConfig, SearchQuery};
|
||||
use ruvector_core::{VectorDB, VectorEntry};
|
||||
|
||||
fn bench_hnsw_search(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("hnsw_search");
|
||||
|
||||
// Create temp database
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let options = DbOptions {
|
||||
dimensions: 128,
|
||||
distance_metric: DistanceMetric::Cosine,
|
||||
storage_path: temp_dir
|
||||
.path()
|
||||
.join("test.db")
|
||||
.to_string_lossy()
|
||||
.to_string(),
|
||||
hnsw_config: Some(HnswConfig::default()),
|
||||
quantization: None,
|
||||
};
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert test vectors
|
||||
let vectors: Vec<VectorEntry> = (0..1000)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("v{}", i)),
|
||||
vector: (0..128).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(vectors).unwrap();
|
||||
|
||||
// Benchmark search
|
||||
let query: Vec<f32> = (0..128).map(|i| i as f32).collect();
|
||||
|
||||
for k in [1, 10, 100].iter() {
|
||||
group.bench_with_input(BenchmarkId::from_parameter(k), k, |bench, &k| {
|
||||
bench.iter(|| {
|
||||
db.search(SearchQuery {
|
||||
vector: black_box(query.clone()),
|
||||
k: black_box(k),
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap()
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_hnsw_search);
|
||||
criterion_main!(benches);
|
||||
77
crates/ruvector-core/benches/quantization_bench.rs
Normal file
77
crates/ruvector-core/benches/quantization_bench.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use ruvector_core::quantization::*;
|
||||
|
||||
fn bench_scalar_quantization(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("scalar_quantization");
|
||||
|
||||
for size in [128, 384, 768, 1536].iter() {
|
||||
let vector: Vec<f32> = (0..*size).map(|i| i as f32 * 0.1).collect();
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("encode", size), size, |bench, _| {
|
||||
bench.iter(|| ScalarQuantized::quantize(black_box(&vector)));
|
||||
});
|
||||
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
group.bench_with_input(BenchmarkId::new("decode", size), size, |bench, _| {
|
||||
bench.iter(|| quantized.reconstruct());
|
||||
});
|
||||
|
||||
let quantized2 = ScalarQuantized::quantize(&vector);
|
||||
group.bench_with_input(BenchmarkId::new("distance", size), size, |bench, _| {
|
||||
bench.iter(|| quantized.distance(black_box(&quantized2)));
|
||||
});
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_binary_quantization(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("binary_quantization");
|
||||
|
||||
for size in [128, 384, 768, 1536].iter() {
|
||||
let vector: Vec<f32> = (0..*size)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("encode", size), size, |bench, _| {
|
||||
bench.iter(|| BinaryQuantized::quantize(black_box(&vector)));
|
||||
});
|
||||
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
group.bench_with_input(BenchmarkId::new("decode", size), size, |bench, _| {
|
||||
bench.iter(|| quantized.reconstruct());
|
||||
});
|
||||
|
||||
let quantized2 = BinaryQuantized::quantize(&vector);
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("hamming_distance", size),
|
||||
size,
|
||||
|bench, _| {
|
||||
bench.iter(|| quantized.distance(black_box(&quantized2)));
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_quantization_compression_ratio(c: &mut Criterion) {
|
||||
let dimensions = 384;
|
||||
let vector: Vec<f32> = (0..dimensions).map(|i| i as f32 * 0.01).collect();
|
||||
|
||||
c.bench_function("scalar_vs_binary_encoding", |b| {
|
||||
b.iter(|| {
|
||||
let scalar = ScalarQuantized::quantize(black_box(&vector));
|
||||
let binary = BinaryQuantized::quantize(black_box(&vector));
|
||||
(scalar, binary)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_scalar_quantization,
|
||||
bench_binary_quantization,
|
||||
bench_quantization_compression_ratio
|
||||
);
|
||||
criterion_main!(benches);
|
||||
217
crates/ruvector-core/benches/real_benchmark.rs
Normal file
217
crates/ruvector-core/benches/real_benchmark.rs
Normal file
@@ -0,0 +1,217 @@
|
||||
//! Real Benchmarks for RuVector Core
|
||||
//!
|
||||
//! These are ACTUAL performance measurements, not simulations.
|
||||
//! Run with: cargo bench -p ruvector-core --bench real_benchmark
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ruvector_core::types::{DbOptions, HnswConfig};
|
||||
use ruvector_core::{DistanceMetric, SearchQuery, VectorDB, VectorEntry};
|
||||
use tempfile::tempdir;
|
||||
|
||||
/// Generate random vectors for benchmarking
|
||||
fn generate_vectors(count: usize, dim: usize) -> Vec<Vec<f32>> {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
(0..count)
|
||||
.map(|i| {
|
||||
(0..dim)
|
||||
.map(|j| {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
(i * dim + j).hash(&mut hasher);
|
||||
let h = hasher.finish();
|
||||
((h % 2000) as f32 / 1000.0) - 1.0 // Range [-1, 1]
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Benchmark: Vector insertion (single)
|
||||
fn bench_insert_single(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("insert_single");
|
||||
|
||||
for dim in [64, 128, 256, 512].iter() {
|
||||
let vectors = generate_vectors(1000, *dim);
|
||||
|
||||
group.throughput(Throughput::Elements(1));
|
||||
group.bench_with_input(BenchmarkId::new("dimensions", dim), dim, |b, &dim| {
|
||||
let dir = tempdir().unwrap();
|
||||
let options = DbOptions {
|
||||
storage_path: dir.path().join("bench.db").to_string_lossy().to_string(),
|
||||
dimensions: dim,
|
||||
distance_metric: DistanceMetric::Cosine,
|
||||
hnsw_config: Some(HnswConfig::default()),
|
||||
quantization: None,
|
||||
};
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
let mut idx = 0;
|
||||
|
||||
b.iter(|| {
|
||||
let entry = VectorEntry {
|
||||
id: None,
|
||||
vector: vectors[idx % vectors.len()].clone(),
|
||||
metadata: None,
|
||||
};
|
||||
let _ = black_box(db.insert(entry));
|
||||
idx += 1;
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Vector insertion (batch)
|
||||
fn bench_insert_batch(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("insert_batch");
|
||||
|
||||
for batch_size in [100, 500, 1000].iter() {
|
||||
let vectors = generate_vectors(*batch_size, 128);
|
||||
|
||||
group.throughput(Throughput::Elements(*batch_size as u64));
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("batch_size", batch_size),
|
||||
batch_size,
|
||||
|b, &batch_size| {
|
||||
b.iter(|| {
|
||||
let dir = tempdir().unwrap();
|
||||
let options = DbOptions {
|
||||
storage_path: dir.path().join("bench.db").to_string_lossy().to_string(),
|
||||
dimensions: 128,
|
||||
distance_metric: DistanceMetric::Cosine,
|
||||
hnsw_config: Some(HnswConfig::default()),
|
||||
quantization: None,
|
||||
};
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
let entries: Vec<VectorEntry> = vectors
|
||||
.iter()
|
||||
.map(|v| VectorEntry {
|
||||
id: None,
|
||||
vector: v.clone(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
black_box(db.insert_batch(entries).unwrap())
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Search (k-NN)
|
||||
fn bench_search(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("search");
|
||||
|
||||
// Pre-populate database
|
||||
let dir = tempdir().unwrap();
|
||||
let options = DbOptions {
|
||||
storage_path: dir.path().join("bench.db").to_string_lossy().to_string(),
|
||||
dimensions: 128,
|
||||
distance_metric: DistanceMetric::Cosine,
|
||||
hnsw_config: Some(HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 50,
|
||||
max_elements: 100000,
|
||||
}),
|
||||
quantization: None,
|
||||
};
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert 10k vectors
|
||||
let vectors = generate_vectors(10000, 128);
|
||||
let entries: Vec<VectorEntry> = vectors
|
||||
.iter()
|
||||
.map(|v| VectorEntry {
|
||||
id: None,
|
||||
vector: v.clone(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
db.insert_batch(entries).unwrap();
|
||||
|
||||
// Generate query vectors
|
||||
let queries = generate_vectors(100, 128);
|
||||
|
||||
for k in [10, 50, 100].iter() {
|
||||
group.throughput(Throughput::Elements(1));
|
||||
group.bench_with_input(BenchmarkId::new("top_k", k), k, |b, &k| {
|
||||
let mut query_idx = 0;
|
||||
b.iter(|| {
|
||||
let query = &queries[query_idx % queries.len()];
|
||||
let search_query = SearchQuery {
|
||||
vector: query.clone(),
|
||||
k,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
};
|
||||
let results = black_box(db.search(search_query));
|
||||
query_idx += 1;
|
||||
results
|
||||
});
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Distance computation (raw)
|
||||
fn bench_distance(c: &mut Criterion) {
|
||||
use ruvector_core::distance::{cosine_distance, dot_product_distance, euclidean_distance};
|
||||
|
||||
let mut group = c.benchmark_group("distance");
|
||||
|
||||
for dim in [64, 128, 256, 512, 1024].iter() {
|
||||
let v1: Vec<f32> = (0..*dim).map(|i| (i as f32 * 0.01).sin()).collect();
|
||||
let v2: Vec<f32> = (0..*dim).map(|i| (i as f32 * 0.02).cos()).collect();
|
||||
|
||||
group.throughput(Throughput::Elements(1));
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("cosine", dim), dim, |b, _| {
|
||||
b.iter(|| black_box(cosine_distance(&v1, &v2)));
|
||||
});
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("euclidean", dim), dim, |b, _| {
|
||||
b.iter(|| black_box(euclidean_distance(&v1, &v2)));
|
||||
});
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("dot_product", dim), dim, |b, _| {
|
||||
b.iter(|| black_box(dot_product_distance(&v1, &v2)));
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
/// Benchmark: Quantization
|
||||
fn bench_quantization(c: &mut Criterion) {
|
||||
use ruvector_core::quantization::{QuantizedVector, ScalarQuantized};
|
||||
|
||||
let mut group = c.benchmark_group("quantization");
|
||||
|
||||
for dim in [128, 256, 512].iter() {
|
||||
let vector: Vec<f32> = (0..*dim).map(|i| (i as f32 * 0.01).sin()).collect();
|
||||
|
||||
group.bench_with_input(BenchmarkId::new("scalar_quantize", dim), dim, |b, _| {
|
||||
b.iter(|| black_box(ScalarQuantized::quantize(&vector)));
|
||||
});
|
||||
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
group.bench_with_input(BenchmarkId::new("scalar_distance", dim), dim, |b, _| {
|
||||
b.iter(|| black_box(quantized.distance(&quantized)));
|
||||
});
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_distance,
|
||||
bench_quantization,
|
||||
bench_insert_single,
|
||||
bench_insert_batch,
|
||||
bench_search,
|
||||
);
|
||||
|
||||
criterion_main!(benches);
|
||||
296
crates/ruvector-core/docs/EMBEDDINGS.md
Normal file
296
crates/ruvector-core/docs/EMBEDDINGS.md
Normal file
@@ -0,0 +1,296 @@
|
||||
# Text Embeddings for AgenticDB
|
||||
|
||||
This guide explains how to use real text embeddings with AgenticDB in ruvector-core.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Default (Hash-based - Testing Only)
|
||||
|
||||
```rust
|
||||
use ruvector_core::{AgenticDB, types::DbOptions};
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 128;
|
||||
options.storage_path = "agenticdb.db".to_string();
|
||||
|
||||
// Uses hash-based embeddings by default (fast but not semantic)
|
||||
let db = AgenticDB::new(options)?;
|
||||
|
||||
// Store and retrieve episodes
|
||||
let episode_id = db.store_episode(
|
||||
"Solve math problem".to_string(),
|
||||
vec!["read".to_string(), "calculate".to_string()],
|
||||
vec!["got 42".to_string()],
|
||||
"Should show work".to_string(),
|
||||
)?;
|
||||
```
|
||||
|
||||
⚠️ **Warning**: Hash-based embeddings don't understand semantic meaning!
|
||||
- "dog" and "cat" will NOT be similar
|
||||
- "dog" and "god" WILL be similar (same characters)
|
||||
|
||||
## Production: API-based Embeddings (Recommended)
|
||||
|
||||
### OpenAI
|
||||
|
||||
```rust
|
||||
use ruvector_core::{AgenticDB, ApiEmbedding, types::DbOptions};
|
||||
use std::sync::Arc;
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 1536; // text-embedding-3-small
|
||||
options.storage_path = "agenticdb.db".to_string();
|
||||
|
||||
let api_key = std::env::var("OPENAI_API_KEY")?;
|
||||
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
|
||||
|
||||
let db = AgenticDB::with_embedding_provider(options, provider)?;
|
||||
|
||||
// Now you have semantic embeddings!
|
||||
let episodes = db.retrieve_similar_episodes("mathematics", 5)?;
|
||||
```
|
||||
|
||||
**OpenAI Models:**
|
||||
- `text-embedding-3-small` - 1536 dims, $0.02/1M tokens (recommended)
|
||||
- `text-embedding-3-large` - 3072 dims, $0.13/1M tokens (best quality)
|
||||
|
||||
### Cohere
|
||||
|
||||
```rust
|
||||
let api_key = std::env::var("COHERE_API_KEY")?;
|
||||
let provider = Arc::new(ApiEmbedding::cohere(&api_key, "embed-english-v3.0"));
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 1024; // Cohere embedding size
|
||||
|
||||
let db = AgenticDB::with_embedding_provider(options, provider)?;
|
||||
```
|
||||
|
||||
### Voyage AI
|
||||
|
||||
```rust
|
||||
let api_key = std::env::var("VOYAGE_API_KEY")?;
|
||||
let provider = Arc::new(ApiEmbedding::voyage(&api_key, "voyage-2"));
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 1024; // voyage-2 size
|
||||
|
||||
let db = AgenticDB::with_embedding_provider(options, provider)?;
|
||||
```
|
||||
|
||||
## Custom Embedding Provider
|
||||
|
||||
Implement the `EmbeddingProvider` trait for any embedding system:
|
||||
|
||||
```rust
|
||||
use ruvector_core::embeddings::EmbeddingProvider;
|
||||
use ruvector_core::error::Result;
|
||||
|
||||
struct MyCustomEmbedding {
|
||||
// Your model here
|
||||
}
|
||||
|
||||
impl EmbeddingProvider for MyCustomEmbedding {
|
||||
fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
// Your embedding logic
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
384 // Your embedding dimensions
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"MyCustomEmbedding"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## ONNX Runtime (Local, No API Costs)
|
||||
|
||||
For production use without API costs, use ONNX Runtime with pre-exported models:
|
||||
|
||||
```rust
|
||||
// See examples/onnx-embeddings for complete implementation
|
||||
use ort::{Session, Environment, Value};
|
||||
|
||||
struct OnnxEmbedding {
|
||||
session: Session,
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
impl EmbeddingProvider for OnnxEmbedding {
|
||||
fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
// Tokenize text
|
||||
// Run ONNX inference
|
||||
// Return embeddings
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"OnnxEmbedding"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Exporting Models to ONNX
|
||||
|
||||
```python
|
||||
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_id = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
model.save_pretrained("./onnx-model")
|
||||
tokenizer.save_pretrained("./onnx-model")
|
||||
```
|
||||
|
||||
## Feature Flags
|
||||
|
||||
### `real-embeddings` (Optional)
|
||||
|
||||
This feature flag enables the `CandleEmbedding` type (currently a stub):
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
ruvector-core = { version = "0.1", features = ["real-embeddings"] }
|
||||
```
|
||||
|
||||
However, we recommend using API-based providers instead of implementing Candle integration yourself.
|
||||
|
||||
## Complete Example
|
||||
|
||||
```rust
|
||||
use ruvector_core::{AgenticDB, ApiEmbedding, types::DbOptions};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Setup
|
||||
let api_key = std::env::var("OPENAI_API_KEY")?;
|
||||
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 1536;
|
||||
options.storage_path = "agenticdb.db".to_string();
|
||||
|
||||
let db = AgenticDB::with_embedding_provider(options, provider)?;
|
||||
|
||||
println!("Using: {}", db.embedding_provider_name());
|
||||
|
||||
// Store reflexion episodes
|
||||
let ep1 = db.store_episode(
|
||||
"Debug memory leak in Rust".to_string(),
|
||||
vec!["profile".to_string(), "find leak".to_string()],
|
||||
vec!["fixed with Arc".to_string()],
|
||||
"Should explain reference counting".to_string(),
|
||||
)?;
|
||||
|
||||
let ep2 = db.store_episode(
|
||||
"Optimize Python performance".to_string(),
|
||||
vec!["profile".to_string(), "vectorize".to_string()],
|
||||
vec!["10x speedup".to_string()],
|
||||
"Should mention NumPy".to_string(),
|
||||
)?;
|
||||
|
||||
// Semantic search - will find Rust episode for memory-related query
|
||||
let episodes = db.retrieve_similar_episodes("memory management", 5)?;
|
||||
for episode in episodes {
|
||||
println!("Task: {}", episode.task);
|
||||
println!("Critique: {}", episode.critique);
|
||||
}
|
||||
|
||||
// Create skills
|
||||
db.create_skill(
|
||||
"Memory Profiling".to_string(),
|
||||
"Profile application memory usage to find leaks".to_string(),
|
||||
Default::default(),
|
||||
vec!["valgrind".to_string(), "massif".to_string()],
|
||||
)?;
|
||||
|
||||
// Search skills semantically
|
||||
let skills = db.search_skills("finding memory leaks", 3)?;
|
||||
for skill in skills {
|
||||
println!("Skill: {} - {}", skill.name, skill.description);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### API-based (OpenAI, Cohere, Voyage)
|
||||
- **Pros**: Always up-to-date, no model storage, easy to use
|
||||
- **Cons**: Network latency, API costs, requires internet
|
||||
- **Best for**: Production apps with internet access
|
||||
|
||||
### ONNX Runtime (Local)
|
||||
- **Pros**: No API costs, offline support, fast inference
|
||||
- **Cons**: Model storage (~100MB), setup complexity
|
||||
- **Best for**: Edge deployment, high-volume apps
|
||||
|
||||
### Hash-based (Default)
|
||||
- **Pros**: Zero dependencies, instant, no setup
|
||||
- **Cons**: Not semantic, only for testing
|
||||
- **Best for**: Development, unit tests
|
||||
|
||||
## Recommendations
|
||||
|
||||
1. **Development/Testing**: Use hash-based (default)
|
||||
2. **Production (Cloud)**: Use `ApiEmbedding::openai()`
|
||||
3. **Production (Edge/Offline)**: Implement ONNX provider
|
||||
4. **Custom Models**: Implement `EmbeddingProvider` trait
|
||||
|
||||
## Migration Path
|
||||
|
||||
```rust
|
||||
// Start with hash for development
|
||||
let db = AgenticDB::new(options)?;
|
||||
|
||||
// Switch to API for staging
|
||||
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
|
||||
let db = AgenticDB::with_embedding_provider(options, provider)?;
|
||||
|
||||
// Move to ONNX for production scale
|
||||
let provider = Arc::new(OnnxEmbedding::from_file("model.onnx")?);
|
||||
let db = AgenticDB::with_embedding_provider(options, provider)?;
|
||||
```
|
||||
|
||||
The beauty is: **your AgenticDB code doesn't change**, just the provider!
|
||||
|
||||
## Error Handling
|
||||
|
||||
```rust
|
||||
use ruvector_core::error::RuvectorError;
|
||||
|
||||
match AgenticDB::with_embedding_provider(options, provider) {
|
||||
Ok(db) => {
|
||||
// Use db
|
||||
}
|
||||
Err(RuvectorError::InvalidDimension(msg)) => {
|
||||
eprintln!("Dimension mismatch: {}", msg);
|
||||
}
|
||||
Err(RuvectorError::ModelLoadError(msg)) => {
|
||||
eprintln!("Failed to load model: {}", msg);
|
||||
}
|
||||
Err(RuvectorError::ModelInferenceError(msg)) => {
|
||||
eprintln!("Inference failed: {}", msg);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error: {}", e);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [AgenticDB API Documentation](../src/agenticdb.rs)
|
||||
- [Embedding Provider Trait](../src/embeddings.rs)
|
||||
- [ONNX Examples](../../examples/onnx-embeddings/)
|
||||
- [Integration Tests](../tests/embeddings_test.rs)
|
||||
184
crates/ruvector-core/examples/embeddings_example.rs
Normal file
184
crates/ruvector-core/examples/embeddings_example.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
//! Example of using different embedding providers with AgenticDB
|
||||
//!
|
||||
//! Run with:
|
||||
//! ```bash
|
||||
//! # Default hash-based (testing only)
|
||||
//! cargo run --example embeddings_example
|
||||
//!
|
||||
//! # With OpenAI API (requires OPENAI_API_KEY env var)
|
||||
//! OPENAI_API_KEY=sk-... cargo run --example embeddings_example --features real-embeddings
|
||||
//! ```
|
||||
|
||||
use ruvector_core::types::DbOptions;
|
||||
use ruvector_core::{AgenticDB, ApiEmbedding, HashEmbedding};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("=== AgenticDB Embeddings Example ===\n");
|
||||
|
||||
// Determine which provider to use
|
||||
let use_api = std::env::var("OPENAI_API_KEY").is_ok();
|
||||
|
||||
let (db, provider_name) = if use_api {
|
||||
println!("Using OpenAI API embeddings (real semantic search)");
|
||||
let api_key = std::env::var("OPENAI_API_KEY")?;
|
||||
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 1536; // OpenAI text-embedding-3-small
|
||||
options.storage_path = "/tmp/agenticdb_api.db".to_string();
|
||||
|
||||
let db = AgenticDB::with_embedding_provider(options, provider)?;
|
||||
(db, "OpenAI API")
|
||||
} else {
|
||||
println!("Using hash-based embeddings (testing only - not semantic)");
|
||||
println!("Set OPENAI_API_KEY to use real embeddings\n");
|
||||
|
||||
let mut options = DbOptions::default();
|
||||
options.dimensions = 128;
|
||||
options.storage_path = "/tmp/agenticdb_hash.db".to_string();
|
||||
|
||||
let db = AgenticDB::new(options)?;
|
||||
(db, "Hash-based")
|
||||
};
|
||||
|
||||
println!("Provider: {}\n", db.embedding_provider_name());
|
||||
|
||||
// Store some reflexion episodes
|
||||
println!("--- Storing Reflexion Episodes ---");
|
||||
|
||||
let ep1 = db.store_episode(
|
||||
"Fix Rust borrow checker error".to_string(),
|
||||
vec![
|
||||
"Identified lifetime issue".to_string(),
|
||||
"Added explicit lifetime annotations".to_string(),
|
||||
"Refactored to use references".to_string(),
|
||||
],
|
||||
vec!["Code compiles now".to_string()],
|
||||
"Should explain borrow checker rules better".to_string(),
|
||||
)?;
|
||||
println!(
|
||||
"✓ Stored episode: Fix Rust borrow checker error (ID: {})",
|
||||
ep1
|
||||
);
|
||||
|
||||
let ep2 = db.store_episode(
|
||||
"Optimize Python data processing".to_string(),
|
||||
vec![
|
||||
"Profiled with cProfile".to_string(),
|
||||
"Vectorized with NumPy".to_string(),
|
||||
"Parallelized with multiprocessing".to_string(),
|
||||
],
|
||||
vec!["10x performance improvement".to_string()],
|
||||
"Could have used Pandas for better readability".to_string(),
|
||||
)?;
|
||||
println!(
|
||||
"✓ Stored episode: Optimize Python data processing (ID: {})",
|
||||
ep2
|
||||
);
|
||||
|
||||
let ep3 = db.store_episode(
|
||||
"Debug JavaScript async issue".to_string(),
|
||||
vec![
|
||||
"Added console.log statements".to_string(),
|
||||
"Used Chrome DevTools debugger".to_string(),
|
||||
"Fixed Promise chain".to_string(),
|
||||
],
|
||||
vec!["Race condition resolved".to_string()],
|
||||
"Should use async/await instead of callbacks".to_string(),
|
||||
)?;
|
||||
println!(
|
||||
"✓ Stored episode: Debug JavaScript async issue (ID: {})\n",
|
||||
ep3
|
||||
);
|
||||
|
||||
// Create some skills
|
||||
println!("--- Creating Skills ---");
|
||||
|
||||
let skill1 = db.create_skill(
|
||||
"Memory Profiling".to_string(),
|
||||
"Profile application memory usage to detect leaks and optimize allocation".to_string(),
|
||||
Default::default(),
|
||||
vec![
|
||||
"valgrind".to_string(),
|
||||
"massif".to_string(),
|
||||
"heaptrack".to_string(),
|
||||
],
|
||||
)?;
|
||||
println!("✓ Created skill: Memory Profiling (ID: {})", skill1);
|
||||
|
||||
let skill2 = db.create_skill(
|
||||
"Async Programming".to_string(),
|
||||
"Write asynchronous code using promises, async/await, or futures".to_string(),
|
||||
Default::default(),
|
||||
vec![
|
||||
"Promise.all()".to_string(),
|
||||
"async/await".to_string(),
|
||||
"tokio".to_string(),
|
||||
],
|
||||
)?;
|
||||
println!("✓ Created skill: Async Programming (ID: {})", skill2);
|
||||
|
||||
let skill3 = db.create_skill(
|
||||
"Performance Optimization".to_string(),
|
||||
"Profile and optimize code performance using profilers and benchmarks".to_string(),
|
||||
Default::default(),
|
||||
vec![
|
||||
"perf".to_string(),
|
||||
"criterion".to_string(),
|
||||
"flamegraph".to_string(),
|
||||
],
|
||||
)?;
|
||||
println!(
|
||||
"✓ Created skill: Performance Optimization (ID: {})\n",
|
||||
skill3
|
||||
);
|
||||
|
||||
// Search episodes
|
||||
println!("--- Searching Episodes ---");
|
||||
let query = "memory problems in programming";
|
||||
println!("Query: \"{}\"", query);
|
||||
|
||||
let episodes = db.retrieve_similar_episodes(query, 3)?;
|
||||
println!("Found {} similar episodes:\n", episodes.len());
|
||||
|
||||
for (i, episode) in episodes.iter().enumerate() {
|
||||
println!("{}. Task: {}", i + 1, episode.task);
|
||||
println!(" Critique: {}", episode.critique);
|
||||
println!(" Actions: {}", episode.actions.join(" → "));
|
||||
println!();
|
||||
}
|
||||
|
||||
if use_api {
|
||||
println!("ℹ️ With OpenAI embeddings, results are semantically similar!");
|
||||
println!(" 'memory problems' should match 'Rust borrow checker' and 'memory profiling'");
|
||||
} else {
|
||||
println!("⚠️ Hash-based embeddings are NOT semantic!");
|
||||
println!(" Results are based on character overlap, not meaning.");
|
||||
println!(" Set OPENAI_API_KEY to see real semantic search.");
|
||||
}
|
||||
|
||||
// Search skills
|
||||
println!("\n--- Searching Skills ---");
|
||||
let query = "handling asynchronous operations";
|
||||
println!("Query: \"{}\"", query);
|
||||
|
||||
let skills = db.search_skills(query, 3)?;
|
||||
println!("Found {} similar skills:\n", skills.len());
|
||||
|
||||
for (i, skill) in skills.iter().enumerate() {
|
||||
println!("{}. {}", i + 1, skill.name);
|
||||
println!(" Description: {}", skill.description);
|
||||
println!(" Examples: {}", skill.examples.join(", "));
|
||||
println!();
|
||||
}
|
||||
|
||||
println!("=== Example Complete ===");
|
||||
println!("\nTips:");
|
||||
println!("- Use hash-based embeddings for testing/development");
|
||||
println!("- Use API embeddings (OpenAI, Cohere, Voyage) for production");
|
||||
println!("- Implement ONNX provider for offline/edge deployment");
|
||||
println!("- See docs/EMBEDDINGS.md for full guide");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
264
crates/ruvector-core/examples/neon_benchmark.rs
Normal file
264
crates/ruvector-core/examples/neon_benchmark.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
//! Quick benchmark to compare NEON SIMD vs scalar performance on Apple Silicon
|
||||
//!
|
||||
//! Run with: cargo run --example neon_benchmark --release -p ruvector-core
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
fn main() {
|
||||
println!("╔════════════════════════════════════════════════════════════╗");
|
||||
println!("║ NEON SIMD Benchmark for Apple Silicon (M4 Pro) ║");
|
||||
println!("╚════════════════════════════════════════════════════════════╝\n");
|
||||
|
||||
// Test parameters
|
||||
let dimensions = 128; // Common embedding dimension
|
||||
let num_vectors = 10_000;
|
||||
let num_queries = 1_000;
|
||||
|
||||
// Generate test data
|
||||
let vectors: Vec<Vec<f32>> = (0..num_vectors)
|
||||
.map(|i| {
|
||||
(0..dimensions)
|
||||
.map(|j| ((i * j) % 1000) as f32 / 1000.0)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let queries: Vec<Vec<f32>> = (0..num_queries)
|
||||
.map(|i| {
|
||||
(0..dimensions)
|
||||
.map(|j| ((i * j + 500) % 1000) as f32 / 1000.0)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
println!("Configuration:");
|
||||
println!(" - Dimensions: {}", dimensions);
|
||||
println!(" - Vectors: {}", num_vectors);
|
||||
println!(" - Queries: {}", num_queries);
|
||||
println!(
|
||||
" - Total distance calculations: {}\n",
|
||||
num_vectors * num_queries
|
||||
);
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
println!("Platform: ARM64 (Apple Silicon) - NEON enabled ✓\n");
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
println!("Platform: x86_64 - AVX2 detection enabled\n");
|
||||
|
||||
// Benchmark Euclidean distance (SIMD)
|
||||
println!("═══════════════════════════════════════════════════════════════");
|
||||
println!("Euclidean Distance:");
|
||||
println!("═══════════════════════════════════════════════════════════════");
|
||||
|
||||
let start = Instant::now();
|
||||
let mut simd_sum = 0.0f32;
|
||||
for query in &queries {
|
||||
for vec in &vectors {
|
||||
simd_sum += euclidean_simd(query, vec);
|
||||
}
|
||||
}
|
||||
let simd_time = start.elapsed();
|
||||
println!(
|
||||
" SIMD: {:>8.2} ms (checksum: {:.4})",
|
||||
simd_time.as_secs_f64() * 1000.0,
|
||||
simd_sum
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let mut scalar_sum = 0.0f32;
|
||||
for query in &queries {
|
||||
for vec in &vectors {
|
||||
scalar_sum += euclidean_scalar(query, vec);
|
||||
}
|
||||
}
|
||||
let scalar_time = start.elapsed();
|
||||
println!(
|
||||
" Scalar: {:>8.2} ms (checksum: {:.4})",
|
||||
scalar_time.as_secs_f64() * 1000.0,
|
||||
scalar_sum
|
||||
);
|
||||
|
||||
let speedup = scalar_time.as_secs_f64() / simd_time.as_secs_f64();
|
||||
println!(" Speedup: {:.2}x\n", speedup);
|
||||
|
||||
// Benchmark Dot Product (SIMD)
|
||||
println!("═══════════════════════════════════════════════════════════════");
|
||||
println!("Dot Product:");
|
||||
println!("═══════════════════════════════════════════════════════════════");
|
||||
|
||||
let start = Instant::now();
|
||||
let mut simd_sum = 0.0f32;
|
||||
for query in &queries {
|
||||
for vec in &vectors {
|
||||
simd_sum += dot_simd(query, vec);
|
||||
}
|
||||
}
|
||||
let simd_time = start.elapsed();
|
||||
println!(
|
||||
" SIMD: {:>8.2} ms (checksum: {:.4})",
|
||||
simd_time.as_secs_f64() * 1000.0,
|
||||
simd_sum
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let mut scalar_sum = 0.0f32;
|
||||
for query in &queries {
|
||||
for vec in &vectors {
|
||||
scalar_sum += dot_scalar(query, vec);
|
||||
}
|
||||
}
|
||||
let scalar_time = start.elapsed();
|
||||
println!(
|
||||
" Scalar: {:>8.2} ms (checksum: {:.4})",
|
||||
scalar_time.as_secs_f64() * 1000.0,
|
||||
scalar_sum
|
||||
);
|
||||
|
||||
let speedup = scalar_time.as_secs_f64() / simd_time.as_secs_f64();
|
||||
println!(" Speedup: {:.2}x\n", speedup);
|
||||
|
||||
// Benchmark Cosine Similarity (SIMD)
|
||||
println!("═══════════════════════════════════════════════════════════════");
|
||||
println!("Cosine Similarity:");
|
||||
println!("═══════════════════════════════════════════════════════════════");
|
||||
|
||||
let start = Instant::now();
|
||||
let mut simd_sum = 0.0f32;
|
||||
for query in &queries {
|
||||
for vec in &vectors {
|
||||
simd_sum += cosine_simd(query, vec);
|
||||
}
|
||||
}
|
||||
let simd_time = start.elapsed();
|
||||
println!(
|
||||
" SIMD: {:>8.2} ms (checksum: {:.4})",
|
||||
simd_time.as_secs_f64() * 1000.0,
|
||||
simd_sum
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let mut scalar_sum = 0.0f32;
|
||||
for query in &queries {
|
||||
for vec in &vectors {
|
||||
scalar_sum += cosine_scalar(query, vec);
|
||||
}
|
||||
}
|
||||
let scalar_time = start.elapsed();
|
||||
println!(
|
||||
" Scalar: {:>8.2} ms (checksum: {:.4})",
|
||||
scalar_time.as_secs_f64() * 1000.0,
|
||||
scalar_sum
|
||||
);
|
||||
|
||||
let speedup = scalar_time.as_secs_f64() / simd_time.as_secs_f64();
|
||||
println!(" Speedup: {:.2}x\n", speedup);
|
||||
|
||||
println!("═══════════════════════════════════════════════════════════════");
|
||||
println!("Benchmark complete!");
|
||||
}
|
||||
|
||||
// SIMD implementations (use the crate's SIMD functions)
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
use std::arch::aarch64::*;
|
||||
|
||||
#[inline]
|
||||
fn euclidean_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
unsafe {
|
||||
let len = a.len();
|
||||
let mut sum = vdupq_n_f32(0.0);
|
||||
let chunks = len / 4;
|
||||
for i in 0..chunks {
|
||||
let idx = i * 4;
|
||||
let va = vld1q_f32(a.as_ptr().add(idx));
|
||||
let vb = vld1q_f32(b.as_ptr().add(idx));
|
||||
let diff = vsubq_f32(va, vb);
|
||||
sum = vfmaq_f32(sum, diff, diff);
|
||||
}
|
||||
let mut total = vaddvq_f32(sum);
|
||||
for i in (chunks * 4)..len {
|
||||
let diff = a[i] - b[i];
|
||||
total += diff * diff;
|
||||
}
|
||||
total.sqrt()
|
||||
}
|
||||
#[cfg(not(target_arch = "aarch64"))]
|
||||
euclidean_scalar(a, b)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y) * (x - y))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
unsafe {
|
||||
let len = a.len();
|
||||
let mut sum = vdupq_n_f32(0.0);
|
||||
let chunks = len / 4;
|
||||
for i in 0..chunks {
|
||||
let idx = i * 4;
|
||||
let va = vld1q_f32(a.as_ptr().add(idx));
|
||||
let vb = vld1q_f32(b.as_ptr().add(idx));
|
||||
sum = vfmaq_f32(sum, va, vb);
|
||||
}
|
||||
let mut total = vaddvq_f32(sum);
|
||||
for i in (chunks * 4)..len {
|
||||
total += a[i] * b[i];
|
||||
}
|
||||
total
|
||||
}
|
||||
#[cfg(not(target_arch = "aarch64"))]
|
||||
dot_scalar(a, b)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn cosine_simd(a: &[f32], b: &[f32]) -> f32 {
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
unsafe {
|
||||
let len = a.len();
|
||||
let mut dot = vdupq_n_f32(0.0);
|
||||
let mut norm_a = vdupq_n_f32(0.0);
|
||||
let mut norm_b = vdupq_n_f32(0.0);
|
||||
let chunks = len / 4;
|
||||
for i in 0..chunks {
|
||||
let idx = i * 4;
|
||||
let va = vld1q_f32(a.as_ptr().add(idx));
|
||||
let vb = vld1q_f32(b.as_ptr().add(idx));
|
||||
dot = vfmaq_f32(dot, va, vb);
|
||||
norm_a = vfmaq_f32(norm_a, va, va);
|
||||
norm_b = vfmaq_f32(norm_b, vb, vb);
|
||||
}
|
||||
let mut dot_sum = vaddvq_f32(dot);
|
||||
let mut norm_a_sum = vaddvq_f32(norm_a);
|
||||
let mut norm_b_sum = vaddvq_f32(norm_b);
|
||||
for i in (chunks * 4)..len {
|
||||
dot_sum += a[i] * b[i];
|
||||
norm_a_sum += a[i] * a[i];
|
||||
norm_b_sum += b[i] * b[i];
|
||||
}
|
||||
dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
|
||||
}
|
||||
#[cfg(not(target_arch = "aarch64"))]
|
||||
cosine_scalar(a, b)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
545
crates/ruvector-core/src/advanced/hypergraph.rs
Normal file
545
crates/ruvector-core/src/advanced/hypergraph.rs
Normal file
@@ -0,0 +1,545 @@
|
||||
//! # Hypergraph Support for N-ary Relationships
|
||||
//!
|
||||
//! Implements hypergraph structures for representing complex multi-entity relationships
|
||||
//! beyond traditional pairwise similarity. Based on HyperGraphRAG (NeurIPS 2025) architecture.
|
||||
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use crate::types::{DistanceMetric, VectorId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Hyperedge connecting multiple vectors with description and embedding
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Hyperedge {
|
||||
/// Unique identifier for the hyperedge
|
||||
pub id: String,
|
||||
/// Vector IDs connected by this hyperedge
|
||||
pub nodes: Vec<VectorId>,
|
||||
/// Natural language description of the relationship
|
||||
pub description: String,
|
||||
/// Embedding of the hyperedge description
|
||||
pub embedding: Vec<f32>,
|
||||
/// Confidence weight (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// Optional metadata
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Temporal hyperedge with time attributes
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TemporalHyperedge {
|
||||
/// Base hyperedge
|
||||
pub hyperedge: Hyperedge,
|
||||
/// Creation timestamp (Unix epoch seconds)
|
||||
pub timestamp: u64,
|
||||
/// Optional expiration timestamp
|
||||
pub expires_at: Option<u64>,
|
||||
/// Temporal context (hourly, daily, monthly)
|
||||
pub granularity: TemporalGranularity,
|
||||
}
|
||||
|
||||
/// Temporal granularity for indexing
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum TemporalGranularity {
|
||||
Hourly,
|
||||
Daily,
|
||||
Monthly,
|
||||
Yearly,
|
||||
}
|
||||
|
||||
impl Hyperedge {
|
||||
/// Create a new hyperedge
|
||||
pub fn new(
|
||||
nodes: Vec<VectorId>,
|
||||
description: String,
|
||||
embedding: Vec<f32>,
|
||||
confidence: f32,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
nodes,
|
||||
description,
|
||||
embedding,
|
||||
confidence: confidence.clamp(0.0, 1.0),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get hyperedge order (number of nodes)
|
||||
pub fn order(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Check if hyperedge contains a specific node
|
||||
pub fn contains_node(&self, node: &VectorId) -> bool {
|
||||
self.nodes.contains(node)
|
||||
}
|
||||
}
|
||||
|
||||
impl TemporalHyperedge {
|
||||
/// Create a new temporal hyperedge with current timestamp
|
||||
pub fn new(hyperedge: Hyperedge, granularity: TemporalGranularity) -> Self {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
Self {
|
||||
hyperedge,
|
||||
timestamp,
|
||||
expires_at: None,
|
||||
granularity,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if hyperedge is expired
|
||||
pub fn is_expired(&self) -> bool {
|
||||
if let Some(expires_at) = self.expires_at {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
now > expires_at
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get time bucket for indexing
|
||||
pub fn time_bucket(&self) -> u64 {
|
||||
match self.granularity {
|
||||
TemporalGranularity::Hourly => self.timestamp / 3600,
|
||||
TemporalGranularity::Daily => self.timestamp / 86400,
|
||||
TemporalGranularity::Monthly => self.timestamp / (86400 * 30),
|
||||
TemporalGranularity::Yearly => self.timestamp / (86400 * 365),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hypergraph index with bipartite graph storage
|
||||
pub struct HypergraphIndex {
|
||||
/// Entity nodes
|
||||
entities: HashMap<VectorId, Vec<f32>>,
|
||||
/// Hyperedges
|
||||
hyperedges: HashMap<String, Hyperedge>,
|
||||
/// Temporal hyperedges indexed by time bucket
|
||||
temporal_index: HashMap<u64, Vec<String>>,
|
||||
/// Bipartite graph: entity -> hyperedge IDs
|
||||
entity_to_hyperedges: HashMap<VectorId, HashSet<String>>,
|
||||
/// Bipartite graph: hyperedge -> entity IDs
|
||||
hyperedge_to_entities: HashMap<String, HashSet<VectorId>>,
|
||||
/// Distance metric for embeddings
|
||||
distance_metric: DistanceMetric,
|
||||
}
|
||||
|
||||
impl HypergraphIndex {
|
||||
/// Create a new hypergraph index
|
||||
pub fn new(distance_metric: DistanceMetric) -> Self {
|
||||
Self {
|
||||
entities: HashMap::new(),
|
||||
hyperedges: HashMap::new(),
|
||||
temporal_index: HashMap::new(),
|
||||
entity_to_hyperedges: HashMap::new(),
|
||||
hyperedge_to_entities: HashMap::new(),
|
||||
distance_metric,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an entity node
|
||||
pub fn add_entity(&mut self, id: VectorId, embedding: Vec<f32>) {
|
||||
self.entities.insert(id.clone(), embedding);
|
||||
self.entity_to_hyperedges.entry(id).or_default();
|
||||
}
|
||||
|
||||
/// Add a hyperedge
|
||||
pub fn add_hyperedge(&mut self, hyperedge: Hyperedge) -> Result<()> {
|
||||
let edge_id = hyperedge.id.clone();
|
||||
|
||||
// Verify all nodes exist
|
||||
for node in &hyperedge.nodes {
|
||||
if !self.entities.contains_key(node) {
|
||||
return Err(RuvectorError::InvalidInput(format!(
|
||||
"Entity {} not found in hypergraph",
|
||||
node
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Update bipartite graph
|
||||
for node in &hyperedge.nodes {
|
||||
self.entity_to_hyperedges
|
||||
.entry(node.clone())
|
||||
.or_default()
|
||||
.insert(edge_id.clone());
|
||||
}
|
||||
|
||||
let nodes_set: HashSet<VectorId> = hyperedge.nodes.iter().cloned().collect();
|
||||
self.hyperedge_to_entities
|
||||
.insert(edge_id.clone(), nodes_set);
|
||||
|
||||
self.hyperedges.insert(edge_id, hyperedge);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a temporal hyperedge
|
||||
pub fn add_temporal_hyperedge(&mut self, temporal_edge: TemporalHyperedge) -> Result<()> {
|
||||
let bucket = temporal_edge.time_bucket();
|
||||
let edge_id = temporal_edge.hyperedge.id.clone();
|
||||
|
||||
self.add_hyperedge(temporal_edge.hyperedge)?;
|
||||
|
||||
self.temporal_index.entry(bucket).or_default().push(edge_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search hyperedges by embedding similarity
|
||||
pub fn search_hyperedges(&self, query_embedding: &[f32], k: usize) -> Vec<(String, f32)> {
|
||||
let mut results: Vec<(String, f32)> = self
|
||||
.hyperedges
|
||||
.iter()
|
||||
.map(|(id, edge)| {
|
||||
let distance = self.compute_distance(query_embedding, &edge.embedding);
|
||||
(id.clone(), distance)
|
||||
})
|
||||
.collect();
|
||||
|
||||
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
results.truncate(k);
|
||||
results
|
||||
}
|
||||
|
||||
/// Get k-hop neighbors in hypergraph
|
||||
/// Returns all nodes reachable within k hops from the start node
|
||||
pub fn k_hop_neighbors(&self, start_node: VectorId, k: usize) -> HashSet<VectorId> {
|
||||
let mut visited = HashSet::new();
|
||||
let mut current_layer = HashSet::new();
|
||||
current_layer.insert(start_node.clone());
|
||||
visited.insert(start_node); // Start node is at distance 0
|
||||
|
||||
for _hop in 0..k {
|
||||
let mut next_layer = HashSet::new();
|
||||
|
||||
for node in current_layer.iter() {
|
||||
// Get all hyperedges containing this node
|
||||
if let Some(hyperedges) = self.entity_to_hyperedges.get(node) {
|
||||
for edge_id in hyperedges {
|
||||
// Get all nodes in this hyperedge
|
||||
if let Some(nodes) = self.hyperedge_to_entities.get(edge_id) {
|
||||
for neighbor in nodes.iter() {
|
||||
if !visited.contains(neighbor) {
|
||||
visited.insert(neighbor.clone());
|
||||
next_layer.insert(neighbor.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if next_layer.is_empty() {
|
||||
break;
|
||||
}
|
||||
current_layer = next_layer;
|
||||
}
|
||||
|
||||
visited
|
||||
}
|
||||
|
||||
/// Query temporal hyperedges in a time range
|
||||
pub fn query_temporal_range(&self, start_bucket: u64, end_bucket: u64) -> Vec<String> {
|
||||
let mut results = Vec::new();
|
||||
for bucket in start_bucket..=end_bucket {
|
||||
if let Some(edges) = self.temporal_index.get(&bucket) {
|
||||
results.extend(edges.iter().cloned());
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
/// Get hyperedge by ID
|
||||
pub fn get_hyperedge(&self, id: &str) -> Option<&Hyperedge> {
|
||||
self.hyperedges.get(id)
|
||||
}
|
||||
|
||||
/// Get statistics
|
||||
pub fn stats(&self) -> HypergraphStats {
|
||||
let total_edges = self.hyperedges.len();
|
||||
let total_entities = self.entities.len();
|
||||
let avg_degree = if total_entities > 0 {
|
||||
self.entity_to_hyperedges
|
||||
.values()
|
||||
.map(|edges| edges.len())
|
||||
.sum::<usize>() as f32
|
||||
/ total_entities as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
HypergraphStats {
|
||||
total_entities,
|
||||
total_hyperedges: total_edges,
|
||||
avg_entity_degree: avg_degree,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
crate::distance::distance(a, b, self.distance_metric).unwrap_or(f32::MAX)
|
||||
}
|
||||
}
|
||||
|
||||
/// Hypergraph statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HypergraphStats {
|
||||
pub total_entities: usize,
|
||||
pub total_hyperedges: usize,
|
||||
pub avg_entity_degree: f32,
|
||||
}
|
||||
|
||||
/// Causal hypergraph memory for agent reasoning
|
||||
pub struct CausalMemory {
|
||||
/// Hypergraph index
|
||||
index: HypergraphIndex,
|
||||
/// Causal relationship tracking: (cause_id, effect_id) -> success_count
|
||||
causal_counts: HashMap<(VectorId, VectorId), u32>,
|
||||
/// Action latencies: action_id -> avg_latency_ms
|
||||
latencies: HashMap<VectorId, f32>,
|
||||
/// Utility function weights
|
||||
alpha: f32, // similarity weight
|
||||
beta: f32, // causal uplift weight
|
||||
gamma: f32, // latency penalty weight
|
||||
}
|
||||
|
||||
impl CausalMemory {
|
||||
/// Create a new causal memory with default utility weights
|
||||
pub fn new(distance_metric: DistanceMetric) -> Self {
|
||||
Self {
|
||||
index: HypergraphIndex::new(distance_metric),
|
||||
causal_counts: HashMap::new(),
|
||||
latencies: HashMap::new(),
|
||||
alpha: 0.7,
|
||||
beta: 0.2,
|
||||
gamma: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set custom utility function weights
|
||||
pub fn with_weights(mut self, alpha: f32, beta: f32, gamma: f32) -> Self {
|
||||
self.alpha = alpha;
|
||||
self.beta = beta;
|
||||
self.gamma = gamma;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a causal relationship
|
||||
pub fn add_causal_edge(
|
||||
&mut self,
|
||||
cause: VectorId,
|
||||
effect: VectorId,
|
||||
context: Vec<VectorId>,
|
||||
description: String,
|
||||
embedding: Vec<f32>,
|
||||
latency_ms: f32,
|
||||
) -> Result<()> {
|
||||
// Create hyperedge connecting cause, effect, and context
|
||||
let mut nodes = vec![cause.clone(), effect.clone()];
|
||||
nodes.extend(context);
|
||||
|
||||
let hyperedge = Hyperedge::new(nodes, description, embedding, 1.0);
|
||||
self.index.add_hyperedge(hyperedge)?;
|
||||
|
||||
// Update causal counts
|
||||
*self
|
||||
.causal_counts
|
||||
.entry((cause.clone(), effect.clone()))
|
||||
.or_insert(0) += 1;
|
||||
|
||||
// Update latency
|
||||
let entry = self.latencies.entry(cause).or_insert(0.0);
|
||||
*entry = (*entry + latency_ms) / 2.0; // Running average
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Query with utility function: U = α·similarity + β·causal_uplift - γ·latency
|
||||
pub fn query_with_utility(
|
||||
&self,
|
||||
query_embedding: &[f32],
|
||||
action_id: VectorId,
|
||||
k: usize,
|
||||
) -> Vec<(String, f32)> {
|
||||
let mut results: Vec<(String, f32)> = self
|
||||
.index
|
||||
.hyperedges
|
||||
.iter()
|
||||
.filter(|(_, edge)| edge.contains_node(&action_id))
|
||||
.map(|(id, edge)| {
|
||||
let similarity = 1.0
|
||||
- self
|
||||
.index
|
||||
.compute_distance(query_embedding, &edge.embedding);
|
||||
let causal_uplift = self.compute_causal_uplift(&edge.nodes);
|
||||
let latency = self.latencies.get(&action_id).copied().unwrap_or(0.0);
|
||||
|
||||
let utility = self.alpha * similarity + self.beta * causal_uplift
|
||||
- self.gamma * (latency / 1000.0); // Normalize latency to 0-1 range
|
||||
|
||||
(id.clone(), utility)
|
||||
})
|
||||
.collect();
|
||||
|
||||
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); // Sort by utility descending
|
||||
results.truncate(k);
|
||||
results
|
||||
}
|
||||
|
||||
fn compute_causal_uplift(&self, nodes: &[VectorId]) -> f32 {
|
||||
if nodes.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute average causal strength for pairs in this hyperedge
|
||||
let mut total_uplift = 0.0;
|
||||
let mut count = 0;
|
||||
|
||||
for i in 0..nodes.len() - 1 {
|
||||
for j in i + 1..nodes.len() {
|
||||
if let Some(&success_count) = self
|
||||
.causal_counts
|
||||
.get(&(nodes[i].clone(), nodes[j].clone()))
|
||||
{
|
||||
total_uplift += (success_count as f32).ln_1p(); // Log scale
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
total_uplift / count as f32
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Get hypergraph index
|
||||
pub fn index(&self) -> &HypergraphIndex {
|
||||
&self.index
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hyperedge_creation() {
|
||||
let nodes = vec!["1".to_string(), "2".to_string(), "3".to_string()];
|
||||
let desc = "Test relationship".to_string();
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let edge = Hyperedge::new(nodes, desc, embedding, 0.95);
|
||||
|
||||
assert_eq!(edge.order(), 3);
|
||||
assert!(edge.contains_node(&"1".to_string()));
|
||||
assert!(!edge.contains_node(&"4".to_string()));
|
||||
assert_eq!(edge.confidence, 0.95);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_temporal_hyperedge() {
|
||||
let nodes = vec!["1".to_string(), "2".to_string()];
|
||||
let desc = "Temporal relationship".to_string();
|
||||
let embedding = vec![0.1, 0.2];
|
||||
let edge = Hyperedge::new(nodes, desc, embedding, 1.0);
|
||||
|
||||
let temporal = TemporalHyperedge::new(edge, TemporalGranularity::Hourly);
|
||||
|
||||
assert!(!temporal.is_expired());
|
||||
assert!(temporal.time_bucket() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hypergraph_index() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Add entities
|
||||
index.add_entity("1".to_string(), vec![1.0, 0.0, 0.0]);
|
||||
index.add_entity("2".to_string(), vec![0.0, 1.0, 0.0]);
|
||||
index.add_entity("3".to_string(), vec![0.0, 0.0, 1.0]);
|
||||
|
||||
// Add hyperedge
|
||||
let edge = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string(), "3".to_string()],
|
||||
"Triple relationship".to_string(),
|
||||
vec![0.5, 0.5, 0.5],
|
||||
0.9,
|
||||
);
|
||||
index.add_hyperedge(edge).unwrap();
|
||||
|
||||
let stats = index.stats();
|
||||
assert_eq!(stats.total_entities, 3);
|
||||
assert_eq!(stats.total_hyperedges, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_k_hop_neighbors() {
|
||||
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
|
||||
|
||||
// Create a small hypergraph
|
||||
index.add_entity("1".to_string(), vec![1.0]);
|
||||
index.add_entity("2".to_string(), vec![1.0]);
|
||||
index.add_entity("3".to_string(), vec![1.0]);
|
||||
index.add_entity("4".to_string(), vec![1.0]);
|
||||
|
||||
let edge1 = Hyperedge::new(
|
||||
vec!["1".to_string(), "2".to_string()],
|
||||
"e1".to_string(),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
let edge2 = Hyperedge::new(
|
||||
vec!["2".to_string(), "3".to_string()],
|
||||
"e2".to_string(),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
let edge3 = Hyperedge::new(
|
||||
vec!["3".to_string(), "4".to_string()],
|
||||
"e3".to_string(),
|
||||
vec![1.0],
|
||||
1.0,
|
||||
);
|
||||
|
||||
index.add_hyperedge(edge1).unwrap();
|
||||
index.add_hyperedge(edge2).unwrap();
|
||||
index.add_hyperedge(edge3).unwrap();
|
||||
|
||||
let neighbors = index.k_hop_neighbors("1".to_string(), 2);
|
||||
assert!(neighbors.contains(&"1".to_string()));
|
||||
assert!(neighbors.contains(&"2".to_string()));
|
||||
assert!(neighbors.contains(&"3".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_causal_memory() {
|
||||
let mut memory = CausalMemory::new(DistanceMetric::Cosine);
|
||||
|
||||
memory.index.add_entity("1".to_string(), vec![1.0, 0.0]);
|
||||
memory.index.add_entity("2".to_string(), vec![0.0, 1.0]);
|
||||
|
||||
memory
|
||||
.add_causal_edge(
|
||||
"1".to_string(),
|
||||
"2".to_string(),
|
||||
vec![],
|
||||
"Action 1 causes effect 2".to_string(),
|
||||
vec![0.5, 0.5],
|
||||
100.0,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let results = memory.query_with_utility(&[0.6, 0.4], "1".to_string(), 5);
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
}
|
||||
441
crates/ruvector-core/src/advanced/learned_index.rs
Normal file
441
crates/ruvector-core/src/advanced/learned_index.rs
Normal file
@@ -0,0 +1,441 @@
|
||||
//! # Learned Index Structures
|
||||
//!
|
||||
//! Experimental learned indexes using neural networks to approximate data distribution.
|
||||
//! Based on Recursive Model Index (RMI) concept with bounded error correction.
|
||||
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use crate::types::VectorId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Trait for learned index structures
|
||||
pub trait LearnedIndex {
|
||||
/// Predict position for a key
|
||||
fn predict(&self, key: &[f32]) -> Result<usize>;
|
||||
|
||||
/// Insert a key-value pair
|
||||
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()>;
|
||||
|
||||
/// Search for a key
|
||||
fn search(&self, key: &[f32]) -> Result<Option<VectorId>>;
|
||||
|
||||
/// Get index statistics
|
||||
fn stats(&self) -> IndexStats;
|
||||
}
|
||||
|
||||
/// Statistics for learned indexes
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IndexStats {
|
||||
pub total_entries: usize,
|
||||
pub model_size_bytes: usize,
|
||||
pub avg_error: f32,
|
||||
pub max_error: usize,
|
||||
}
|
||||
|
||||
/// Simple linear model for CDF approximation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct LinearModel {
|
||||
weights: Vec<f32>,
|
||||
bias: f32,
|
||||
}
|
||||
|
||||
impl LinearModel {
|
||||
fn new(dimensions: usize) -> Self {
|
||||
Self {
|
||||
weights: vec![0.0; dimensions],
|
||||
bias: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn predict(&self, input: &[f32]) -> f32 {
|
||||
let mut result = self.bias;
|
||||
for (w, x) in self.weights.iter().zip(input.iter()) {
|
||||
result += w * x;
|
||||
}
|
||||
result.max(0.0)
|
||||
}
|
||||
|
||||
fn train_simple(&mut self, data: &[(Vec<f32>, usize)]) {
|
||||
if data.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Simple least squares approximation
|
||||
let n = data.len() as f32;
|
||||
let dim = self.weights.len();
|
||||
|
||||
// Reset weights
|
||||
self.weights.fill(0.0);
|
||||
self.bias = 0.0;
|
||||
|
||||
// Compute means
|
||||
let mut mean_x = vec![0.0; dim];
|
||||
let mut mean_y = 0.0;
|
||||
|
||||
for (x, y) in data {
|
||||
for (i, &val) in x.iter().enumerate() {
|
||||
mean_x[i] += val;
|
||||
}
|
||||
mean_y += *y as f32;
|
||||
}
|
||||
|
||||
for val in mean_x.iter_mut() {
|
||||
*val /= n;
|
||||
}
|
||||
mean_y /= n;
|
||||
|
||||
// Simple linear regression for first dimension
|
||||
if dim > 0 {
|
||||
let mut numerator = 0.0;
|
||||
let mut denominator = 0.0;
|
||||
|
||||
for (x, y) in data {
|
||||
let x_diff = x[0] - mean_x[0];
|
||||
let y_diff = *y as f32 - mean_y;
|
||||
numerator += x_diff * y_diff;
|
||||
denominator += x_diff * x_diff;
|
||||
}
|
||||
|
||||
if denominator.abs() > 1e-10 {
|
||||
self.weights[0] = numerator / denominator;
|
||||
}
|
||||
self.bias = mean_y - self.weights[0] * mean_x[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursive Model Index (RMI)
|
||||
/// Multi-stage neural models making coarse-then-fine predictions
|
||||
pub struct RecursiveModelIndex {
|
||||
/// Root model for coarse prediction
|
||||
root_model: LinearModel,
|
||||
/// Second-level models for fine prediction
|
||||
leaf_models: Vec<LinearModel>,
|
||||
/// Sorted data with error correction
|
||||
data: Vec<(Vec<f32>, VectorId)>,
|
||||
/// Error bounds for binary search fallback
|
||||
max_error: usize,
|
||||
/// Dimensions of vectors
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
impl RecursiveModelIndex {
|
||||
/// Create a new RMI with specified number of leaf models
|
||||
pub fn new(dimensions: usize, num_leaf_models: usize) -> Self {
|
||||
let leaf_models = (0..num_leaf_models)
|
||||
.map(|_| LinearModel::new(dimensions))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
root_model: LinearModel::new(dimensions),
|
||||
leaf_models,
|
||||
data: Vec::new(),
|
||||
max_error: 100,
|
||||
dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the index from data
|
||||
pub fn build(&mut self, mut data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
|
||||
if data.is_empty() {
|
||||
return Err(RuvectorError::InvalidInput(
|
||||
"Cannot build index from empty data".into(),
|
||||
));
|
||||
}
|
||||
|
||||
if data[0].0.is_empty() {
|
||||
return Err(RuvectorError::InvalidInput(
|
||||
"Cannot build index from vectors with zero dimensions".into(),
|
||||
));
|
||||
}
|
||||
|
||||
if self.leaf_models.is_empty() {
|
||||
return Err(RuvectorError::InvalidInput(
|
||||
"Cannot build index with zero leaf models".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Sort data by first dimension (simple heuristic)
|
||||
data.sort_by(|a, b| {
|
||||
a.0[0]
|
||||
.partial_cmp(&b.0[0])
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let n = data.len();
|
||||
|
||||
// Train root model to predict leaf model index
|
||||
let root_training_data: Vec<(Vec<f32>, usize)> = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (key, _))| {
|
||||
let leaf_idx = (i * self.leaf_models.len()) / n;
|
||||
(key.clone(), leaf_idx)
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.root_model.train_simple(&root_training_data);
|
||||
|
||||
// Train each leaf model
|
||||
let num_leaf_models = self.leaf_models.len();
|
||||
let chunk_size = n / num_leaf_models;
|
||||
for (i, model) in self.leaf_models.iter_mut().enumerate() {
|
||||
let start = i * chunk_size;
|
||||
let end = if i == num_leaf_models - 1 {
|
||||
n
|
||||
} else {
|
||||
(i + 1) * chunk_size
|
||||
};
|
||||
|
||||
if start < n {
|
||||
let leaf_data: Vec<(Vec<f32>, usize)> = data[start..end.min(n)]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(j, (key, _))| (key.clone(), start + j))
|
||||
.collect();
|
||||
|
||||
model.train_simple(&leaf_data);
|
||||
}
|
||||
}
|
||||
|
||||
self.data = data;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl LearnedIndex for RecursiveModelIndex {
|
||||
fn predict(&self, key: &[f32]) -> Result<usize> {
|
||||
if key.len() != self.dimensions {
|
||||
return Err(RuvectorError::InvalidInput(
|
||||
"Key dimensions mismatch".into(),
|
||||
));
|
||||
}
|
||||
|
||||
if self.leaf_models.is_empty() {
|
||||
return Err(RuvectorError::InvalidInput(
|
||||
"Index not built: no leaf models available".into(),
|
||||
));
|
||||
}
|
||||
|
||||
if self.data.is_empty() {
|
||||
return Err(RuvectorError::InvalidInput(
|
||||
"Index not built: no data available".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Root model predicts leaf model
|
||||
let leaf_idx = self.root_model.predict(key) as usize;
|
||||
let leaf_idx = leaf_idx.min(self.leaf_models.len() - 1);
|
||||
|
||||
// Leaf model predicts position
|
||||
let pos = self.leaf_models[leaf_idx].predict(key) as usize;
|
||||
let pos = pos.min(self.data.len().saturating_sub(1));
|
||||
|
||||
Ok(pos)
|
||||
}
|
||||
|
||||
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
|
||||
// For simplicity, append and mark for rebuild
|
||||
// Production implementation would use incremental updates
|
||||
self.data.push((key, value));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
|
||||
if self.data.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let predicted_pos = self.predict(key)?;
|
||||
|
||||
// Binary search around predicted position with error bound
|
||||
let start = predicted_pos.saturating_sub(self.max_error);
|
||||
let end = (predicted_pos + self.max_error).min(self.data.len());
|
||||
|
||||
for i in start..end {
|
||||
if self.data[i].0 == key {
|
||||
return Ok(Some(self.data[i].1.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn stats(&self) -> IndexStats {
|
||||
let model_size = std::mem::size_of_val(&self.root_model)
|
||||
+ self.leaf_models.len() * std::mem::size_of::<LinearModel>();
|
||||
|
||||
// Compute average prediction error
|
||||
let mut total_error = 0.0;
|
||||
let mut max_error = 0;
|
||||
|
||||
for (i, (key, _)) in self.data.iter().enumerate() {
|
||||
if let Ok(pred_pos) = self.predict(key) {
|
||||
let error = i.abs_diff(pred_pos);
|
||||
total_error += error as f32;
|
||||
max_error = max_error.max(error);
|
||||
}
|
||||
}
|
||||
|
||||
let avg_error = if !self.data.is_empty() {
|
||||
total_error / self.data.len() as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
IndexStats {
|
||||
total_entries: self.data.len(),
|
||||
model_size_bytes: model_size,
|
||||
avg_error,
|
||||
max_error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hybrid index combining learned index for static data with HNSW for dynamic updates
|
||||
pub struct HybridIndex {
|
||||
/// Learned index for static segment
|
||||
learned: RecursiveModelIndex,
|
||||
/// Dynamic updates buffer
|
||||
dynamic_buffer: HashMap<Vec<u8>, VectorId>,
|
||||
/// Threshold for rebuilding learned index
|
||||
rebuild_threshold: usize,
|
||||
}
|
||||
|
||||
impl HybridIndex {
|
||||
/// Create a new hybrid index
|
||||
pub fn new(dimensions: usize, num_leaf_models: usize, rebuild_threshold: usize) -> Self {
|
||||
Self {
|
||||
learned: RecursiveModelIndex::new(dimensions, num_leaf_models),
|
||||
dynamic_buffer: HashMap::new(),
|
||||
rebuild_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the learned portion from static data
|
||||
pub fn build_static(&mut self, data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
|
||||
self.learned.build(data)
|
||||
}
|
||||
|
||||
/// Check if rebuild is needed
|
||||
pub fn needs_rebuild(&self) -> bool {
|
||||
self.dynamic_buffer.len() >= self.rebuild_threshold
|
||||
}
|
||||
|
||||
/// Rebuild learned index incorporating dynamic updates
|
||||
pub fn rebuild(&mut self) -> Result<()> {
|
||||
let mut all_data: Vec<(Vec<f32>, VectorId)> = self.learned.data.clone();
|
||||
|
||||
for (key_bytes, value) in &self.dynamic_buffer {
|
||||
let (key, _): (Vec<f32>, usize) =
|
||||
bincode::decode_from_slice(key_bytes, bincode::config::standard())
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
|
||||
all_data.push((key, value.clone()));
|
||||
}
|
||||
|
||||
self.learned.build(all_data)?;
|
||||
self.dynamic_buffer.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn serialize_key(key: &[f32]) -> Vec<u8> {
|
||||
bincode::encode_to_vec(key, bincode::config::standard()).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl LearnedIndex for HybridIndex {
|
||||
fn predict(&self, key: &[f32]) -> Result<usize> {
|
||||
self.learned.predict(key)
|
||||
}
|
||||
|
||||
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
|
||||
let key_bytes = Self::serialize_key(&key);
|
||||
self.dynamic_buffer.insert(key_bytes, value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
|
||||
// Check dynamic buffer first
|
||||
let key_bytes = Self::serialize_key(key);
|
||||
if let Some(value) = self.dynamic_buffer.get(&key_bytes) {
|
||||
return Ok(Some(value.clone()));
|
||||
}
|
||||
|
||||
// Fall back to learned index
|
||||
self.learned.search(key)
|
||||
}
|
||||
|
||||
fn stats(&self) -> IndexStats {
|
||||
let mut stats = self.learned.stats();
|
||||
stats.total_entries += self.dynamic_buffer.len();
|
||||
stats
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_linear_model() {
|
||||
let mut model = LinearModel::new(2);
|
||||
let data = vec![
|
||||
(vec![0.0, 0.0], 0),
|
||||
(vec![1.0, 1.0], 10),
|
||||
(vec![2.0, 2.0], 20),
|
||||
];
|
||||
|
||||
model.train_simple(&data);
|
||||
|
||||
let pred = model.predict(&[1.5, 1.5]);
|
||||
assert!(pred >= 0.0 && pred <= 30.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rmi_build() {
|
||||
let mut rmi = RecursiveModelIndex::new(2, 4);
|
||||
|
||||
let data: Vec<(Vec<f32>, VectorId)> = (0..100)
|
||||
.map(|i| {
|
||||
let x = i as f32 / 100.0;
|
||||
(vec![x, x * x], i.to_string())
|
||||
})
|
||||
.collect();
|
||||
|
||||
rmi.build(data).unwrap();
|
||||
|
||||
let stats = rmi.stats();
|
||||
assert_eq!(stats.total_entries, 100);
|
||||
assert!(stats.avg_error < 50.0); // Should have reasonable error
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rmi_search() {
|
||||
let mut rmi = RecursiveModelIndex::new(1, 2);
|
||||
|
||||
let data = vec![
|
||||
(vec![0.0], "0".to_string()),
|
||||
(vec![0.5], "1".to_string()),
|
||||
(vec![1.0], "2".to_string()),
|
||||
];
|
||||
|
||||
rmi.build(data).unwrap();
|
||||
|
||||
let result = rmi.search(&[0.5]).unwrap();
|
||||
assert_eq!(result, Some("1".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_index() {
|
||||
let mut hybrid = HybridIndex::new(1, 2, 10);
|
||||
|
||||
let static_data = vec![(vec![0.0], "0".to_string()), (vec![1.0], "1".to_string())];
|
||||
hybrid.build_static(static_data).unwrap();
|
||||
|
||||
// Add dynamic updates
|
||||
hybrid.insert(vec![2.0], "2".to_string()).unwrap();
|
||||
|
||||
assert_eq!(hybrid.search(&[2.0]).unwrap(), Some("2".to_string()));
|
||||
assert_eq!(hybrid.search(&[0.0]).unwrap(), Some("0".to_string()));
|
||||
}
|
||||
}
|
||||
17
crates/ruvector-core/src/advanced/mod.rs
Normal file
17
crates/ruvector-core/src/advanced/mod.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
//! # Advanced Techniques
|
||||
//!
|
||||
//! This module contains experimental and advanced features for next-generation vector search:
|
||||
//! - **Hypergraphs**: n-ary relationships beyond pairwise similarity
|
||||
//! - **Learned Indexes**: Neural network-based index structures
|
||||
//! - **Neural Hashing**: Similarity-preserving binary projections
|
||||
//! - **Topological Data Analysis**: Embedding quality assessment
|
||||
|
||||
pub mod hypergraph;
|
||||
pub mod learned_index;
|
||||
pub mod neural_hash;
|
||||
pub mod tda;
|
||||
|
||||
pub use hypergraph::{CausalMemory, Hyperedge, HypergraphIndex, TemporalHyperedge};
|
||||
pub use learned_index::{HybridIndex, LearnedIndex, RecursiveModelIndex};
|
||||
pub use neural_hash::{DeepHashEmbedding, NeuralHash};
|
||||
pub use tda::{EmbeddingQuality, TopologicalAnalyzer};
|
||||
427
crates/ruvector-core/src/advanced/neural_hash.rs
Normal file
427
crates/ruvector-core/src/advanced/neural_hash.rs
Normal file
@@ -0,0 +1,427 @@
|
||||
//! # Neural Hash Functions
|
||||
//!
|
||||
//! Learn similarity-preserving binary projections for extreme compression.
|
||||
//! Achieves 32-128x compression with 90-95% recall preservation.
|
||||
|
||||
use crate::types::VectorId;
|
||||
use ndarray::{Array1, Array2};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Neural hash function for similarity-preserving binary codes
|
||||
pub trait NeuralHash {
|
||||
/// Encode a vector to binary code
|
||||
fn encode(&self, vector: &[f32]) -> Vec<u8>;
|
||||
|
||||
/// Compute Hamming distance between two codes
|
||||
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32;
|
||||
|
||||
/// Estimate similarity from Hamming distance
|
||||
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32;
|
||||
}
|
||||
|
||||
/// Deep hash embedding with learned projections
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct DeepHashEmbedding {
|
||||
/// Projection matrices for each layer
|
||||
projections: Vec<Array2<f32>>,
|
||||
/// Biases for each layer
|
||||
biases: Vec<Array1<f32>>,
|
||||
/// Number of output bits
|
||||
output_bits: usize,
|
||||
/// Input dimensions
|
||||
input_dims: usize,
|
||||
}
|
||||
|
||||
impl DeepHashEmbedding {
|
||||
/// Create a new deep hash embedding
|
||||
pub fn new(input_dims: usize, hidden_dims: Vec<usize>, output_bits: usize) -> Self {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut projections = Vec::new();
|
||||
let mut biases = Vec::new();
|
||||
|
||||
let mut layer_dims = vec![input_dims];
|
||||
layer_dims.extend(&hidden_dims);
|
||||
layer_dims.push(output_bits);
|
||||
|
||||
// Initialize random projections (Xavier initialization)
|
||||
for i in 0..layer_dims.len() - 1 {
|
||||
let in_dim = layer_dims[i];
|
||||
let out_dim = layer_dims[i + 1];
|
||||
|
||||
let scale = (2.0 / (in_dim + out_dim) as f32).sqrt();
|
||||
let proj = Array2::from_shape_fn((out_dim, in_dim), |_| {
|
||||
rng.gen::<f32>() * 2.0 * scale - scale
|
||||
});
|
||||
|
||||
let bias = Array1::zeros(out_dim);
|
||||
|
||||
projections.push(proj);
|
||||
biases.push(bias);
|
||||
}
|
||||
|
||||
Self {
|
||||
projections,
|
||||
biases,
|
||||
output_bits,
|
||||
input_dims,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the network
|
||||
fn forward(&self, input: &[f32]) -> Vec<f32> {
|
||||
let mut activations = Array1::from_vec(input.to_vec());
|
||||
|
||||
for (proj, bias) in self.projections.iter().zip(self.biases.iter()) {
|
||||
// Linear layer: y = Wx + b
|
||||
activations = proj.dot(&activations) + bias;
|
||||
|
||||
// ReLU activation (except last layer)
|
||||
if proj.nrows() != self.output_bits {
|
||||
activations.mapv_inplace(|x| x.max(0.0));
|
||||
}
|
||||
}
|
||||
|
||||
activations.to_vec()
|
||||
}
|
||||
|
||||
/// Train on pairs of similar/dissimilar examples
|
||||
pub fn train(
|
||||
&mut self,
|
||||
positive_pairs: &[(Vec<f32>, Vec<f32>)],
|
||||
negative_pairs: &[(Vec<f32>, Vec<f32>)],
|
||||
learning_rate: f32,
|
||||
epochs: usize,
|
||||
) {
|
||||
// Simplified training with contrastive loss
|
||||
// Production would use proper backpropagation
|
||||
for _ in 0..epochs {
|
||||
// Positive pairs should have small Hamming distance
|
||||
for (a, b) in positive_pairs {
|
||||
let code_a = self.encode(a);
|
||||
let code_b = self.encode(b);
|
||||
let dist = self.hamming_distance(&code_a, &code_b);
|
||||
|
||||
// If distance is too large, update towards similarity
|
||||
if dist as f32 > self.output_bits as f32 * 0.3 {
|
||||
self.update_weights(a, b, learning_rate, true);
|
||||
}
|
||||
}
|
||||
|
||||
// Negative pairs should have large Hamming distance
|
||||
for (a, b) in negative_pairs {
|
||||
let code_a = self.encode(a);
|
||||
let code_b = self.encode(b);
|
||||
let dist = self.hamming_distance(&code_a, &code_b);
|
||||
|
||||
// If distance is too small, update towards dissimilarity
|
||||
if (dist as f32) < self.output_bits as f32 * 0.6 {
|
||||
self.update_weights(a, b, learning_rate, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_weights(&mut self, a: &[f32], b: &[f32], lr: f32, attract: bool) {
|
||||
// Simplified gradient update (production would use proper autodiff)
|
||||
let direction = if attract { 1.0 } else { -1.0 };
|
||||
|
||||
// Update only the last layer for simplicity
|
||||
if let Some(last_proj) = self.projections.last_mut() {
|
||||
let a_arr = Array1::from_vec(a.to_vec());
|
||||
let b_arr = Array1::from_vec(b.to_vec());
|
||||
|
||||
for i in 0..last_proj.nrows() {
|
||||
for j in 0..last_proj.ncols() {
|
||||
let grad = direction * lr * (a_arr[j] - b_arr[j]);
|
||||
last_proj[[i, j]] += grad * 0.001; // Small update
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get dimensions
|
||||
pub fn dimensions(&self) -> (usize, usize) {
|
||||
(self.input_dims, self.output_bits)
|
||||
}
|
||||
}
|
||||
|
||||
impl NeuralHash for DeepHashEmbedding {
|
||||
fn encode(&self, vector: &[f32]) -> Vec<u8> {
|
||||
if vector.len() != self.input_dims {
|
||||
return vec![0; self.output_bits.div_ceil(8)];
|
||||
}
|
||||
|
||||
let logits = self.forward(vector);
|
||||
|
||||
// Threshold at 0 to get binary codes
|
||||
let mut bits = vec![0u8; self.output_bits.div_ceil(8)];
|
||||
|
||||
for (i, &logit) in logits.iter().enumerate() {
|
||||
if logit > 0.0 {
|
||||
let byte_idx = i / 8;
|
||||
let bit_idx = i % 8;
|
||||
bits[byte_idx] |= 1 << bit_idx;
|
||||
}
|
||||
}
|
||||
|
||||
bits
|
||||
}
|
||||
|
||||
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
|
||||
code_a
|
||||
.iter()
|
||||
.zip(code_b.iter())
|
||||
.map(|(a, b)| (a ^ b).count_ones())
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
|
||||
// Convert Hamming distance to approximate cosine similarity
|
||||
let normalized_dist = hamming_dist as f32 / code_bits as f32;
|
||||
1.0 - 2.0 * normalized_dist
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple LSH (Locality Sensitive Hashing) baseline
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct SimpleLSH {
|
||||
/// Random projection vectors
|
||||
projections: Array2<f32>,
|
||||
/// Number of hash bits
|
||||
num_bits: usize,
|
||||
}
|
||||
|
||||
impl SimpleLSH {
|
||||
/// Create a new LSH with random projections
|
||||
pub fn new(input_dims: usize, num_bits: usize) -> Self {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// Random Gaussian projections
|
||||
let projections =
|
||||
Array2::from_shape_fn((num_bits, input_dims), |_| rng.gen::<f32>() * 2.0 - 1.0);
|
||||
|
||||
Self {
|
||||
projections,
|
||||
num_bits,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NeuralHash for SimpleLSH {
|
||||
fn encode(&self, vector: &[f32]) -> Vec<u8> {
|
||||
let input = Array1::from_vec(vector.to_vec());
|
||||
let projections = self.projections.dot(&input);
|
||||
|
||||
let mut bits = vec![0u8; self.num_bits.div_ceil(8)];
|
||||
|
||||
for (i, &val) in projections.iter().enumerate() {
|
||||
if val > 0.0 {
|
||||
let byte_idx = i / 8;
|
||||
let bit_idx = i % 8;
|
||||
bits[byte_idx] |= 1 << bit_idx;
|
||||
}
|
||||
}
|
||||
|
||||
bits
|
||||
}
|
||||
|
||||
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
|
||||
code_a
|
||||
.iter()
|
||||
.zip(code_b.iter())
|
||||
.map(|(a, b)| (a ^ b).count_ones())
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
|
||||
let normalized_dist = hamming_dist as f32 / code_bits as f32;
|
||||
1.0 - 2.0 * normalized_dist
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash index for fast approximate nearest neighbor search
|
||||
pub struct HashIndex<H: NeuralHash + Clone> {
|
||||
/// Hash function
|
||||
hasher: H,
|
||||
/// Hash tables: binary code -> list of vector IDs
|
||||
tables: HashMap<Vec<u8>, Vec<VectorId>>,
|
||||
/// Original vectors for verification
|
||||
vectors: HashMap<VectorId, Vec<f32>>,
|
||||
/// Code bits
|
||||
code_bits: usize,
|
||||
}
|
||||
|
||||
impl<H: NeuralHash + Clone> HashIndex<H> {
|
||||
/// Create a new hash index
|
||||
pub fn new(hasher: H, code_bits: usize) -> Self {
|
||||
Self {
|
||||
hasher,
|
||||
tables: HashMap::new(),
|
||||
vectors: HashMap::new(),
|
||||
code_bits,
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a vector
|
||||
pub fn insert(&mut self, id: VectorId, vector: Vec<f32>) {
|
||||
let code = self.hasher.encode(&vector);
|
||||
|
||||
self.tables.entry(code).or_default().push(id.clone());
|
||||
|
||||
self.vectors.insert(id, vector);
|
||||
}
|
||||
|
||||
/// Search for approximate nearest neighbors
|
||||
pub fn search(&self, query: &[f32], k: usize, max_hamming: u32) -> Vec<(VectorId, f32)> {
|
||||
let query_code = self.hasher.encode(query);
|
||||
|
||||
let mut candidates = Vec::new();
|
||||
|
||||
// Find all vectors within Hamming distance threshold
|
||||
for (code, ids) in &self.tables {
|
||||
let hamming = self.hasher.hamming_distance(&query_code, code);
|
||||
|
||||
if hamming <= max_hamming {
|
||||
for id in ids {
|
||||
if let Some(vec) = self.vectors.get(id) {
|
||||
let similarity = cosine_similarity(query, vec);
|
||||
candidates.push((id.clone(), similarity));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by similarity and return top-k
|
||||
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
candidates.truncate(k);
|
||||
candidates
|
||||
}
|
||||
|
||||
/// Get compression ratio
|
||||
pub fn compression_ratio(&self) -> f32 {
|
||||
if self.vectors.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let original_size: usize = self
|
||||
.vectors
|
||||
.values()
|
||||
.map(|v| v.len() * std::mem::size_of::<f32>())
|
||||
.sum();
|
||||
|
||||
let compressed_size = self.tables.len() * self.code_bits.div_ceil(8);
|
||||
|
||||
original_size as f32 / compressed_size as f32
|
||||
}
|
||||
|
||||
/// Get statistics
|
||||
pub fn stats(&self) -> HashIndexStats {
|
||||
let buckets = self.tables.len();
|
||||
let total_vectors = self.vectors.len();
|
||||
let avg_bucket_size = if buckets > 0 {
|
||||
total_vectors as f32 / buckets as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
HashIndexStats {
|
||||
total_vectors,
|
||||
num_buckets: buckets,
|
||||
avg_bucket_size,
|
||||
compression_ratio: self.compression_ratio(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash index statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HashIndexStats {
|
||||
pub total_vectors: usize,
|
||||
pub num_buckets: usize,
|
||||
pub avg_bucket_size: f32,
|
||||
pub compression_ratio: f32,
|
||||
}
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a > 0.0 && norm_b > 0.0 {
|
||||
dot / (norm_a * norm_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_deep_hash_encoding() {
|
||||
let hash = DeepHashEmbedding::new(4, vec![8], 16);
|
||||
let vector = vec![0.1, 0.2, 0.3, 0.4];
|
||||
|
||||
let code = hash.encode(&vector);
|
||||
assert_eq!(code.len(), 2); // 16 bits = 2 bytes
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hamming_distance() {
|
||||
let hash = DeepHashEmbedding::new(2, vec![], 8);
|
||||
|
||||
let code_a = vec![0b10101010];
|
||||
let code_b = vec![0b11001100];
|
||||
|
||||
let dist = hash.hamming_distance(&code_a, &code_b);
|
||||
assert_eq!(dist, 4); // 4 bits differ
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsh_encoding() {
|
||||
let lsh = SimpleLSH::new(4, 16);
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0];
|
||||
|
||||
let code = lsh.encode(&vector);
|
||||
assert_eq!(code.len(), 2);
|
||||
|
||||
// Same vector should produce same code
|
||||
let code2 = lsh.encode(&vector);
|
||||
assert_eq!(code, code2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_index() {
|
||||
let lsh = SimpleLSH::new(3, 8);
|
||||
let mut index = HashIndex::new(lsh, 8);
|
||||
|
||||
// Insert vectors
|
||||
index.insert("0".to_string(), vec![1.0, 0.0, 0.0]);
|
||||
index.insert("1".to_string(), vec![0.9, 0.1, 0.0]);
|
||||
index.insert("2".to_string(), vec![0.0, 1.0, 0.0]);
|
||||
|
||||
// Search
|
||||
let results = index.search(&[1.0, 0.0, 0.0], 2, 4);
|
||||
|
||||
assert!(!results.is_empty());
|
||||
let stats = index.stats();
|
||||
assert_eq!(stats.total_vectors, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratio() {
|
||||
let lsh = SimpleLSH::new(128, 32); // 128D -> 32 bits
|
||||
let mut index = HashIndex::new(lsh, 32);
|
||||
|
||||
for i in 0..10 {
|
||||
let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 128.0).collect();
|
||||
index.insert(i.to_string(), vec);
|
||||
}
|
||||
|
||||
let ratio = index.compression_ratio();
|
||||
assert!(ratio > 1.0); // Should have compression
|
||||
}
|
||||
}
|
||||
496
crates/ruvector-core/src/advanced/tda.rs
Normal file
496
crates/ruvector-core/src/advanced/tda.rs
Normal file
@@ -0,0 +1,496 @@
|
||||
//! # Topological Data Analysis (TDA)
|
||||
//!
|
||||
//! Basic topological analysis for embedding quality assessment.
|
||||
//! Detects mode collapse, degeneracy, and topological structure.
|
||||
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use ndarray::Array2;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Topological analyzer for embeddings
|
||||
pub struct TopologicalAnalyzer {
|
||||
/// k for k-nearest neighbors graph
|
||||
k_neighbors: usize,
|
||||
/// Distance threshold for edge creation
|
||||
epsilon: f32,
|
||||
}
|
||||
|
||||
impl TopologicalAnalyzer {
|
||||
/// Create a new topological analyzer
|
||||
pub fn new(k_neighbors: usize, epsilon: f32) -> Self {
|
||||
Self {
|
||||
k_neighbors,
|
||||
epsilon,
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze embedding quality
|
||||
pub fn analyze(&self, embeddings: &[Vec<f32>]) -> Result<EmbeddingQuality> {
|
||||
if embeddings.is_empty() {
|
||||
return Err(RuvectorError::InvalidInput("Empty embeddings".into()));
|
||||
}
|
||||
|
||||
let n = embeddings.len();
|
||||
let dim = embeddings[0].len();
|
||||
|
||||
// Build k-NN graph
|
||||
let graph = self.build_knn_graph(embeddings);
|
||||
|
||||
// Compute topological features
|
||||
let connected_components = self.count_connected_components(&graph, n);
|
||||
let clustering_coefficient = self.compute_clustering_coefficient(&graph);
|
||||
let degree_stats = self.compute_degree_statistics(&graph, n);
|
||||
|
||||
// Detect mode collapse
|
||||
let mode_collapse_score = self.detect_mode_collapse(embeddings);
|
||||
|
||||
// Compute embedding spread
|
||||
let spread = self.compute_spread(embeddings);
|
||||
|
||||
// Detect degeneracy (vectors collapsing to a lower-dimensional manifold)
|
||||
let degeneracy_score = self.detect_degeneracy(embeddings);
|
||||
|
||||
// Compute persistence features (simplified)
|
||||
let persistence_score = self.compute_persistence(&graph, embeddings);
|
||||
|
||||
// Overall quality score (0-1, higher is better)
|
||||
let quality_score = self.compute_quality_score(
|
||||
mode_collapse_score,
|
||||
degeneracy_score,
|
||||
connected_components,
|
||||
clustering_coefficient,
|
||||
spread,
|
||||
);
|
||||
|
||||
Ok(EmbeddingQuality {
|
||||
dimensions: dim,
|
||||
num_vectors: n,
|
||||
connected_components,
|
||||
clustering_coefficient,
|
||||
avg_degree: degree_stats.0,
|
||||
degree_std: degree_stats.1,
|
||||
mode_collapse_score,
|
||||
degeneracy_score,
|
||||
spread,
|
||||
persistence_score,
|
||||
quality_score,
|
||||
})
|
||||
}
|
||||
|
||||
fn build_knn_graph(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<usize>> {
|
||||
let n = embeddings.len();
|
||||
let mut graph = vec![Vec::new(); n];
|
||||
|
||||
for i in 0..n {
|
||||
let mut distances: Vec<(usize, f32)> = (0..n)
|
||||
.filter(|&j| i != j)
|
||||
.map(|j| {
|
||||
let dist = euclidean_distance(&embeddings[i], &embeddings[j]);
|
||||
(j, dist)
|
||||
})
|
||||
.collect();
|
||||
|
||||
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
|
||||
// Add k nearest neighbors
|
||||
for (j, dist) in distances.iter().take(self.k_neighbors) {
|
||||
if *dist <= self.epsilon {
|
||||
graph[i].push(*j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
graph
|
||||
}
|
||||
|
||||
fn count_connected_components(&self, graph: &[Vec<usize>], n: usize) -> usize {
|
||||
let mut visited = vec![false; n];
|
||||
let mut components = 0;
|
||||
|
||||
for i in 0..n {
|
||||
if !visited[i] {
|
||||
components += 1;
|
||||
self.dfs(i, graph, &mut visited);
|
||||
}
|
||||
}
|
||||
|
||||
components
|
||||
}
|
||||
|
||||
#[allow(clippy::only_used_in_recursion)]
|
||||
fn dfs(&self, node: usize, graph: &[Vec<usize>], visited: &mut [bool]) {
|
||||
visited[node] = true;
|
||||
for &neighbor in &graph[node] {
|
||||
if !visited[neighbor] {
|
||||
self.dfs(neighbor, graph, visited);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_clustering_coefficient(&self, graph: &[Vec<usize>]) -> f32 {
|
||||
let mut total_coeff = 0.0;
|
||||
let mut count = 0;
|
||||
|
||||
for neighbors in graph {
|
||||
if neighbors.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let k = neighbors.len();
|
||||
let mut triangles = 0;
|
||||
|
||||
// Count triangles
|
||||
for i in 0..k {
|
||||
for j in i + 1..k {
|
||||
let ni = neighbors[i];
|
||||
let nj = neighbors[j];
|
||||
|
||||
if graph[ni].contains(&nj) {
|
||||
triangles += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let possible_triangles = k * (k - 1) / 2;
|
||||
if possible_triangles > 0 {
|
||||
total_coeff += triangles as f32 / possible_triangles as f32;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
total_coeff / count as f32
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_degree_statistics(&self, graph: &[Vec<usize>], n: usize) -> (f32, f32) {
|
||||
let degrees: Vec<f32> = graph
|
||||
.iter()
|
||||
.map(|neighbors| neighbors.len() as f32)
|
||||
.collect();
|
||||
|
||||
let avg = degrees.iter().sum::<f32>() / n as f32;
|
||||
let variance = degrees.iter().map(|&d| (d - avg).powi(2)).sum::<f32>() / n as f32;
|
||||
let std = variance.sqrt();
|
||||
|
||||
(avg, std)
|
||||
}
|
||||
|
||||
fn detect_mode_collapse(&self, embeddings: &[Vec<f32>]) -> f32 {
|
||||
// Compute pairwise distances
|
||||
let n = embeddings.len();
|
||||
let mut distances = Vec::new();
|
||||
|
||||
for i in 0..n {
|
||||
for j in i + 1..n {
|
||||
let dist = euclidean_distance(&embeddings[i], &embeddings[j]);
|
||||
distances.push(dist);
|
||||
}
|
||||
}
|
||||
|
||||
if distances.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute coefficient of variation
|
||||
let mean = distances.iter().sum::<f32>() / distances.len() as f32;
|
||||
let variance =
|
||||
distances.iter().map(|&d| (d - mean).powi(2)).sum::<f32>() / distances.len() as f32;
|
||||
let std = variance.sqrt();
|
||||
|
||||
// High CV indicates good separation, low CV indicates collapse
|
||||
let cv = if mean > 0.0 { std / mean } else { 0.0 };
|
||||
|
||||
// Normalize to 0-1, where 0 is collapsed, 1 is good
|
||||
(cv * 2.0).min(1.0)
|
||||
}
|
||||
|
||||
fn compute_spread(&self, embeddings: &[Vec<f32>]) -> f32 {
|
||||
if embeddings.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dim = embeddings[0].len();
|
||||
|
||||
// Compute mean
|
||||
let mut mean = vec![0.0; dim];
|
||||
for emb in embeddings {
|
||||
for (i, &val) in emb.iter().enumerate() {
|
||||
mean[i] += val;
|
||||
}
|
||||
}
|
||||
for val in mean.iter_mut() {
|
||||
*val /= embeddings.len() as f32;
|
||||
}
|
||||
|
||||
// Compute average distance from mean
|
||||
let mut total_dist = 0.0;
|
||||
for emb in embeddings {
|
||||
let dist = euclidean_distance(emb, &mean);
|
||||
total_dist += dist;
|
||||
}
|
||||
|
||||
total_dist / embeddings.len() as f32
|
||||
}
|
||||
|
||||
fn detect_degeneracy(&self, embeddings: &[Vec<f32>]) -> f32 {
|
||||
if embeddings.is_empty() || embeddings[0].is_empty() {
|
||||
return 1.0; // Fully degenerate
|
||||
}
|
||||
|
||||
let n = embeddings.len();
|
||||
let dim = embeddings[0].len();
|
||||
|
||||
if n < dim {
|
||||
return 0.0; // Cannot determine
|
||||
}
|
||||
|
||||
// Compute covariance matrix
|
||||
let cov = self.compute_covariance_matrix(embeddings);
|
||||
|
||||
// Estimate rank by counting significant singular values
|
||||
let singular_values = self.approximate_singular_values(&cov);
|
||||
|
||||
let significant = singular_values.iter().filter(|&&sv| sv > 1e-6).count();
|
||||
|
||||
// Degeneracy score: 0 = full rank, 1 = rank-1 (collapsed)
|
||||
1.0 - (significant as f32 / dim as f32)
|
||||
}
|
||||
|
||||
fn compute_covariance_matrix(&self, embeddings: &[Vec<f32>]) -> Array2<f32> {
|
||||
let n = embeddings.len();
|
||||
let dim = embeddings[0].len();
|
||||
|
||||
// Compute mean
|
||||
let mut mean = vec![0.0; dim];
|
||||
for emb in embeddings {
|
||||
for (i, &val) in emb.iter().enumerate() {
|
||||
mean[i] += val;
|
||||
}
|
||||
}
|
||||
for val in mean.iter_mut() {
|
||||
*val /= n as f32;
|
||||
}
|
||||
|
||||
// Compute covariance
|
||||
let mut cov = Array2::zeros((dim, dim));
|
||||
for emb in embeddings {
|
||||
for i in 0..dim {
|
||||
for j in 0..dim {
|
||||
cov[[i, j]] += (emb[i] - mean[i]) * (emb[j] - mean[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cov.mapv(|x| x / (n - 1) as f32);
|
||||
cov
|
||||
}
|
||||
|
||||
fn approximate_singular_values(&self, matrix: &Array2<f32>) -> Vec<f32> {
|
||||
// Power iteration for largest singular values (simplified)
|
||||
let dim = matrix.nrows();
|
||||
let mut values = Vec::new();
|
||||
|
||||
// Just return diagonal for approximation
|
||||
for i in 0..dim {
|
||||
values.push(matrix[[i, i]].abs());
|
||||
}
|
||||
|
||||
values.sort_by(|a, b| b.partial_cmp(a).unwrap());
|
||||
values
|
||||
}
|
||||
|
||||
fn compute_persistence(&self, _graph: &[Vec<usize>], embeddings: &[Vec<f32>]) -> f32 {
|
||||
// Simplified persistence: measure how graph structure changes with distance threshold
|
||||
let scales = vec![0.1, 0.5, 1.0, 2.0, 5.0];
|
||||
let mut component_counts = Vec::new();
|
||||
|
||||
for &scale in &scales {
|
||||
let scaled_analyzer = TopologicalAnalyzer::new(self.k_neighbors, scale);
|
||||
let scaled_graph = scaled_analyzer.build_knn_graph(embeddings);
|
||||
let components =
|
||||
scaled_analyzer.count_connected_components(&scaled_graph, embeddings.len());
|
||||
component_counts.push(components);
|
||||
}
|
||||
|
||||
// Persistence is the variation in component count across scales
|
||||
let max_components = *component_counts.iter().max().unwrap_or(&1);
|
||||
let min_components = *component_counts.iter().min().unwrap_or(&1);
|
||||
|
||||
(max_components - min_components) as f32 / max_components as f32
|
||||
}
|
||||
|
||||
fn compute_quality_score(
|
||||
&self,
|
||||
mode_collapse: f32,
|
||||
degeneracy: f32,
|
||||
components: usize,
|
||||
clustering: f32,
|
||||
spread: f32,
|
||||
) -> f32 {
|
||||
// Weighted combination of metrics
|
||||
let collapse_score = mode_collapse; // Higher is better
|
||||
let degeneracy_score = 1.0 - degeneracy; // Lower degeneracy is better
|
||||
let component_score = if components == 1 { 1.0 } else { 0.5 }; // Single component is good
|
||||
let clustering_score = clustering; // Higher clustering is good
|
||||
let spread_score = (spread / 10.0).min(1.0); // Reasonable spread
|
||||
|
||||
(collapse_score * 0.3
|
||||
+ degeneracy_score * 0.3
|
||||
+ component_score * 0.2
|
||||
+ clustering_score * 0.1
|
||||
+ spread_score * 0.1)
|
||||
.clamp(0.0, 1.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding quality metrics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingQuality {
|
||||
/// Embedding dimensions
|
||||
pub dimensions: usize,
|
||||
/// Number of vectors
|
||||
pub num_vectors: usize,
|
||||
/// Number of connected components
|
||||
pub connected_components: usize,
|
||||
/// Clustering coefficient (0-1)
|
||||
pub clustering_coefficient: f32,
|
||||
/// Average node degree
|
||||
pub avg_degree: f32,
|
||||
/// Degree standard deviation
|
||||
pub degree_std: f32,
|
||||
/// Mode collapse score (0=collapsed, 1=good)
|
||||
pub mode_collapse_score: f32,
|
||||
/// Degeneracy score (0=full rank, 1=degenerate)
|
||||
pub degeneracy_score: f32,
|
||||
/// Average spread from centroid
|
||||
pub spread: f32,
|
||||
/// Topological persistence score
|
||||
pub persistence_score: f32,
|
||||
/// Overall quality (0-1, higher is better)
|
||||
pub quality_score: f32,
|
||||
}
|
||||
|
||||
impl EmbeddingQuality {
|
||||
/// Check if embeddings show signs of mode collapse
|
||||
pub fn has_mode_collapse(&self) -> bool {
|
||||
self.mode_collapse_score < 0.3
|
||||
}
|
||||
|
||||
/// Check if embeddings are degenerate
|
||||
pub fn is_degenerate(&self) -> bool {
|
||||
self.degeneracy_score > 0.7
|
||||
}
|
||||
|
||||
/// Check if embeddings are well-structured
|
||||
pub fn is_good_quality(&self) -> bool {
|
||||
self.quality_score > 0.7
|
||||
}
|
||||
|
||||
/// Get quality assessment
|
||||
pub fn assessment(&self) -> &str {
|
||||
if self.quality_score > 0.8 {
|
||||
"Excellent"
|
||||
} else if self.quality_score > 0.6 {
|
||||
"Good"
|
||||
} else if self.quality_score > 0.4 {
|
||||
"Fair"
|
||||
} else {
|
||||
"Poor"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_embedding_analysis() {
|
||||
let analyzer = TopologicalAnalyzer::new(3, 5.0);
|
||||
|
||||
// Create well-separated embeddings
|
||||
let embeddings = vec![
|
||||
vec![0.0, 0.0],
|
||||
vec![0.1, 0.1],
|
||||
vec![0.2, 0.2],
|
||||
vec![5.0, 5.0],
|
||||
vec![5.1, 5.1],
|
||||
];
|
||||
|
||||
let quality = analyzer.analyze(&embeddings).unwrap();
|
||||
|
||||
assert_eq!(quality.dimensions, 2);
|
||||
assert_eq!(quality.num_vectors, 5);
|
||||
assert!(quality.quality_score > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mode_collapse_detection() {
|
||||
let analyzer = TopologicalAnalyzer::new(2, 10.0);
|
||||
|
||||
// Well-separated embeddings (high CV should give high score)
|
||||
let good = vec![vec![0.0, 0.0], vec![5.0, 5.0], vec![10.0, 10.0]];
|
||||
let score_good = analyzer.detect_mode_collapse(&good);
|
||||
|
||||
// Collapsed embeddings (all identical, CV = 0)
|
||||
let collapsed = vec![vec![1.0, 1.0], vec![1.0, 1.0], vec![1.0, 1.0]];
|
||||
let score_collapsed = analyzer.detect_mode_collapse(&collapsed);
|
||||
|
||||
// Identical vectors should have score 0 (distances all same = CV 0)
|
||||
assert_eq!(score_collapsed, 0.0);
|
||||
|
||||
// Well-separated should have higher score
|
||||
assert!(score_good > score_collapsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connected_components() {
|
||||
let analyzer = TopologicalAnalyzer::new(1, 1.0);
|
||||
|
||||
// Two separate clusters
|
||||
let embeddings = vec![
|
||||
vec![0.0, 0.0],
|
||||
vec![0.5, 0.5],
|
||||
vec![10.0, 10.0],
|
||||
vec![10.5, 10.5],
|
||||
];
|
||||
|
||||
let graph = analyzer.build_knn_graph(&embeddings);
|
||||
let components = analyzer.count_connected_components(&graph, embeddings.len());
|
||||
|
||||
assert!(components >= 2); // Should have at least 2 components
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quality_assessment() {
|
||||
let quality = EmbeddingQuality {
|
||||
dimensions: 128,
|
||||
num_vectors: 1000,
|
||||
connected_components: 1,
|
||||
clustering_coefficient: 0.6,
|
||||
avg_degree: 5.0,
|
||||
degree_std: 1.2,
|
||||
mode_collapse_score: 0.8,
|
||||
degeneracy_score: 0.2,
|
||||
spread: 3.5,
|
||||
persistence_score: 0.4,
|
||||
quality_score: 0.75,
|
||||
};
|
||||
|
||||
assert!(!quality.has_mode_collapse());
|
||||
assert!(!quality.is_degenerate());
|
||||
assert!(quality.is_good_quality());
|
||||
assert_eq!(quality.assessment(), "Good");
|
||||
}
|
||||
}
|
||||
23
crates/ruvector-core/src/advanced_features.rs
Normal file
23
crates/ruvector-core/src/advanced_features.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
//! Advanced Features for Ruvector
|
||||
//!
|
||||
//! This module provides advanced vector database capabilities:
|
||||
//! - Enhanced Product Quantization with precomputed lookup tables
|
||||
//! - Filtered Search with automatic strategy selection
|
||||
//! - MMR (Maximal Marginal Relevance) for diversity
|
||||
//! - Hybrid Search combining vector and keyword matching
|
||||
//! - Conformal Prediction for uncertainty quantification
|
||||
|
||||
pub mod conformal_prediction;
|
||||
pub mod filtered_search;
|
||||
pub mod hybrid_search;
|
||||
pub mod mmr;
|
||||
pub mod product_quantization;
|
||||
|
||||
// Re-exports
|
||||
pub use conformal_prediction::{
|
||||
ConformalConfig, ConformalPredictor, NonconformityMeasure, PredictionSet,
|
||||
};
|
||||
pub use filtered_search::{FilterExpression, FilterStrategy, FilteredSearch};
|
||||
pub use hybrid_search::{HybridConfig, HybridSearch, NormalizationStrategy, BM25};
|
||||
pub use mmr::{MMRConfig, MMRSearch};
|
||||
pub use product_quantization::{EnhancedPQ, LookupTable, PQConfig};
|
||||
@@ -0,0 +1,503 @@
|
||||
//! Conformal Prediction for Uncertainty Quantification
|
||||
//!
|
||||
//! Implements conformal prediction to provide statistically valid uncertainty estimates
|
||||
//! and prediction sets with guaranteed coverage (1-α).
|
||||
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use crate::types::{SearchResult, VectorId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for conformal prediction
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConformalConfig {
|
||||
/// Significance level (alpha) - typically 0.05 or 0.10
|
||||
pub alpha: f32,
|
||||
/// Size of calibration set (as fraction of total data)
|
||||
pub calibration_fraction: f32,
|
||||
/// Non-conformity measure type
|
||||
pub nonconformity_measure: NonconformityMeasure,
|
||||
}
|
||||
|
||||
impl Default for ConformalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
alpha: 0.1, // 90% coverage
|
||||
calibration_fraction: 0.2,
|
||||
nonconformity_measure: NonconformityMeasure::Distance,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Type of non-conformity measure
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum NonconformityMeasure {
|
||||
/// Use distance score as non-conformity
|
||||
Distance,
|
||||
/// Use inverse rank as non-conformity
|
||||
InverseRank,
|
||||
/// Use normalized distance (distance / avg_distance)
|
||||
NormalizedDistance,
|
||||
}
|
||||
|
||||
/// Prediction set with conformal guarantees
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictionSet {
|
||||
/// Results in the prediction set
|
||||
pub results: Vec<SearchResult>,
|
||||
/// Conformal threshold used
|
||||
pub threshold: f32,
|
||||
/// Confidence level (1 - alpha)
|
||||
pub confidence: f32,
|
||||
/// Coverage guarantee
|
||||
pub coverage_guarantee: f32,
|
||||
}
|
||||
|
||||
/// Conformal predictor for vector search
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConformalPredictor {
|
||||
/// Configuration
|
||||
pub config: ConformalConfig,
|
||||
/// Calibration set: non-conformity scores
|
||||
pub calibration_scores: Vec<f32>,
|
||||
/// Conformal threshold (quantile of calibration scores)
|
||||
pub threshold: Option<f32>,
|
||||
}
|
||||
|
||||
impl ConformalPredictor {
|
||||
/// Create a new conformal predictor
|
||||
pub fn new(config: ConformalConfig) -> Result<Self> {
|
||||
if !(0.0..=1.0).contains(&config.alpha) {
|
||||
return Err(RuvectorError::InvalidParameter(format!(
|
||||
"Alpha must be in [0, 1], got {}",
|
||||
config.alpha
|
||||
)));
|
||||
}
|
||||
|
||||
if !(0.0..=1.0).contains(&config.calibration_fraction) {
|
||||
return Err(RuvectorError::InvalidParameter(format!(
|
||||
"Calibration fraction must be in [0, 1], got {}",
|
||||
config.calibration_fraction
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
calibration_scores: Vec::new(),
|
||||
threshold: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Calibrate on a set of validation examples
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `validation_queries` - Query vectors for calibration
|
||||
/// * `true_neighbors` - Ground truth neighbors for each query
|
||||
/// * `search_fn` - Function to perform search
|
||||
pub fn calibrate<F>(
|
||||
&mut self,
|
||||
validation_queries: &[Vec<f32>],
|
||||
true_neighbors: &[Vec<VectorId>],
|
||||
search_fn: F,
|
||||
) -> Result<()>
|
||||
where
|
||||
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
|
||||
{
|
||||
if validation_queries.len() != true_neighbors.len() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Number of queries must match number of true neighbor sets".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if validation_queries.is_empty() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Calibration set cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut all_scores = Vec::new();
|
||||
|
||||
// Compute non-conformity scores for calibration set
|
||||
for (query, true_ids) in validation_queries.iter().zip(true_neighbors) {
|
||||
// Search for neighbors
|
||||
let results = search_fn(query, 100)?; // Fetch more results
|
||||
|
||||
// Compute non-conformity scores for true neighbors
|
||||
for true_id in true_ids {
|
||||
let score = self.compute_nonconformity_score(&results, true_id)?;
|
||||
all_scores.push(score);
|
||||
}
|
||||
}
|
||||
|
||||
self.calibration_scores = all_scores;
|
||||
|
||||
// Compute threshold as (1 - alpha) quantile
|
||||
self.compute_threshold()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute conformal threshold from calibration scores
|
||||
fn compute_threshold(&mut self) -> Result<()> {
|
||||
if self.calibration_scores.is_empty() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"No calibration scores available".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut sorted_scores = self.calibration_scores.clone();
|
||||
sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
|
||||
// Compute (1 - alpha) quantile
|
||||
let n = sorted_scores.len();
|
||||
let quantile_index = ((1.0 - self.config.alpha) * (n as f32 + 1.0)).ceil() as usize;
|
||||
let quantile_index = quantile_index.min(n - 1);
|
||||
|
||||
self.threshold = Some(sorted_scores[quantile_index]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute non-conformity score for a specific result
|
||||
fn compute_nonconformity_score(
|
||||
&self,
|
||||
results: &[SearchResult],
|
||||
target_id: &VectorId,
|
||||
) -> Result<f32> {
|
||||
match self.config.nonconformity_measure {
|
||||
NonconformityMeasure::Distance => {
|
||||
// Use distance score directly
|
||||
results
|
||||
.iter()
|
||||
.find(|r| &r.id == target_id)
|
||||
.map(|r| r.score)
|
||||
.ok_or_else(|| {
|
||||
RuvectorError::VectorNotFound(format!(
|
||||
"Target {} not in results",
|
||||
target_id
|
||||
))
|
||||
})
|
||||
}
|
||||
NonconformityMeasure::InverseRank => {
|
||||
// Use inverse rank: 1 / (rank + 1)
|
||||
let rank = results
|
||||
.iter()
|
||||
.position(|r| &r.id == target_id)
|
||||
.ok_or_else(|| {
|
||||
RuvectorError::VectorNotFound(format!(
|
||||
"Target {} not in results",
|
||||
target_id
|
||||
))
|
||||
})?;
|
||||
Ok(1.0 / (rank as f32 + 1.0))
|
||||
}
|
||||
NonconformityMeasure::NormalizedDistance => {
|
||||
// Normalize by average distance
|
||||
let target_score = results
|
||||
.iter()
|
||||
.find(|r| &r.id == target_id)
|
||||
.map(|r| r.score)
|
||||
.ok_or_else(|| {
|
||||
RuvectorError::VectorNotFound(format!(
|
||||
"Target {} not in results",
|
||||
target_id
|
||||
))
|
||||
})?;
|
||||
|
||||
// Guard against empty results
|
||||
if results.is_empty() {
|
||||
return Ok(target_score);
|
||||
}
|
||||
|
||||
let avg_score = results.iter().map(|r| r.score).sum::<f32>() / results.len() as f32;
|
||||
|
||||
Ok(if avg_score > 0.0 {
|
||||
target_score / avg_score
|
||||
} else {
|
||||
target_score
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make prediction with conformal guarantee
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `search_fn` - Function to perform search
|
||||
///
|
||||
/// # Returns
|
||||
/// Prediction set with coverage guarantee
|
||||
pub fn predict<F>(&self, query: &[f32], search_fn: F) -> Result<PredictionSet>
|
||||
where
|
||||
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
|
||||
{
|
||||
let threshold = self.threshold.ok_or_else(|| {
|
||||
RuvectorError::InvalidParameter("Predictor not calibrated yet".to_string())
|
||||
})?;
|
||||
|
||||
// Perform search with large k
|
||||
let results = search_fn(query, 1000)?;
|
||||
|
||||
// Select results based on non-conformity threshold
|
||||
let prediction_set: Vec<SearchResult> = match self.config.nonconformity_measure {
|
||||
NonconformityMeasure::Distance => {
|
||||
// Include all results with distance <= threshold
|
||||
results
|
||||
.into_iter()
|
||||
.filter(|r| r.score <= threshold)
|
||||
.collect()
|
||||
}
|
||||
NonconformityMeasure::InverseRank => {
|
||||
// Include top-k results where k is determined by threshold
|
||||
let k = (1.0 / threshold).ceil() as usize;
|
||||
results.into_iter().take(k).collect()
|
||||
}
|
||||
NonconformityMeasure::NormalizedDistance => {
|
||||
// Guard against empty results
|
||||
if results.is_empty() {
|
||||
return Ok(PredictionSet {
|
||||
results: vec![],
|
||||
threshold,
|
||||
confidence: 1.0 - self.config.alpha,
|
||||
coverage_guarantee: 1.0 - self.config.alpha,
|
||||
});
|
||||
}
|
||||
|
||||
let avg_score = results.iter().map(|r| r.score).sum::<f32>() / results.len() as f32;
|
||||
let adjusted_threshold = threshold * avg_score;
|
||||
results
|
||||
.into_iter()
|
||||
.filter(|r| r.score <= adjusted_threshold)
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(PredictionSet {
|
||||
results: prediction_set,
|
||||
threshold,
|
||||
confidence: 1.0 - self.config.alpha,
|
||||
coverage_guarantee: 1.0 - self.config.alpha,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compute adaptive top-k based on uncertainty
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `search_fn` - Function to perform search
|
||||
///
|
||||
/// # Returns
|
||||
/// Number of results to return based on uncertainty
|
||||
pub fn adaptive_top_k<F>(&self, query: &[f32], search_fn: F) -> Result<usize>
|
||||
where
|
||||
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
|
||||
{
|
||||
let prediction_set = self.predict(query, search_fn)?;
|
||||
Ok(prediction_set.results.len())
|
||||
}
|
||||
|
||||
/// Get calibration statistics
|
||||
pub fn get_statistics(&self) -> Option<CalibrationStats> {
|
||||
if self.calibration_scores.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let n = self.calibration_scores.len() as f32;
|
||||
let mean = self.calibration_scores.iter().sum::<f32>() / n;
|
||||
let variance = self
|
||||
.calibration_scores
|
||||
.iter()
|
||||
.map(|&s| (s - mean).powi(2))
|
||||
.sum::<f32>()
|
||||
/ n;
|
||||
let std = variance.sqrt();
|
||||
|
||||
let mut sorted = self.calibration_scores.clone();
|
||||
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
|
||||
Some(CalibrationStats {
|
||||
num_samples: self.calibration_scores.len(),
|
||||
mean,
|
||||
std,
|
||||
min: sorted.first().copied().unwrap(),
|
||||
max: sorted.last().copied().unwrap(),
|
||||
median: sorted[sorted.len() / 2],
|
||||
threshold: self.threshold,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Calibration statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CalibrationStats {
|
||||
/// Number of calibration samples
|
||||
pub num_samples: usize,
|
||||
/// Mean non-conformity score
|
||||
pub mean: f32,
|
||||
/// Standard deviation
|
||||
pub std: f32,
|
||||
/// Minimum score
|
||||
pub min: f32,
|
||||
/// Maximum score
|
||||
pub max: f32,
|
||||
/// Median score
|
||||
pub median: f32,
|
||||
/// Conformal threshold
|
||||
pub threshold: Option<f32>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_search_result(id: &str, score: f32) -> SearchResult {
|
||||
SearchResult {
|
||||
id: id.to_string(),
|
||||
score,
|
||||
vector: Some(vec![0.0; 10]),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_search_fn(_query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
|
||||
Ok((0..k)
|
||||
.map(|i| create_search_result(&format!("doc_{}", i), i as f32 * 0.1))
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_config_validation() {
|
||||
let config = ConformalConfig {
|
||||
alpha: 0.1,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(ConformalPredictor::new(config).is_ok());
|
||||
|
||||
let invalid_config = ConformalConfig {
|
||||
alpha: 1.5,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(ConformalPredictor::new(invalid_config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_calibration() {
|
||||
let config = ConformalConfig::default();
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
// Create calibration data
|
||||
let queries = vec![vec![1.0; 10], vec![2.0; 10], vec![3.0; 10]];
|
||||
let true_neighbors = vec![
|
||||
vec!["doc_0".to_string(), "doc_1".to_string()],
|
||||
vec!["doc_0".to_string()],
|
||||
vec!["doc_1".to_string(), "doc_2".to_string()],
|
||||
];
|
||||
|
||||
predictor
|
||||
.calibrate(&queries, &true_neighbors, mock_search_fn)
|
||||
.unwrap();
|
||||
|
||||
assert!(!predictor.calibration_scores.is_empty());
|
||||
assert!(predictor.threshold.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_prediction() {
|
||||
let config = ConformalConfig {
|
||||
alpha: 0.1,
|
||||
calibration_fraction: 0.2,
|
||||
nonconformity_measure: NonconformityMeasure::Distance,
|
||||
};
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
// Calibrate
|
||||
let queries = vec![vec![1.0; 10], vec![2.0; 10]];
|
||||
let true_neighbors = vec![vec!["doc_0".to_string()], vec!["doc_1".to_string()]];
|
||||
|
||||
predictor
|
||||
.calibrate(&queries, &true_neighbors, mock_search_fn)
|
||||
.unwrap();
|
||||
|
||||
// Make prediction
|
||||
let query = vec![1.5; 10];
|
||||
let prediction_set = predictor.predict(&query, mock_search_fn).unwrap();
|
||||
|
||||
assert!(!prediction_set.results.is_empty());
|
||||
assert_eq!(prediction_set.confidence, 0.9);
|
||||
assert!(prediction_set.threshold > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonconformity_distance() {
|
||||
let config = ConformalConfig {
|
||||
nonconformity_measure: NonconformityMeasure::Distance,
|
||||
..Default::default()
|
||||
};
|
||||
let predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
let results = vec![
|
||||
create_search_result("doc_0", 0.1),
|
||||
create_search_result("doc_1", 0.3),
|
||||
create_search_result("doc_2", 0.5),
|
||||
];
|
||||
|
||||
let score = predictor
|
||||
.compute_nonconformity_score(&results, &"doc_1".to_string())
|
||||
.unwrap();
|
||||
assert!((score - 0.3).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonconformity_inverse_rank() {
|
||||
let config = ConformalConfig {
|
||||
nonconformity_measure: NonconformityMeasure::InverseRank,
|
||||
..Default::default()
|
||||
};
|
||||
let predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
let results = vec![
|
||||
create_search_result("doc_0", 0.1),
|
||||
create_search_result("doc_1", 0.3),
|
||||
create_search_result("doc_2", 0.5),
|
||||
];
|
||||
|
||||
let score = predictor
|
||||
.compute_nonconformity_score(&results, &"doc_1".to_string())
|
||||
.unwrap();
|
||||
assert!((score - 0.5).abs() < 0.01); // 1 / (1 + 1) = 0.5
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibration_stats() {
|
||||
let config = ConformalConfig::default();
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
predictor.calibration_scores = vec![0.1, 0.2, 0.3, 0.4, 0.5];
|
||||
predictor.threshold = Some(0.4);
|
||||
|
||||
let stats = predictor.get_statistics().unwrap();
|
||||
assert_eq!(stats.num_samples, 5);
|
||||
assert!((stats.mean - 0.3).abs() < 0.01);
|
||||
assert!((stats.min - 0.1).abs() < 0.01);
|
||||
assert!((stats.max - 0.5).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_top_k() {
|
||||
let config = ConformalConfig::default();
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
// Calibrate
|
||||
let queries = vec![vec![1.0; 10], vec![2.0; 10]];
|
||||
let true_neighbors = vec![vec!["doc_0".to_string()], vec!["doc_1".to_string()]];
|
||||
|
||||
predictor
|
||||
.calibrate(&queries, &true_neighbors, mock_search_fn)
|
||||
.unwrap();
|
||||
|
||||
// Test adaptive k
|
||||
let query = vec![1.5; 10];
|
||||
let k = predictor.adaptive_top_k(&query, mock_search_fn).unwrap();
|
||||
assert!(k > 0);
|
||||
}
|
||||
}
|
||||
363
crates/ruvector-core/src/advanced_features/filtered_search.rs
Normal file
363
crates/ruvector-core/src/advanced_features/filtered_search.rs
Normal file
@@ -0,0 +1,363 @@
|
||||
//! Filtered Search with Automatic Strategy Selection
|
||||
//!
|
||||
//! Supports two filtering strategies:
|
||||
//! - Pre-filtering: Apply metadata filters before graph traversal
|
||||
//! - Post-filtering: Traverse graph then apply filters
|
||||
//! - Automatic strategy selection based on filter selectivity
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::types::{SearchResult, VectorId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Filter strategy selection
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum FilterStrategy {
|
||||
/// Apply filters before search (efficient for highly selective filters)
|
||||
PreFilter,
|
||||
/// Apply filters after search (efficient for low selectivity)
|
||||
PostFilter,
|
||||
/// Automatically select strategy based on estimated selectivity
|
||||
Auto,
|
||||
}
|
||||
|
||||
/// Filter expression for metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum FilterExpression {
|
||||
/// Equality check: field == value
|
||||
Eq(String, serde_json::Value),
|
||||
/// Not equal: field != value
|
||||
Ne(String, serde_json::Value),
|
||||
/// Greater than: field > value
|
||||
Gt(String, serde_json::Value),
|
||||
/// Greater than or equal: field >= value
|
||||
Gte(String, serde_json::Value),
|
||||
/// Less than: field < value
|
||||
Lt(String, serde_json::Value),
|
||||
/// Less than or equal: field <= value
|
||||
Lte(String, serde_json::Value),
|
||||
/// In list: field in [values]
|
||||
In(String, Vec<serde_json::Value>),
|
||||
/// Not in list: field not in [values]
|
||||
NotIn(String, Vec<serde_json::Value>),
|
||||
/// Range check: min <= field <= max
|
||||
Range(String, serde_json::Value, serde_json::Value),
|
||||
/// Logical AND
|
||||
And(Vec<FilterExpression>),
|
||||
/// Logical OR
|
||||
Or(Vec<FilterExpression>),
|
||||
/// Logical NOT
|
||||
Not(Box<FilterExpression>),
|
||||
}
|
||||
|
||||
impl FilterExpression {
|
||||
/// Evaluate filter against metadata
|
||||
pub fn evaluate(&self, metadata: &HashMap<String, serde_json::Value>) -> bool {
|
||||
match self {
|
||||
FilterExpression::Eq(field, value) => metadata.get(field) == Some(value),
|
||||
FilterExpression::Ne(field, value) => metadata.get(field) != Some(value),
|
||||
FilterExpression::Gt(field, value) => {
|
||||
if let Some(field_value) = metadata.get(field) {
|
||||
compare_values(field_value, value) > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
FilterExpression::Gte(field, value) => {
|
||||
if let Some(field_value) = metadata.get(field) {
|
||||
compare_values(field_value, value) >= 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
FilterExpression::Lt(field, value) => {
|
||||
if let Some(field_value) = metadata.get(field) {
|
||||
compare_values(field_value, value) < 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
FilterExpression::Lte(field, value) => {
|
||||
if let Some(field_value) = metadata.get(field) {
|
||||
compare_values(field_value, value) <= 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
FilterExpression::In(field, values) => {
|
||||
if let Some(field_value) = metadata.get(field) {
|
||||
values.contains(field_value)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
FilterExpression::NotIn(field, values) => {
|
||||
if let Some(field_value) = metadata.get(field) {
|
||||
!values.contains(field_value)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
FilterExpression::Range(field, min, max) => {
|
||||
if let Some(field_value) = metadata.get(field) {
|
||||
compare_values(field_value, min) >= 0 && compare_values(field_value, max) <= 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
FilterExpression::And(exprs) => exprs.iter().all(|e| e.evaluate(metadata)),
|
||||
FilterExpression::Or(exprs) => exprs.iter().any(|e| e.evaluate(metadata)),
|
||||
FilterExpression::Not(expr) => !expr.evaluate(metadata),
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate selectivity of filter (0.0 = very selective, 1.0 = not selective)
|
||||
#[allow(clippy::only_used_in_recursion)]
|
||||
pub fn estimate_selectivity(&self, total_vectors: usize) -> f32 {
|
||||
match self {
|
||||
FilterExpression::Eq(_, _) => 0.1, // Equality is typically selective
|
||||
FilterExpression::Ne(_, _) => 0.9, // Not equal is less selective
|
||||
FilterExpression::In(_, values) => (values.len() as f32) / 100.0,
|
||||
FilterExpression::NotIn(_, values) => 1.0 - (values.len() as f32) / 100.0,
|
||||
FilterExpression::Range(_, _, _) => 0.3, // Ranges are moderately selective
|
||||
FilterExpression::Gt(_, _) | FilterExpression::Gte(_, _) => 0.5,
|
||||
FilterExpression::Lt(_, _) | FilterExpression::Lte(_, _) => 0.5,
|
||||
FilterExpression::And(exprs) => {
|
||||
// AND is more selective (multiply selectivities)
|
||||
exprs
|
||||
.iter()
|
||||
.map(|e| e.estimate_selectivity(total_vectors))
|
||||
.product()
|
||||
}
|
||||
FilterExpression::Or(exprs) => {
|
||||
// OR is less selective (sum selectivities, capped at 1.0)
|
||||
exprs
|
||||
.iter()
|
||||
.map(|e| e.estimate_selectivity(total_vectors))
|
||||
.sum::<f32>()
|
||||
.min(1.0)
|
||||
}
|
||||
FilterExpression::Not(expr) => 1.0 - expr.estimate_selectivity(total_vectors),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filtered search implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FilteredSearch {
|
||||
/// Filter expression
|
||||
pub filter: FilterExpression,
|
||||
/// Strategy for applying filter
|
||||
pub strategy: FilterStrategy,
|
||||
/// Metadata store: id -> metadata
|
||||
pub metadata_store: HashMap<VectorId, HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
impl FilteredSearch {
|
||||
/// Create a new filtered search instance
|
||||
pub fn new(
|
||||
filter: FilterExpression,
|
||||
strategy: FilterStrategy,
|
||||
metadata_store: HashMap<VectorId, HashMap<String, serde_json::Value>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
filter,
|
||||
strategy,
|
||||
metadata_store,
|
||||
}
|
||||
}
|
||||
|
||||
/// Automatically select strategy based on filter selectivity
|
||||
pub fn auto_select_strategy(&self) -> FilterStrategy {
|
||||
let selectivity = self.filter.estimate_selectivity(self.metadata_store.len());
|
||||
|
||||
// If filter is highly selective (< 20%), use pre-filtering
|
||||
// Otherwise use post-filtering
|
||||
if selectivity < 0.2 {
|
||||
FilterStrategy::PreFilter
|
||||
} else {
|
||||
FilterStrategy::PostFilter
|
||||
}
|
||||
}
|
||||
|
||||
/// Get list of vector IDs that pass the filter (for pre-filtering)
|
||||
pub fn get_filtered_ids(&self) -> Vec<VectorId> {
|
||||
self.metadata_store
|
||||
.iter()
|
||||
.filter(|(_, metadata)| self.filter.evaluate(metadata))
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Apply filter to search results (for post-filtering)
|
||||
pub fn filter_results(&self, results: Vec<SearchResult>) -> Vec<SearchResult> {
|
||||
results
|
||||
.into_iter()
|
||||
.filter(|result| {
|
||||
if let Some(metadata) = result.metadata.as_ref() {
|
||||
self.filter.evaluate(metadata)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Apply filtered search with automatic strategy selection
|
||||
pub fn search<F>(&self, query: &[f32], k: usize, search_fn: F) -> Result<Vec<SearchResult>>
|
||||
where
|
||||
F: Fn(&[f32], usize, Option<&[VectorId]>) -> Result<Vec<SearchResult>>,
|
||||
{
|
||||
let strategy = match self.strategy {
|
||||
FilterStrategy::Auto => self.auto_select_strategy(),
|
||||
other => other,
|
||||
};
|
||||
|
||||
match strategy {
|
||||
FilterStrategy::PreFilter => {
|
||||
// Get filtered IDs first
|
||||
let filtered_ids = self.get_filtered_ids();
|
||||
|
||||
if filtered_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Search only within filtered IDs
|
||||
// We may need to fetch more results to get k after filtering
|
||||
let fetch_k = (k as f32 * 1.5).ceil() as usize;
|
||||
search_fn(query, fetch_k, Some(&filtered_ids))
|
||||
}
|
||||
FilterStrategy::PostFilter => {
|
||||
// Search first, then filter
|
||||
// Fetch more results to ensure we get k after filtering
|
||||
let fetch_k = (k as f32 * 2.0).ceil() as usize;
|
||||
let results = search_fn(query, fetch_k, None)?;
|
||||
|
||||
// Apply filter
|
||||
let filtered = self.filter_results(results);
|
||||
|
||||
// Return top-k
|
||||
Ok(filtered.into_iter().take(k).collect())
|
||||
}
|
||||
FilterStrategy::Auto => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare JSON values
|
||||
fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> i32 {
|
||||
use serde_json::Value;
|
||||
|
||||
match (a, b) {
|
||||
(Value::Number(a), Value::Number(b)) => {
|
||||
let a_f64 = a.as_f64().unwrap_or(0.0);
|
||||
let b_f64 = b.as_f64().unwrap_or(0.0);
|
||||
if a_f64 < b_f64 {
|
||||
-1
|
||||
} else if a_f64 > b_f64 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
(Value::String(a), Value::String(b)) => a.cmp(b) as i32,
|
||||
(Value::Bool(a), Value::Bool(b)) => a.cmp(b) as i32,
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_filter_eq() {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("category".to_string(), json!("electronics"));
|
||||
|
||||
let filter = FilterExpression::Eq("category".to_string(), json!("electronics"));
|
||||
assert!(filter.evaluate(&metadata));
|
||||
|
||||
let filter = FilterExpression::Eq("category".to_string(), json!("books"));
|
||||
assert!(!filter.evaluate(&metadata));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_range() {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("price".to_string(), json!(50.0));
|
||||
|
||||
let filter = FilterExpression::Range("price".to_string(), json!(10.0), json!(100.0));
|
||||
assert!(filter.evaluate(&metadata));
|
||||
|
||||
let filter = FilterExpression::Range("price".to_string(), json!(60.0), json!(100.0));
|
||||
assert!(!filter.evaluate(&metadata));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_and() {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("category".to_string(), json!("electronics"));
|
||||
metadata.insert("price".to_string(), json!(50.0));
|
||||
|
||||
let filter = FilterExpression::And(vec![
|
||||
FilterExpression::Eq("category".to_string(), json!("electronics")),
|
||||
FilterExpression::Lt("price".to_string(), json!(100.0)),
|
||||
]);
|
||||
assert!(filter.evaluate(&metadata));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_or() {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("category".to_string(), json!("electronics"));
|
||||
|
||||
let filter = FilterExpression::Or(vec![
|
||||
FilterExpression::Eq("category".to_string(), json!("books")),
|
||||
FilterExpression::Eq("category".to_string(), json!("electronics")),
|
||||
]);
|
||||
assert!(filter.evaluate(&metadata));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_in() {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("tag".to_string(), json!("popular"));
|
||||
|
||||
let filter = FilterExpression::In(
|
||||
"tag".to_string(),
|
||||
vec![json!("popular"), json!("trending"), json!("new")],
|
||||
);
|
||||
assert!(filter.evaluate(&metadata));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_selectivity_estimation() {
|
||||
let filter_eq = FilterExpression::Eq("field".to_string(), json!("value"));
|
||||
assert!(filter_eq.estimate_selectivity(1000) < 0.5);
|
||||
|
||||
let filter_ne = FilterExpression::Ne("field".to_string(), json!("value"));
|
||||
assert!(filter_ne.estimate_selectivity(1000) > 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_strategy_selection() {
|
||||
let mut metadata_store = HashMap::new();
|
||||
for i in 0..100 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("id".to_string(), json!(i));
|
||||
metadata_store.insert(format!("vec_{}", i), metadata);
|
||||
}
|
||||
|
||||
// Highly selective filter should choose pre-filter
|
||||
let filter = FilterExpression::Eq("id".to_string(), json!(42));
|
||||
let search = FilteredSearch::new(filter, FilterStrategy::Auto, metadata_store.clone());
|
||||
assert_eq!(search.auto_select_strategy(), FilterStrategy::PreFilter);
|
||||
|
||||
// Less selective filter should choose post-filter
|
||||
let filter = FilterExpression::Gte("id".to_string(), json!(0));
|
||||
let search = FilteredSearch::new(filter, FilterStrategy::Auto, metadata_store);
|
||||
assert_eq!(search.auto_select_strategy(), FilterStrategy::PostFilter);
|
||||
}
|
||||
}
|
||||
444
crates/ruvector-core/src/advanced_features/hybrid_search.rs
Normal file
444
crates/ruvector-core/src/advanced_features/hybrid_search.rs
Normal file
@@ -0,0 +1,444 @@
|
||||
//! Hybrid Search: Combining Vector Similarity and Keyword Matching
|
||||
//!
|
||||
//! Implements hybrid search by combining:
|
||||
//! - Vector similarity search (semantic)
|
||||
//! - BM25 keyword matching (lexical)
|
||||
//! - Weighted combination of scores
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::types::{SearchResult, VectorId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Configuration for hybrid search
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HybridConfig {
|
||||
/// Weight for vector similarity (alpha)
|
||||
pub vector_weight: f32,
|
||||
/// Weight for keyword matching (beta)
|
||||
pub keyword_weight: f32,
|
||||
/// Normalization strategy
|
||||
pub normalization: NormalizationStrategy,
|
||||
}
|
||||
|
||||
impl Default for HybridConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vector_weight: 0.7,
|
||||
keyword_weight: 0.3,
|
||||
normalization: NormalizationStrategy::MinMax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Score normalization strategy
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum NormalizationStrategy {
|
||||
/// Min-max normalization: (x - min) / (max - min)
|
||||
MinMax,
|
||||
/// Z-score normalization: (x - mean) / std
|
||||
ZScore,
|
||||
/// No normalization
|
||||
None,
|
||||
}
|
||||
|
||||
/// Simple BM25 implementation for keyword matching
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BM25 {
|
||||
/// IDF scores for terms
|
||||
pub idf: HashMap<String, f32>,
|
||||
/// Average document length
|
||||
pub avg_doc_len: f32,
|
||||
/// Document lengths
|
||||
pub doc_lengths: HashMap<VectorId, usize>,
|
||||
/// Inverted index: term -> set of doc IDs
|
||||
pub inverted_index: HashMap<String, HashSet<VectorId>>,
|
||||
/// BM25 parameters
|
||||
pub k1: f32,
|
||||
pub b: f32,
|
||||
}
|
||||
|
||||
impl BM25 {
|
||||
/// Create a new BM25 instance
|
||||
pub fn new(k1: f32, b: f32) -> Self {
|
||||
Self {
|
||||
idf: HashMap::new(),
|
||||
avg_doc_len: 0.0,
|
||||
doc_lengths: HashMap::new(),
|
||||
inverted_index: HashMap::new(),
|
||||
k1,
|
||||
b,
|
||||
}
|
||||
}
|
||||
|
||||
/// Index a document
|
||||
pub fn index_document(&mut self, doc_id: VectorId, text: &str) {
|
||||
let terms = tokenize(text);
|
||||
self.doc_lengths.insert(doc_id.clone(), terms.len());
|
||||
|
||||
// Update inverted index
|
||||
for term in terms {
|
||||
self.inverted_index
|
||||
.entry(term)
|
||||
.or_default()
|
||||
.insert(doc_id.clone());
|
||||
}
|
||||
|
||||
// Update average document length
|
||||
let total_len: usize = self.doc_lengths.values().sum();
|
||||
self.avg_doc_len = total_len as f32 / self.doc_lengths.len() as f32;
|
||||
}
|
||||
|
||||
/// Build IDF scores after indexing all documents
|
||||
pub fn build_idf(&mut self) {
|
||||
let num_docs = self.doc_lengths.len() as f32;
|
||||
|
||||
for (term, doc_set) in &self.inverted_index {
|
||||
let doc_freq = doc_set.len() as f32;
|
||||
let idf = ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
|
||||
self.idf.insert(term.clone(), idf);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute BM25 score for a query against a document
|
||||
pub fn score(&self, query: &str, doc_id: &VectorId, doc_text: &str) -> f32 {
|
||||
let query_terms = tokenize(query);
|
||||
let doc_terms = tokenize(doc_text);
|
||||
let doc_len = self.doc_lengths.get(doc_id).copied().unwrap_or(0) as f32;
|
||||
|
||||
// Count term frequencies in document
|
||||
let mut term_freq: HashMap<String, f32> = HashMap::new();
|
||||
for term in doc_terms {
|
||||
*term_freq.entry(term).or_insert(0.0) += 1.0;
|
||||
}
|
||||
|
||||
// Calculate BM25 score
|
||||
let mut score = 0.0;
|
||||
for term in query_terms {
|
||||
let idf = self.idf.get(&term).copied().unwrap_or(0.0);
|
||||
let tf = term_freq.get(&term).copied().unwrap_or(0.0);
|
||||
|
||||
let numerator = tf * (self.k1 + 1.0);
|
||||
let denominator = tf + self.k1 * (1.0 - self.b + self.b * (doc_len / self.avg_doc_len));
|
||||
|
||||
score += idf * (numerator / denominator);
|
||||
}
|
||||
|
||||
score
|
||||
}
|
||||
|
||||
/// Get all documents containing at least one query term
|
||||
pub fn get_candidate_docs(&self, query: &str) -> HashSet<VectorId> {
|
||||
let query_terms = tokenize(query);
|
||||
let mut candidates = HashSet::new();
|
||||
|
||||
for term in query_terms {
|
||||
if let Some(doc_set) = self.inverted_index.get(&term) {
|
||||
candidates.extend(doc_set.iter().cloned());
|
||||
}
|
||||
}
|
||||
|
||||
candidates
|
||||
}
|
||||
}
|
||||
|
||||
/// Hybrid search combining vector and keyword matching
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HybridSearch {
|
||||
/// Configuration
|
||||
pub config: HybridConfig,
|
||||
/// BM25 index for keyword matching
|
||||
pub bm25: BM25,
|
||||
/// Document texts for BM25 scoring
|
||||
pub doc_texts: HashMap<VectorId, String>,
|
||||
}
|
||||
|
||||
impl HybridSearch {
|
||||
/// Create a new hybrid search instance
|
||||
pub fn new(config: HybridConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
bm25: BM25::new(1.5, 0.75), // Standard BM25 parameters
|
||||
doc_texts: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Index a document with both vector and text
|
||||
pub fn index_document(&mut self, doc_id: VectorId, text: String) {
|
||||
self.bm25.index_document(doc_id.clone(), &text);
|
||||
self.doc_texts.insert(doc_id, text);
|
||||
}
|
||||
|
||||
/// Finalize indexing (build IDF scores)
|
||||
pub fn finalize_indexing(&mut self) {
|
||||
self.bm25.build_idf();
|
||||
}
|
||||
|
||||
/// Perform hybrid search
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query_vector` - Query vector for semantic search
|
||||
/// * `query_text` - Query text for keyword matching
|
||||
/// * `k` - Number of results to return
|
||||
/// * `vector_search_fn` - Function to perform vector similarity search
|
||||
///
|
||||
/// # Returns
|
||||
/// Combined and reranked search results
|
||||
pub fn search<F>(
|
||||
&self,
|
||||
query_vector: &[f32],
|
||||
query_text: &str,
|
||||
k: usize,
|
||||
vector_search_fn: F,
|
||||
) -> Result<Vec<SearchResult>>
|
||||
where
|
||||
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
|
||||
{
|
||||
// Get vector similarity results
|
||||
let vector_results = vector_search_fn(query_vector, k * 2)?;
|
||||
|
||||
// Get keyword matching candidates
|
||||
let keyword_candidates = self.bm25.get_candidate_docs(query_text);
|
||||
|
||||
// Compute BM25 scores for all candidates
|
||||
let mut bm25_scores: HashMap<VectorId, f32> = HashMap::new();
|
||||
for doc_id in &keyword_candidates {
|
||||
if let Some(doc_text) = self.doc_texts.get(doc_id) {
|
||||
let score = self.bm25.score(query_text, doc_id, doc_text);
|
||||
bm25_scores.insert(doc_id.clone(), score);
|
||||
}
|
||||
}
|
||||
|
||||
// Combine results
|
||||
let mut combined_results: HashMap<VectorId, CombinedScore> = HashMap::new();
|
||||
|
||||
// Add vector results
|
||||
for result in vector_results {
|
||||
combined_results.insert(
|
||||
result.id.clone(),
|
||||
CombinedScore {
|
||||
id: result.id.clone(),
|
||||
vector_score: Some(result.score),
|
||||
keyword_score: bm25_scores.get(&result.id).copied(),
|
||||
vector: result.vector,
|
||||
metadata: result.metadata,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Add keyword-only results
|
||||
for (doc_id, bm25_score) in bm25_scores {
|
||||
combined_results
|
||||
.entry(doc_id.clone())
|
||||
.or_insert(CombinedScore {
|
||||
id: doc_id,
|
||||
vector_score: None,
|
||||
keyword_score: Some(bm25_score),
|
||||
vector: None,
|
||||
metadata: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Normalize and combine scores
|
||||
let normalized_results =
|
||||
self.normalize_and_combine(combined_results.into_values().collect())?;
|
||||
|
||||
// Sort by combined score (descending)
|
||||
let mut sorted_results = normalized_results;
|
||||
sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
|
||||
|
||||
// Return top-k
|
||||
Ok(sorted_results.into_iter().take(k).collect())
|
||||
}
|
||||
|
||||
/// Normalize and combine scores
|
||||
fn normalize_and_combine(&self, results: Vec<CombinedScore>) -> Result<Vec<SearchResult>> {
|
||||
let mut vector_scores: Vec<f32> = results.iter().filter_map(|r| r.vector_score).collect();
|
||||
let mut keyword_scores: Vec<f32> = results.iter().filter_map(|r| r.keyword_score).collect();
|
||||
|
||||
// Normalize scores
|
||||
normalize_scores(&mut vector_scores, self.config.normalization);
|
||||
normalize_scores(&mut keyword_scores, self.config.normalization);
|
||||
|
||||
// Create lookup maps
|
||||
let mut vector_map: HashMap<VectorId, f32> = HashMap::new();
|
||||
let mut keyword_map: HashMap<VectorId, f32> = HashMap::new();
|
||||
|
||||
for (result, &norm_score) in results.iter().zip(&vector_scores) {
|
||||
if result.vector_score.is_some() {
|
||||
vector_map.insert(result.id.clone(), norm_score);
|
||||
}
|
||||
}
|
||||
|
||||
for (result, &norm_score) in results.iter().zip(&keyword_scores) {
|
||||
if result.keyword_score.is_some() {
|
||||
keyword_map.insert(result.id.clone(), norm_score);
|
||||
}
|
||||
}
|
||||
|
||||
// Combine scores
|
||||
let combined: Vec<SearchResult> = results
|
||||
.into_iter()
|
||||
.map(|result| {
|
||||
let vector_norm = vector_map.get(&result.id).copied().unwrap_or(0.0);
|
||||
let keyword_norm = keyword_map.get(&result.id).copied().unwrap_or(0.0);
|
||||
|
||||
let combined_score = self.config.vector_weight * vector_norm
|
||||
+ self.config.keyword_weight * keyword_norm;
|
||||
|
||||
SearchResult {
|
||||
id: result.id,
|
||||
score: combined_score,
|
||||
vector: result.vector,
|
||||
metadata: result.metadata,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(combined)
|
||||
}
|
||||
}
|
||||
|
||||
/// Combined score holder
|
||||
#[derive(Debug, Clone)]
|
||||
struct CombinedScore {
|
||||
id: VectorId,
|
||||
vector_score: Option<f32>,
|
||||
keyword_score: Option<f32>,
|
||||
vector: Option<Vec<f32>>,
|
||||
metadata: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
fn tokenize(text: &str) -> Vec<String> {
|
||||
text.to_lowercase()
|
||||
.split_whitespace()
|
||||
.filter(|s| s.len() > 2) // Remove very short tokens
|
||||
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn normalize_scores(scores: &mut [f32], strategy: NormalizationStrategy) {
|
||||
if scores.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
match strategy {
|
||||
NormalizationStrategy::MinMax => {
|
||||
let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
|
||||
let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let range = max - min;
|
||||
|
||||
if range > 0.0 {
|
||||
for score in scores.iter_mut() {
|
||||
*score = (*score - min) / range;
|
||||
}
|
||||
}
|
||||
}
|
||||
NormalizationStrategy::ZScore => {
|
||||
let mean = scores.iter().sum::<f32>() / scores.len() as f32;
|
||||
let variance =
|
||||
scores.iter().map(|&s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
|
||||
let std = variance.sqrt();
|
||||
|
||||
if std > 0.0 {
|
||||
for score in scores.iter_mut() {
|
||||
*score = (*score - mean) / std;
|
||||
}
|
||||
}
|
||||
}
|
||||
NormalizationStrategy::None => {}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tokenize() {
|
||||
let text = "The quick brown fox jumps over the lazy dog!";
|
||||
let tokens = tokenize(text);
|
||||
assert!(tokens.contains(&"quick".to_string()));
|
||||
assert!(tokens.contains(&"brown".to_string()));
|
||||
assert!(tokens.contains(&"the".to_string())); // "the" is 3 chars, passes > 2 filter
|
||||
assert!(!tokens.contains(&"a".to_string())); // 1 char, too short
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_indexing() {
|
||||
let mut bm25 = BM25::new(1.5, 0.75);
|
||||
|
||||
bm25.index_document("doc1".to_string(), "rust programming language");
|
||||
bm25.index_document("doc2".to_string(), "python programming tutorial");
|
||||
bm25.build_idf();
|
||||
|
||||
assert_eq!(bm25.doc_lengths.len(), 2);
|
||||
assert!(bm25.idf.contains_key("rust"));
|
||||
assert!(bm25.idf.contains_key("programming"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_scoring() {
|
||||
let mut bm25 = BM25::new(1.5, 0.75);
|
||||
|
||||
bm25.index_document("doc1".to_string(), "rust programming language");
|
||||
bm25.index_document("doc2".to_string(), "python programming tutorial");
|
||||
bm25.index_document("doc3".to_string(), "rust systems programming");
|
||||
bm25.build_idf();
|
||||
|
||||
let score1 = bm25.score(
|
||||
"rust programming",
|
||||
&"doc1".to_string(),
|
||||
"rust programming language",
|
||||
);
|
||||
let score2 = bm25.score(
|
||||
"rust programming",
|
||||
&"doc2".to_string(),
|
||||
"python programming tutorial",
|
||||
);
|
||||
|
||||
// doc1 should score higher (contains both terms)
|
||||
assert!(score1 > score2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_search_initialization() {
|
||||
let config = HybridConfig::default();
|
||||
let mut hybrid = HybridSearch::new(config);
|
||||
|
||||
hybrid.index_document("doc1".to_string(), "rust vector database".to_string());
|
||||
hybrid.index_document("doc2".to_string(), "python machine learning".to_string());
|
||||
hybrid.finalize_indexing();
|
||||
|
||||
assert_eq!(hybrid.doc_texts.len(), 2);
|
||||
assert_eq!(hybrid.bm25.doc_lengths.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_minmax() {
|
||||
let mut scores = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
normalize_scores(&mut scores, NormalizationStrategy::MinMax);
|
||||
|
||||
assert!((scores[0] - 0.0).abs() < 0.01);
|
||||
assert!((scores[4] - 1.0).abs() < 0.01);
|
||||
assert!((scores[2] - 0.5).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bm25_candidate_retrieval() {
|
||||
let mut bm25 = BM25::new(1.5, 0.75);
|
||||
|
||||
bm25.index_document("doc1".to_string(), "rust programming");
|
||||
bm25.index_document("doc2".to_string(), "python programming");
|
||||
bm25.index_document("doc3".to_string(), "java development");
|
||||
bm25.build_idf();
|
||||
|
||||
let candidates = bm25.get_candidate_docs("rust programming");
|
||||
assert!(candidates.contains(&"doc1".to_string()));
|
||||
assert!(candidates.contains(&"doc2".to_string())); // Contains "programming"
|
||||
assert!(!candidates.contains(&"doc3".to_string()));
|
||||
}
|
||||
}
|
||||
336
crates/ruvector-core/src/advanced_features/mmr.rs
Normal file
336
crates/ruvector-core/src/advanced_features/mmr.rs
Normal file
@@ -0,0 +1,336 @@
|
||||
//! Maximal Marginal Relevance (MMR) for Diversity-Aware Search
|
||||
//!
|
||||
//! Implements MMR algorithm to balance relevance and diversity in search results:
|
||||
//! MMR = λ × Similarity(query, doc) - (1-λ) × max Similarity(doc, selected_docs)
|
||||
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use crate::types::{DistanceMetric, SearchResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for MMR search
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MMRConfig {
|
||||
/// Lambda parameter: balance between relevance (1.0) and diversity (0.0)
|
||||
/// - λ = 1.0: Pure relevance (standard similarity search)
|
||||
/// - λ = 0.5: Equal balance
|
||||
/// - λ = 0.0: Pure diversity
|
||||
pub lambda: f32,
|
||||
/// Distance metric for similarity computation
|
||||
pub metric: DistanceMetric,
|
||||
/// Fetch multiplier for initial candidates (fetch k * multiplier results)
|
||||
pub fetch_multiplier: f32,
|
||||
}
|
||||
|
||||
impl Default for MMRConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
lambda: 0.5,
|
||||
metric: DistanceMetric::Cosine,
|
||||
fetch_multiplier: 2.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MMR search implementation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MMRSearch {
|
||||
/// Configuration
|
||||
pub config: MMRConfig,
|
||||
}
|
||||
|
||||
impl MMRSearch {
|
||||
/// Create a new MMR search instance
|
||||
pub fn new(config: MMRConfig) -> Result<Self> {
|
||||
if !(0.0..=1.0).contains(&config.lambda) {
|
||||
return Err(RuvectorError::InvalidParameter(format!(
|
||||
"Lambda must be in [0, 1], got {}",
|
||||
config.lambda
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self { config })
|
||||
}
|
||||
|
||||
/// Perform MMR-based reranking of search results
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `candidates` - Initial search results (sorted by relevance)
|
||||
/// * `k` - Number of diverse results to return
|
||||
///
|
||||
/// # Returns
|
||||
/// Reranked results optimizing for both relevance and diversity
|
||||
pub fn rerank(
|
||||
&self,
|
||||
query: &[f32],
|
||||
candidates: Vec<SearchResult>,
|
||||
k: usize,
|
||||
) -> Result<Vec<SearchResult>> {
|
||||
if candidates.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
if k == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
if k >= candidates.len() {
|
||||
return Ok(candidates);
|
||||
}
|
||||
|
||||
let mut selected: Vec<SearchResult> = Vec::with_capacity(k);
|
||||
let mut remaining = candidates;
|
||||
|
||||
// Iteratively select documents maximizing MMR
|
||||
for _ in 0..k {
|
||||
if remaining.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Compute MMR score for each remaining candidate
|
||||
let mut best_idx = 0;
|
||||
let mut best_mmr = f32::NEG_INFINITY;
|
||||
|
||||
for (idx, candidate) in remaining.iter().enumerate() {
|
||||
let mmr_score = self.compute_mmr_score(query, candidate, &selected)?;
|
||||
|
||||
if mmr_score > best_mmr {
|
||||
best_mmr = mmr_score;
|
||||
best_idx = idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Move best candidate to selected set
|
||||
let best = remaining.remove(best_idx);
|
||||
selected.push(best);
|
||||
}
|
||||
|
||||
Ok(selected)
|
||||
}
|
||||
|
||||
/// Compute MMR score for a candidate
|
||||
fn compute_mmr_score(
|
||||
&self,
|
||||
_query: &[f32],
|
||||
candidate: &SearchResult,
|
||||
selected: &[SearchResult],
|
||||
) -> Result<f32> {
|
||||
let candidate_vec = candidate.vector.as_ref().ok_or_else(|| {
|
||||
RuvectorError::InvalidParameter("Candidate vector not available".to_string())
|
||||
})?;
|
||||
|
||||
// Relevance: similarity to query (convert distance to similarity)
|
||||
let relevance = self.distance_to_similarity(candidate.score);
|
||||
|
||||
// Diversity: max similarity to already selected documents
|
||||
let max_similarity = if selected.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
selected
|
||||
.iter()
|
||||
.filter_map(|s| s.vector.as_ref())
|
||||
.map(|selected_vec| {
|
||||
let dist = compute_distance(candidate_vec, selected_vec, self.config.metric);
|
||||
self.distance_to_similarity(dist)
|
||||
})
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap_or(0.0)
|
||||
};
|
||||
|
||||
// MMR = λ × relevance - (1-λ) × max_similarity
|
||||
let mmr = self.config.lambda * relevance - (1.0 - self.config.lambda) * max_similarity;
|
||||
|
||||
Ok(mmr)
|
||||
}
|
||||
|
||||
/// Convert distance to similarity (higher is better)
|
||||
fn distance_to_similarity(&self, distance: f32) -> f32 {
|
||||
match self.config.metric {
|
||||
DistanceMetric::Cosine => 1.0 - distance,
|
||||
DistanceMetric::Euclidean => 1.0 / (1.0 + distance),
|
||||
DistanceMetric::Manhattan => 1.0 / (1.0 + distance),
|
||||
DistanceMetric::DotProduct => -distance, // Dot product is already similarity-like
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform end-to-end MMR search
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - Query vector
|
||||
/// * `k` - Number of diverse results to return
|
||||
/// * `search_fn` - Function to perform initial similarity search
|
||||
///
|
||||
/// # Returns
|
||||
/// Diverse search results
|
||||
pub fn search<F>(&self, query: &[f32], k: usize, search_fn: F) -> Result<Vec<SearchResult>>
|
||||
where
|
||||
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
|
||||
{
|
||||
// Fetch more candidates than needed
|
||||
let fetch_k = (k as f32 * self.config.fetch_multiplier).ceil() as usize;
|
||||
let candidates = search_fn(query, fetch_k)?;
|
||||
|
||||
// Rerank using MMR
|
||||
self.rerank(query, candidates, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function
|
||||
fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
|
||||
match metric {
|
||||
DistanceMetric::Euclidean => euclidean_distance(a, b),
|
||||
DistanceMetric::Cosine => cosine_distance(a, b),
|
||||
DistanceMetric::Manhattan => manhattan_distance(a, b),
|
||||
DistanceMetric::DotProduct => dot_product_distance(a, b),
|
||||
}
|
||||
}
|
||||
|
||||
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b)
|
||||
.map(|(x, y)| {
|
||||
let diff = x - y;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
1.0
|
||||
} else {
|
||||
1.0 - (dot / (norm_a * norm_b))
|
||||
}
|
||||
}
|
||||
|
||||
fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum()
|
||||
}
|
||||
|
||||
fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
-dot
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_search_result(id: &str, score: f32, vector: Vec<f32>) -> SearchResult {
|
||||
SearchResult {
|
||||
id: id.to_string(),
|
||||
score,
|
||||
vector: Some(vector),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_config_validation() {
|
||||
let config = MMRConfig {
|
||||
lambda: 0.5,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(MMRSearch::new(config).is_ok());
|
||||
|
||||
let invalid_config = MMRConfig {
|
||||
lambda: 1.5,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(MMRSearch::new(invalid_config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_reranking() {
|
||||
let config = MMRConfig {
|
||||
lambda: 0.5,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
|
||||
let mmr = MMRSearch::new(config).unwrap();
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
|
||||
// Create candidates with varying similarity
|
||||
let candidates = vec![
|
||||
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]), // Very similar to query
|
||||
create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), // Similar to doc1 and query
|
||||
create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.5]), // Different from doc1
|
||||
create_search_result("doc4", 0.6, vec![0.0, 1.0, 0.0]), // Very different
|
||||
];
|
||||
|
||||
let results = mmr.rerank(&query, candidates, 3).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 3);
|
||||
// First result should be most relevant
|
||||
assert_eq!(results[0].id, "doc1");
|
||||
// MMR should promote diversity, so doc3 or doc4 should appear
|
||||
assert!(results.iter().any(|r| r.id == "doc3" || r.id == "doc4"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_pure_relevance() {
|
||||
let config = MMRConfig {
|
||||
lambda: 1.0, // Pure relevance
|
||||
metric: DistanceMetric::Euclidean,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
|
||||
let mmr = MMRSearch::new(config).unwrap();
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
|
||||
let candidates = vec![
|
||||
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
|
||||
create_search_result("doc2", 0.15, vec![0.85, 0.1, 0.05]),
|
||||
create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.0]),
|
||||
];
|
||||
|
||||
let results = mmr.rerank(&query, candidates, 2).unwrap();
|
||||
|
||||
// With lambda=1.0, should just preserve relevance order
|
||||
assert_eq!(results[0].id, "doc1");
|
||||
assert_eq!(results[1].id, "doc2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_pure_diversity() {
|
||||
let config = MMRConfig {
|
||||
lambda: 0.0, // Pure diversity
|
||||
metric: DistanceMetric::Euclidean,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
|
||||
let mmr = MMRSearch::new(config).unwrap();
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
|
||||
let candidates = vec![
|
||||
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
|
||||
create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), // Very similar to doc1
|
||||
create_search_result("doc3", 0.5, vec![0.0, 1.0, 0.0]), // Very different
|
||||
];
|
||||
|
||||
let results = mmr.rerank(&query, candidates, 2).unwrap();
|
||||
|
||||
// With lambda=0.0, should maximize diversity
|
||||
assert_eq!(results.len(), 2);
|
||||
// Should not select both doc1 and doc2 (they're too similar)
|
||||
let has_both_similar =
|
||||
results.iter().any(|r| r.id == "doc1") && results.iter().any(|r| r.id == "doc2");
|
||||
assert!(!has_both_similar);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_empty_candidates() {
|
||||
let config = MMRConfig::default();
|
||||
let mmr = MMRSearch::new(config).unwrap();
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
|
||||
let results = mmr.rerank(&query, Vec::new(), 5).unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,549 @@
|
||||
//! Enhanced Product Quantization with Precomputed Lookup Tables
|
||||
//!
|
||||
//! Provides 8-16x compression with 90-95% recall through:
|
||||
//! - K-means clustering for codebook training
|
||||
//! - Precomputed lookup tables for fast distance calculation
|
||||
//! - Asymmetric distance computation (ADC)
|
||||
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use crate::types::DistanceMetric;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Configuration for Enhanced Product Quantization
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PQConfig {
|
||||
/// Number of subspaces to split vector into
|
||||
pub num_subspaces: usize,
|
||||
/// Codebook size per subspace (typically 256)
|
||||
pub codebook_size: usize,
|
||||
/// Number of k-means iterations for training
|
||||
pub num_iterations: usize,
|
||||
/// Distance metric for codebook training
|
||||
pub metric: DistanceMetric,
|
||||
}
|
||||
|
||||
impl Default for PQConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_subspaces: 8,
|
||||
codebook_size: 256,
|
||||
num_iterations: 20,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PQConfig {
|
||||
/// Validate the configuration
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.codebook_size > 256 {
|
||||
return Err(RuvectorError::InvalidParameter(format!(
|
||||
"Codebook size {} exceeds u8 maximum of 256",
|
||||
self.codebook_size
|
||||
)));
|
||||
}
|
||||
if self.num_subspaces == 0 {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Number of subspaces must be greater than 0".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Precomputed lookup table for fast distance computation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LookupTable {
|
||||
/// Table: [subspace][centroid] -> distance to query subvector
|
||||
pub tables: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
impl LookupTable {
|
||||
/// Create a new lookup table for a query vector
|
||||
pub fn new(query: &[f32], codebooks: &[Vec<Vec<f32>>], metric: DistanceMetric) -> Self {
|
||||
let num_subspaces = codebooks.len();
|
||||
let mut tables = Vec::with_capacity(num_subspaces);
|
||||
|
||||
for (subspace_idx, codebook) in codebooks.iter().enumerate() {
|
||||
let subspace_dim = query.len() / num_subspaces;
|
||||
let start = subspace_idx * subspace_dim;
|
||||
let end = start + subspace_dim;
|
||||
let query_subvector = &query[start..end];
|
||||
|
||||
// Compute distance from query subvector to each centroid
|
||||
let distances: Vec<f32> = codebook
|
||||
.iter()
|
||||
.map(|centroid| compute_distance(query_subvector, centroid, metric))
|
||||
.collect();
|
||||
|
||||
tables.push(distances);
|
||||
}
|
||||
|
||||
Self { tables }
|
||||
}
|
||||
|
||||
/// Compute distance to a quantized vector using the lookup table
|
||||
#[inline]
|
||||
pub fn distance(&self, codes: &[u8]) -> f32 {
|
||||
codes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(subspace_idx, &code)| self.tables[subspace_idx][code as usize])
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
/// Enhanced Product Quantization with lookup tables
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EnhancedPQ {
|
||||
/// Configuration
|
||||
pub config: PQConfig,
|
||||
/// Trained codebooks: [subspace][centroid_id][dimensions]
|
||||
pub codebooks: Vec<Vec<Vec<f32>>>,
|
||||
/// Dimensions of original vectors
|
||||
pub dimensions: usize,
|
||||
/// Quantized vectors storage: id -> codes
|
||||
pub quantized_vectors: HashMap<String, Vec<u8>>,
|
||||
}
|
||||
|
||||
impl EnhancedPQ {
|
||||
/// Create a new Enhanced PQ instance
|
||||
pub fn new(dimensions: usize, config: PQConfig) -> Result<Self> {
|
||||
config.validate()?;
|
||||
|
||||
if dimensions == 0 {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Dimensions must be greater than 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if dimensions % config.num_subspaces != 0 {
|
||||
return Err(RuvectorError::InvalidParameter(format!(
|
||||
"Dimensions {} must be divisible by num_subspaces {}",
|
||||
dimensions, config.num_subspaces
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
codebooks: Vec::new(),
|
||||
dimensions,
|
||||
quantized_vectors: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Train codebooks on a set of vectors using k-means clustering
|
||||
pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
|
||||
if training_vectors.is_empty() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Training set cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if training_vectors[0].is_empty() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Training vectors cannot have zero dimensions".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate dimensions
|
||||
for vec in training_vectors {
|
||||
if vec.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: vec.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let subspace_dim = self.dimensions / self.config.num_subspaces;
|
||||
let mut codebooks = Vec::with_capacity(self.config.num_subspaces);
|
||||
|
||||
// Train a codebook for each subspace
|
||||
for subspace_idx in 0..self.config.num_subspaces {
|
||||
let start = subspace_idx * subspace_dim;
|
||||
let end = start + subspace_dim;
|
||||
|
||||
// Extract subspace vectors
|
||||
let subspace_vectors: Vec<Vec<f32>> = training_vectors
|
||||
.iter()
|
||||
.map(|v| v[start..end].to_vec())
|
||||
.collect();
|
||||
|
||||
// Run k-means clustering
|
||||
let codebook = kmeans_clustering(
|
||||
&subspace_vectors,
|
||||
self.config.codebook_size,
|
||||
self.config.num_iterations,
|
||||
self.config.metric,
|
||||
)?;
|
||||
|
||||
codebooks.push(codebook);
|
||||
}
|
||||
|
||||
self.codebooks = codebooks;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Encode a vector into PQ codes
|
||||
pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
|
||||
if vector.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: vector.len(),
|
||||
});
|
||||
}
|
||||
|
||||
if self.codebooks.is_empty() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Codebooks not trained yet".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let subspace_dim = self.dimensions / self.config.num_subspaces;
|
||||
let mut codes = Vec::with_capacity(self.config.num_subspaces);
|
||||
|
||||
for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
|
||||
let start = subspace_idx * subspace_dim;
|
||||
let end = start + subspace_dim;
|
||||
let subvector = &vector[start..end];
|
||||
|
||||
// Find nearest centroid (quantization)
|
||||
let code = find_nearest_centroid(subvector, codebook, self.config.metric)?;
|
||||
codes.push(code);
|
||||
}
|
||||
|
||||
Ok(codes)
|
||||
}
|
||||
|
||||
/// Add a quantized vector
|
||||
pub fn add_quantized(&mut self, id: String, vector: &[f32]) -> Result<()> {
|
||||
let codes = self.encode(vector)?;
|
||||
self.quantized_vectors.insert(id, codes);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a lookup table for fast distance computation
|
||||
pub fn create_lookup_table(&self, query: &[f32]) -> Result<LookupTable> {
|
||||
if query.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: query.len(),
|
||||
});
|
||||
}
|
||||
|
||||
if self.codebooks.is_empty() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Codebooks not trained yet".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(LookupTable::new(query, &self.codebooks, self.config.metric))
|
||||
}
|
||||
|
||||
/// Search for nearest neighbors using ADC (Asymmetric Distance Computation)
|
||||
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
|
||||
let lookup_table = self.create_lookup_table(query)?;
|
||||
|
||||
// Compute distances using lookup table
|
||||
let mut distances: Vec<(String, f32)> = self
|
||||
.quantized_vectors
|
||||
.iter()
|
||||
.map(|(id, codes)| (id.clone(), lookup_table.distance(codes)))
|
||||
.collect();
|
||||
|
||||
// Sort by distance (ascending)
|
||||
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
|
||||
// Return top-k
|
||||
Ok(distances.into_iter().take(k).collect())
|
||||
}
|
||||
|
||||
/// Reconstruct approximate vector from codes
|
||||
pub fn reconstruct(&self, codes: &[u8]) -> Result<Vec<f32>> {
|
||||
if codes.len() != self.config.num_subspaces {
|
||||
return Err(RuvectorError::InvalidParameter(format!(
|
||||
"Expected {} codes, got {}",
|
||||
self.config.num_subspaces,
|
||||
codes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut result = Vec::with_capacity(self.dimensions);
|
||||
|
||||
for (subspace_idx, &code) in codes.iter().enumerate() {
|
||||
let centroid = &self.codebooks[subspace_idx][code as usize];
|
||||
result.extend_from_slice(centroid);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get compression ratio
|
||||
pub fn compression_ratio(&self) -> f32 {
|
||||
let original_bytes = self.dimensions * 4; // f32 = 4 bytes
|
||||
let compressed_bytes = self.config.num_subspaces; // 1 byte per subspace
|
||||
original_bytes as f32 / compressed_bytes as f32
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
|
||||
match metric {
|
||||
DistanceMetric::Euclidean => euclidean_squared(a, b).sqrt(),
|
||||
DistanceMetric::Cosine => {
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
1.0
|
||||
} else {
|
||||
1.0 - (dot / (norm_a * norm_b))
|
||||
}
|
||||
}
|
||||
DistanceMetric::DotProduct => {
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
-dot // Negative for minimization
|
||||
}
|
||||
DistanceMetric::Manhattan => a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum(),
|
||||
}
|
||||
}
|
||||
|
||||
fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b)
|
||||
.map(|(x, y)| {
|
||||
let diff = x - y;
|
||||
diff * diff
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn find_nearest_centroid(
|
||||
vector: &[f32],
|
||||
codebook: &[Vec<f32>],
|
||||
metric: DistanceMetric,
|
||||
) -> Result<u8> {
|
||||
codebook
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by(|(_, a), (_, b)| {
|
||||
let dist_a = compute_distance(vector, a, metric);
|
||||
let dist_b = compute_distance(vector, b, metric);
|
||||
dist_a.partial_cmp(&dist_b).unwrap()
|
||||
})
|
||||
.map(|(idx, _)| idx as u8)
|
||||
.ok_or_else(|| RuvectorError::Internal("Empty codebook".to_string()))
|
||||
}
|
||||
|
||||
fn kmeans_clustering(
|
||||
vectors: &[Vec<f32>],
|
||||
k: usize,
|
||||
iterations: usize,
|
||||
metric: DistanceMetric,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
if vectors.is_empty() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Cannot cluster empty vector set".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if vectors[0].is_empty() {
|
||||
return Err(RuvectorError::InvalidParameter(
|
||||
"Cannot cluster vectors with zero dimensions".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if k > vectors.len() {
|
||||
return Err(RuvectorError::InvalidParameter(format!(
|
||||
"k ({}) cannot be larger than number of vectors ({})",
|
||||
k,
|
||||
vectors.len()
|
||||
)));
|
||||
}
|
||||
|
||||
if k > 256 {
|
||||
return Err(RuvectorError::InvalidParameter(format!(
|
||||
"k ({}) exceeds u8 maximum of 256 for codebook size",
|
||||
k
|
||||
)));
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let dim = vectors[0].len();
|
||||
|
||||
// Initialize centroids using k-means++
|
||||
let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
|
||||
centroids.push(vectors.choose(&mut rng).unwrap().clone());
|
||||
|
||||
while centroids.len() < k {
|
||||
let distances: Vec<f32> = vectors
|
||||
.iter()
|
||||
.map(|v| {
|
||||
centroids
|
||||
.iter()
|
||||
.map(|c| compute_distance(v, c, metric))
|
||||
.min_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap_or(f32::MAX)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let total: f32 = distances.iter().sum();
|
||||
let mut rand_val = rand::random::<f32>() * total;
|
||||
|
||||
for (i, &dist) in distances.iter().enumerate() {
|
||||
rand_val -= dist;
|
||||
if rand_val <= 0.0 {
|
||||
centroids.push(vectors[i].clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback if we didn't select anything
|
||||
if centroids.len() < k && centroids.len() == centroids.len() {
|
||||
centroids.push(vectors.choose(&mut rng).unwrap().clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Lloyd's algorithm
|
||||
for _ in 0..iterations {
|
||||
let mut assignments: Vec<Vec<Vec<f32>>> = vec![Vec::new(); k];
|
||||
|
||||
// Assignment step
|
||||
for vector in vectors {
|
||||
let nearest = centroids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by(|(_, a), (_, b)| {
|
||||
let dist_a = compute_distance(vector, a, metric);
|
||||
let dist_b = compute_distance(vector, b, metric);
|
||||
dist_a.partial_cmp(&dist_b).unwrap()
|
||||
})
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
assignments[nearest].push(vector.clone());
|
||||
}
|
||||
|
||||
// Update step
|
||||
for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
|
||||
if !assigned.is_empty() {
|
||||
*centroid = vec![0.0; dim];
|
||||
|
||||
for vector in assigned {
|
||||
for (i, &v) in vector.iter().enumerate() {
|
||||
centroid[i] += v;
|
||||
}
|
||||
}
|
||||
|
||||
let count = assigned.len() as f32;
|
||||
for v in centroid.iter_mut() {
|
||||
*v /= count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(centroids)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pq_config_default() {
|
||||
let config = PQConfig::default();
|
||||
assert_eq!(config.num_subspaces, 8);
|
||||
assert_eq!(config.codebook_size, 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enhanced_pq_creation() {
|
||||
let config = PQConfig {
|
||||
num_subspaces: 4,
|
||||
codebook_size: 16,
|
||||
num_iterations: 10,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let pq = EnhancedPQ::new(128, config).unwrap();
|
||||
assert_eq!(pq.dimensions, 128);
|
||||
assert_eq!(pq.config.num_subspaces, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pq_training_and_encoding() {
|
||||
let config = PQConfig {
|
||||
num_subspaces: 2,
|
||||
codebook_size: 4,
|
||||
num_iterations: 5,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(4, config).unwrap();
|
||||
|
||||
// Generate training data
|
||||
let training_data = vec![
|
||||
vec![1.0, 2.0, 3.0, 4.0],
|
||||
vec![2.0, 3.0, 4.0, 5.0],
|
||||
vec![3.0, 4.0, 5.0, 6.0],
|
||||
vec![4.0, 5.0, 6.0, 7.0],
|
||||
vec![5.0, 6.0, 7.0, 8.0],
|
||||
];
|
||||
|
||||
pq.train(&training_data).unwrap();
|
||||
assert_eq!(pq.codebooks.len(), 2);
|
||||
|
||||
// Test encoding
|
||||
let vector = vec![2.5, 3.5, 4.5, 5.5];
|
||||
let codes = pq.encode(&vector).unwrap();
|
||||
assert_eq!(codes.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lookup_table_creation() {
|
||||
let config = PQConfig {
|
||||
num_subspaces: 2,
|
||||
codebook_size: 4,
|
||||
num_iterations: 5,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(4, config).unwrap();
|
||||
|
||||
let training_data = vec![
|
||||
vec![1.0, 2.0, 3.0, 4.0],
|
||||
vec![2.0, 3.0, 4.0, 5.0],
|
||||
vec![3.0, 4.0, 5.0, 6.0],
|
||||
vec![4.0, 5.0, 6.0, 7.0],
|
||||
];
|
||||
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
let query = vec![2.5, 3.5, 4.5, 5.5];
|
||||
let lookup_table = pq.create_lookup_table(&query).unwrap();
|
||||
|
||||
assert_eq!(lookup_table.tables.len(), 2);
|
||||
assert_eq!(lookup_table.tables[0].len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratio() {
|
||||
let config = PQConfig {
|
||||
num_subspaces: 8,
|
||||
codebook_size: 256,
|
||||
num_iterations: 10,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let pq = EnhancedPQ::new(128, config).unwrap();
|
||||
let ratio = pq.compression_ratio();
|
||||
assert_eq!(ratio, 64.0); // 128 * 4 / 8 = 64
|
||||
}
|
||||
}
|
||||
1447
crates/ruvector-core/src/agenticdb.rs
Normal file
1447
crates/ruvector-core/src/agenticdb.rs
Normal file
File diff suppressed because it is too large
Load Diff
704
crates/ruvector-core/src/arena.rs
Normal file
704
crates/ruvector-core/src/arena.rs
Normal file
@@ -0,0 +1,704 @@
|
||||
//! Arena allocator for batch operations
|
||||
//!
|
||||
//! This module provides arena-based memory allocation to reduce allocation
|
||||
//! overhead in hot paths and improve memory locality.
|
||||
//!
|
||||
//! ## Features (ADR-001)
|
||||
//!
|
||||
//! - **Cache-aligned allocations**: All allocations are aligned to cache line boundaries (64 bytes)
|
||||
//! - **Bump allocation**: O(1) allocation with minimal overhead
|
||||
//! - **Batch deallocation**: Free all allocations at once via `reset()`
|
||||
//! - **Thread-local arenas**: Per-thread allocation without synchronization
|
||||
|
||||
use std::alloc::{alloc, dealloc, Layout};
|
||||
use std::cell::RefCell;
|
||||
use std::ptr;
|
||||
|
||||
/// Cache line size (typically 64 bytes on modern CPUs)
|
||||
pub const CACHE_LINE_SIZE: usize = 64;
|
||||
|
||||
/// Arena allocator for temporary allocations
|
||||
///
|
||||
/// Use this for batch operations where many temporary allocations
|
||||
/// are needed and can be freed all at once.
|
||||
pub struct Arena {
|
||||
chunks: RefCell<Vec<Chunk>>,
|
||||
chunk_size: usize,
|
||||
}
|
||||
|
||||
struct Chunk {
|
||||
data: *mut u8,
|
||||
capacity: usize,
|
||||
used: usize,
|
||||
}
|
||||
|
||||
impl Arena {
|
||||
/// Create a new arena with the specified chunk size
|
||||
pub fn new(chunk_size: usize) -> Self {
|
||||
Self {
|
||||
chunks: RefCell::new(Vec::new()),
|
||||
chunk_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an arena with a default 1MB chunk size
|
||||
pub fn with_default_chunk_size() -> Self {
|
||||
Self::new(1024 * 1024) // 1MB
|
||||
}
|
||||
|
||||
/// Allocate a buffer of the specified size
|
||||
pub fn alloc_vec<T>(&self, count: usize) -> ArenaVec<T> {
|
||||
let size = count * std::mem::size_of::<T>();
|
||||
let align = std::mem::align_of::<T>();
|
||||
|
||||
let ptr = self.alloc_raw(size, align);
|
||||
|
||||
ArenaVec {
|
||||
ptr: ptr as *mut T,
|
||||
len: 0,
|
||||
capacity: count,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate raw bytes with specified alignment
|
||||
fn alloc_raw(&self, size: usize, align: usize) -> *mut u8 {
|
||||
// SECURITY: Validate alignment is a power of 2 and size is reasonable
|
||||
assert!(
|
||||
align > 0 && align.is_power_of_two(),
|
||||
"Alignment must be a power of 2"
|
||||
);
|
||||
assert!(size > 0, "Cannot allocate zero bytes");
|
||||
assert!(size <= isize::MAX as usize, "Allocation size too large");
|
||||
|
||||
let mut chunks = self.chunks.borrow_mut();
|
||||
|
||||
// Try to allocate from the last chunk
|
||||
if let Some(chunk) = chunks.last_mut() {
|
||||
// Align the current position
|
||||
let current = chunk.used;
|
||||
let aligned = (current + align - 1) & !(align - 1);
|
||||
|
||||
// SECURITY: Check for overflow in alignment calculation
|
||||
if aligned < current {
|
||||
panic!("Alignment calculation overflow");
|
||||
}
|
||||
|
||||
let needed = aligned
|
||||
.checked_add(size)
|
||||
.expect("Arena allocation size overflow");
|
||||
|
||||
if needed <= chunk.capacity {
|
||||
chunk.used = needed;
|
||||
return unsafe {
|
||||
// SECURITY: Verify pointer arithmetic doesn't overflow
|
||||
let ptr = chunk.data.add(aligned);
|
||||
debug_assert!(ptr as usize >= chunk.data as usize, "Pointer underflow");
|
||||
ptr
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Need a new chunk
|
||||
let chunk_size = self.chunk_size.max(size + align);
|
||||
let layout = Layout::from_size_align(chunk_size, 64).unwrap();
|
||||
let data = unsafe { alloc(layout) };
|
||||
|
||||
let aligned = align;
|
||||
let chunk = Chunk {
|
||||
data,
|
||||
capacity: chunk_size,
|
||||
used: aligned + size,
|
||||
};
|
||||
|
||||
let ptr = unsafe { data.add(aligned) };
|
||||
chunks.push(chunk);
|
||||
|
||||
ptr
|
||||
}
|
||||
|
||||
/// Reset the arena, allowing reuse of allocated memory
|
||||
pub fn reset(&self) {
|
||||
let mut chunks = self.chunks.borrow_mut();
|
||||
for chunk in chunks.iter_mut() {
|
||||
chunk.used = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get total allocated bytes
|
||||
pub fn allocated_bytes(&self) -> usize {
|
||||
let chunks = self.chunks.borrow();
|
||||
chunks.iter().map(|c| c.capacity).sum()
|
||||
}
|
||||
|
||||
/// Get used bytes
|
||||
pub fn used_bytes(&self) -> usize {
|
||||
let chunks = self.chunks.borrow();
|
||||
chunks.iter().map(|c| c.used).sum()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Arena {
|
||||
fn drop(&mut self) {
|
||||
let chunks = self.chunks.borrow();
|
||||
for chunk in chunks.iter() {
|
||||
let layout = Layout::from_size_align(chunk.capacity, 64).unwrap();
|
||||
unsafe {
|
||||
dealloc(chunk.data, layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Vector allocated from an arena
|
||||
pub struct ArenaVec<T> {
|
||||
ptr: *mut T,
|
||||
len: usize,
|
||||
capacity: usize,
|
||||
_phantom: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> ArenaVec<T> {
|
||||
/// Push an element (panics if capacity exceeded)
|
||||
pub fn push(&mut self, value: T) {
|
||||
// SECURITY: Bounds check before pointer arithmetic
|
||||
assert!(self.len < self.capacity, "ArenaVec capacity exceeded");
|
||||
assert!(!self.ptr.is_null(), "ArenaVec pointer is null");
|
||||
|
||||
unsafe {
|
||||
// Additional safety: verify the pointer offset is within bounds
|
||||
let offset_ptr = self.ptr.add(self.len);
|
||||
debug_assert!(
|
||||
offset_ptr as usize >= self.ptr as usize,
|
||||
"Pointer arithmetic overflow"
|
||||
);
|
||||
ptr::write(offset_ptr, value);
|
||||
}
|
||||
self.len += 1;
|
||||
}
|
||||
|
||||
/// Get length
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
/// Get capacity
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.capacity
|
||||
}
|
||||
|
||||
/// Get as slice
|
||||
pub fn as_slice(&self) -> &[T] {
|
||||
// SECURITY: Bounds check before creating slice
|
||||
assert!(self.len <= self.capacity, "Length exceeds capacity");
|
||||
assert!(!self.ptr.is_null(), "Cannot create slice from null pointer");
|
||||
|
||||
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
|
||||
}
|
||||
|
||||
/// Get as mutable slice
|
||||
pub fn as_mut_slice(&mut self) -> &mut [T] {
|
||||
// SECURITY: Bounds check before creating slice
|
||||
assert!(self.len <= self.capacity, "Length exceeds capacity");
|
||||
assert!(!self.ptr.is_null(), "Cannot create slice from null pointer");
|
||||
|
||||
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for ArenaVec<T> {
|
||||
type Target = [T];
|
||||
|
||||
fn deref(&self) -> &[T] {
|
||||
self.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for ArenaVec<T> {
|
||||
fn deref_mut(&mut self) -> &mut [T] {
|
||||
self.as_mut_slice()
|
||||
}
|
||||
}
|
||||
|
||||
// Thread-local arena for per-thread allocations
|
||||
thread_local! {
|
||||
static THREAD_ARENA: RefCell<Arena> = RefCell::new(Arena::with_default_chunk_size());
|
||||
}
|
||||
|
||||
// Get the thread-local arena
|
||||
// Note: Commented out due to lifetime issues with RefCell::borrow() escaping closure
|
||||
// Use THREAD_ARENA.with(|arena| { ... }) directly instead
|
||||
/*
|
||||
pub fn thread_arena() -> impl std::ops::Deref<Target = Arena> {
|
||||
THREAD_ARENA.with(|arena| {
|
||||
arena.borrow()
|
||||
})
|
||||
}
|
||||
*/
|
||||
|
||||
/// Cache-aligned vector storage for SIMD operations (ADR-001)
|
||||
///
|
||||
/// Ensures vectors are aligned to cache line boundaries (64 bytes) for
|
||||
/// optimal SIMD operations and minimal cache misses.
|
||||
#[repr(C, align(64))]
|
||||
pub struct CacheAlignedVec {
|
||||
data: *mut f32,
|
||||
len: usize,
|
||||
capacity: usize,
|
||||
}
|
||||
|
||||
impl CacheAlignedVec {
|
||||
/// Create a new cache-aligned vector with the given capacity
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if memory allocation fails. For fallible allocation,
|
||||
/// use `try_with_capacity`.
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self::try_with_capacity(capacity).expect("Failed to allocate cache-aligned memory")
|
||||
}
|
||||
|
||||
/// Try to create a new cache-aligned vector with the given capacity
|
||||
///
|
||||
/// Returns `None` if memory allocation fails.
|
||||
pub fn try_with_capacity(capacity: usize) -> Option<Self> {
|
||||
// Handle zero capacity case
|
||||
if capacity == 0 {
|
||||
return Some(Self {
|
||||
data: std::ptr::null_mut(),
|
||||
len: 0,
|
||||
capacity: 0,
|
||||
});
|
||||
}
|
||||
|
||||
// Allocate cache-line aligned memory
|
||||
let layout =
|
||||
Layout::from_size_align(capacity * std::mem::size_of::<f32>(), CACHE_LINE_SIZE).ok()?;
|
||||
|
||||
let data = unsafe { alloc(layout) as *mut f32 };
|
||||
|
||||
// SECURITY: Check for allocation failure
|
||||
if data.is_null() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
data,
|
||||
len: 0,
|
||||
capacity,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from an existing slice, copying data to cache-aligned storage
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if memory allocation fails. For fallible allocation,
|
||||
/// use `try_from_slice`.
|
||||
pub fn from_slice(slice: &[f32]) -> Self {
|
||||
Self::try_from_slice(slice).expect("Failed to allocate cache-aligned memory for slice")
|
||||
}
|
||||
|
||||
/// Try to create from an existing slice, copying data to cache-aligned storage
|
||||
///
|
||||
/// Returns `None` if memory allocation fails.
|
||||
pub fn try_from_slice(slice: &[f32]) -> Option<Self> {
|
||||
let mut vec = Self::try_with_capacity(slice.len())?;
|
||||
if !slice.is_empty() {
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(slice.as_ptr(), vec.data, slice.len());
|
||||
}
|
||||
}
|
||||
vec.len = slice.len();
|
||||
Some(vec)
|
||||
}
|
||||
|
||||
/// Push an element
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if capacity is exceeded or if the vector has zero capacity.
|
||||
pub fn push(&mut self, value: f32) {
|
||||
assert!(
|
||||
self.len < self.capacity,
|
||||
"CacheAlignedVec capacity exceeded"
|
||||
);
|
||||
assert!(
|
||||
!self.data.is_null(),
|
||||
"Cannot push to zero-capacity CacheAlignedVec"
|
||||
);
|
||||
unsafe {
|
||||
*self.data.add(self.len) = value;
|
||||
}
|
||||
self.len += 1;
|
||||
}
|
||||
|
||||
/// Get length
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
/// Get capacity
|
||||
#[inline]
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.capacity
|
||||
}
|
||||
|
||||
/// Get as slice
|
||||
#[inline]
|
||||
pub fn as_slice(&self) -> &[f32] {
|
||||
if self.len == 0 {
|
||||
// SAFETY: Empty slice doesn't require valid pointer
|
||||
return &[];
|
||||
}
|
||||
// SAFETY: data is valid for len elements when len > 0
|
||||
unsafe { std::slice::from_raw_parts(self.data, self.len) }
|
||||
}
|
||||
|
||||
/// Get as mutable slice
|
||||
#[inline]
|
||||
pub fn as_mut_slice(&mut self) -> &mut [f32] {
|
||||
if self.len == 0 {
|
||||
// SAFETY: Empty slice doesn't require valid pointer
|
||||
return &mut [];
|
||||
}
|
||||
// SAFETY: data is valid for len elements when len > 0
|
||||
unsafe { std::slice::from_raw_parts_mut(self.data, self.len) }
|
||||
}
|
||||
|
||||
/// Get raw pointer (for SIMD operations)
|
||||
#[inline]
|
||||
pub fn as_ptr(&self) -> *const f32 {
|
||||
self.data
|
||||
}
|
||||
|
||||
/// Get mutable raw pointer (for SIMD operations)
|
||||
#[inline]
|
||||
pub fn as_mut_ptr(&mut self) -> *mut f32 {
|
||||
self.data
|
||||
}
|
||||
|
||||
/// Check if properly aligned for SIMD
|
||||
///
|
||||
/// Returns `true` for zero-capacity vectors (considered trivially aligned).
|
||||
#[inline]
|
||||
pub fn is_aligned(&self) -> bool {
|
||||
if self.data.is_null() {
|
||||
// Zero-capacity vectors are considered aligned
|
||||
return self.capacity == 0;
|
||||
}
|
||||
(self.data as usize) % CACHE_LINE_SIZE == 0
|
||||
}
|
||||
|
||||
/// Clear the vector (sets len to 0, doesn't deallocate)
|
||||
pub fn clear(&mut self) {
|
||||
self.len = 0;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CacheAlignedVec {
|
||||
fn drop(&mut self) {
|
||||
if !self.data.is_null() && self.capacity > 0 {
|
||||
let layout = Layout::from_size_align(
|
||||
self.capacity * std::mem::size_of::<f32>(),
|
||||
CACHE_LINE_SIZE,
|
||||
)
|
||||
.expect("Invalid layout");
|
||||
|
||||
unsafe {
|
||||
dealloc(self.data as *mut u8, layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CacheAlignedVec {
|
||||
type Target = [f32];
|
||||
|
||||
fn deref(&self) -> &[f32] {
|
||||
self.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::DerefMut for CacheAlignedVec {
|
||||
fn deref_mut(&mut self) -> &mut [f32] {
|
||||
self.as_mut_slice()
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: The raw pointer is owned and not shared
|
||||
unsafe impl Send for CacheAlignedVec {}
|
||||
unsafe impl Sync for CacheAlignedVec {}
|
||||
|
||||
/// Batch vector allocator for processing multiple vectors (ADR-001)
|
||||
///
|
||||
/// Allocates contiguous, cache-aligned storage for a batch of vectors,
|
||||
/// enabling efficient SIMD processing and minimal cache misses.
|
||||
pub struct BatchVectorAllocator {
|
||||
data: *mut f32,
|
||||
dimensions: usize,
|
||||
capacity: usize,
|
||||
count: usize,
|
||||
}
|
||||
|
||||
impl BatchVectorAllocator {
|
||||
/// Create allocator for vectors of given dimensions
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if memory allocation fails. For fallible allocation,
|
||||
/// use `try_new`.
|
||||
pub fn new(dimensions: usize, initial_capacity: usize) -> Self {
|
||||
Self::try_new(dimensions, initial_capacity)
|
||||
.expect("Failed to allocate batch vector storage")
|
||||
}
|
||||
|
||||
/// Try to create allocator for vectors of given dimensions
|
||||
///
|
||||
/// Returns `None` if memory allocation fails.
|
||||
pub fn try_new(dimensions: usize, initial_capacity: usize) -> Option<Self> {
|
||||
// Handle zero capacity case
|
||||
if dimensions == 0 || initial_capacity == 0 {
|
||||
return Some(Self {
|
||||
data: std::ptr::null_mut(),
|
||||
dimensions,
|
||||
capacity: initial_capacity,
|
||||
count: 0,
|
||||
});
|
||||
}
|
||||
|
||||
let total_floats = dimensions * initial_capacity;
|
||||
|
||||
let layout =
|
||||
Layout::from_size_align(total_floats * std::mem::size_of::<f32>(), CACHE_LINE_SIZE)
|
||||
.ok()?;
|
||||
|
||||
let data = unsafe { alloc(layout) as *mut f32 };
|
||||
|
||||
// SECURITY: Check for allocation failure
|
||||
if data.is_null() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
data,
|
||||
dimensions,
|
||||
capacity: initial_capacity,
|
||||
count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a vector, returns its index
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the allocator is full, dimensions mismatch, or allocator has zero capacity.
|
||||
pub fn add(&mut self, vector: &[f32]) -> usize {
|
||||
assert_eq!(vector.len(), self.dimensions, "Vector dimension mismatch");
|
||||
assert!(self.count < self.capacity, "Batch allocator full");
|
||||
assert!(
|
||||
!self.data.is_null(),
|
||||
"Cannot add to zero-capacity BatchVectorAllocator"
|
||||
);
|
||||
|
||||
let offset = self.count * self.dimensions;
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(vector.as_ptr(), self.data.add(offset), self.dimensions);
|
||||
}
|
||||
|
||||
let index = self.count;
|
||||
self.count += 1;
|
||||
index
|
||||
}
|
||||
|
||||
/// Get a vector by index
|
||||
pub fn get(&self, index: usize) -> &[f32] {
|
||||
assert!(index < self.count, "Index out of bounds");
|
||||
let offset = index * self.dimensions;
|
||||
unsafe { std::slice::from_raw_parts(self.data.add(offset), self.dimensions) }
|
||||
}
|
||||
|
||||
/// Get mutable vector by index
|
||||
pub fn get_mut(&mut self, index: usize) -> &mut [f32] {
|
||||
assert!(index < self.count, "Index out of bounds");
|
||||
let offset = index * self.dimensions;
|
||||
unsafe { std::slice::from_raw_parts_mut(self.data.add(offset), self.dimensions) }
|
||||
}
|
||||
|
||||
/// Get raw pointer to vector at index (for SIMD)
|
||||
#[inline]
|
||||
pub fn ptr_at(&self, index: usize) -> *const f32 {
|
||||
assert!(index < self.count, "Index out of bounds");
|
||||
let offset = index * self.dimensions;
|
||||
unsafe { self.data.add(offset) }
|
||||
}
|
||||
|
||||
/// Number of vectors stored
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.count == 0
|
||||
}
|
||||
|
||||
/// Dimensions per vector
|
||||
#[inline]
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
/// Reset allocator (keeps memory)
|
||||
pub fn clear(&mut self) {
|
||||
self.count = 0;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for BatchVectorAllocator {
|
||||
fn drop(&mut self) {
|
||||
if !self.data.is_null() {
|
||||
let layout = Layout::from_size_align(
|
||||
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
|
||||
CACHE_LINE_SIZE,
|
||||
)
|
||||
.expect("Invalid layout");
|
||||
|
||||
unsafe {
|
||||
dealloc(self.data as *mut u8, layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: The raw pointer is owned and not shared
|
||||
unsafe impl Send for BatchVectorAllocator {}
|
||||
unsafe impl Sync for BatchVectorAllocator {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arena_alloc() {
|
||||
let arena = Arena::new(1024);
|
||||
|
||||
let mut vec1 = arena.alloc_vec::<f32>(10);
|
||||
vec1.push(1.0);
|
||||
vec1.push(2.0);
|
||||
vec1.push(3.0);
|
||||
|
||||
assert_eq!(vec1.len(), 3);
|
||||
assert_eq!(vec1[0], 1.0);
|
||||
assert_eq!(vec1[1], 2.0);
|
||||
assert_eq!(vec1[2], 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_multiple_allocs() {
|
||||
let arena = Arena::new(1024);
|
||||
|
||||
let vec1 = arena.alloc_vec::<u32>(100);
|
||||
let vec2 = arena.alloc_vec::<u64>(50);
|
||||
let vec3 = arena.alloc_vec::<f32>(200);
|
||||
|
||||
assert_eq!(vec1.capacity(), 100);
|
||||
assert_eq!(vec2.capacity(), 50);
|
||||
assert_eq!(vec3.capacity(), 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_reset() {
|
||||
let arena = Arena::new(1024);
|
||||
|
||||
{
|
||||
let _vec1 = arena.alloc_vec::<f32>(100);
|
||||
let _vec2 = arena.alloc_vec::<f32>(100);
|
||||
}
|
||||
|
||||
let used_before = arena.used_bytes();
|
||||
arena.reset();
|
||||
let used_after = arena.used_bytes();
|
||||
|
||||
assert!(used_after < used_before);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_aligned_vec() {
|
||||
let mut vec = CacheAlignedVec::with_capacity(100);
|
||||
|
||||
// Check alignment
|
||||
assert!(vec.is_aligned(), "Vector should be cache-aligned");
|
||||
|
||||
// Test push
|
||||
for i in 0..50 {
|
||||
vec.push(i as f32);
|
||||
}
|
||||
assert_eq!(vec.len(), 50);
|
||||
|
||||
// Test slice access
|
||||
let slice = vec.as_slice();
|
||||
assert_eq!(slice[0], 0.0);
|
||||
assert_eq!(slice[49], 49.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_aligned_vec_from_slice() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let aligned = CacheAlignedVec::from_slice(&data);
|
||||
|
||||
assert!(aligned.is_aligned());
|
||||
assert_eq!(aligned.len(), 5);
|
||||
assert_eq!(aligned.as_slice(), &data[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_vector_allocator() {
|
||||
let mut allocator = BatchVectorAllocator::new(4, 10);
|
||||
|
||||
let v1 = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let v2 = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let idx1 = allocator.add(&v1);
|
||||
let idx2 = allocator.add(&v2);
|
||||
|
||||
assert_eq!(idx1, 0);
|
||||
assert_eq!(idx2, 1);
|
||||
assert_eq!(allocator.len(), 2);
|
||||
|
||||
// Test retrieval
|
||||
assert_eq!(allocator.get(0), &v1[..]);
|
||||
assert_eq!(allocator.get(1), &v2[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_allocator_clear() {
|
||||
let mut allocator = BatchVectorAllocator::new(3, 5);
|
||||
|
||||
allocator.add(&[1.0, 2.0, 3.0]);
|
||||
allocator.add(&[4.0, 5.0, 6.0]);
|
||||
|
||||
assert_eq!(allocator.len(), 2);
|
||||
|
||||
allocator.clear();
|
||||
assert_eq!(allocator.len(), 0);
|
||||
|
||||
// Should be able to add again
|
||||
allocator.add(&[7.0, 8.0, 9.0]);
|
||||
assert_eq!(allocator.len(), 1);
|
||||
}
|
||||
}
|
||||
436
crates/ruvector-core/src/cache_optimized.rs
Normal file
436
crates/ruvector-core/src/cache_optimized.rs
Normal file
@@ -0,0 +1,436 @@
|
||||
//! Cache-optimized data structures using Structure-of-Arrays (SoA) layout
|
||||
//!
|
||||
//! This module provides cache-friendly layouts for vector storage to minimize
|
||||
//! cache misses and improve memory access patterns.
|
||||
|
||||
use std::alloc::{alloc, dealloc, Layout};
|
||||
use std::ptr;
|
||||
|
||||
/// Cache line size (typically 64 bytes on modern CPUs)
|
||||
const CACHE_LINE_SIZE: usize = 64;
|
||||
|
||||
/// Structure-of-Arrays layout for vectors
|
||||
///
|
||||
/// Instead of storing vectors as Vec<Vec<f32>>, we store all components
|
||||
/// separately to improve cache locality during SIMD operations.
|
||||
#[repr(align(64))] // Align to cache line boundary
|
||||
pub struct SoAVectorStorage {
|
||||
/// Number of vectors
|
||||
count: usize,
|
||||
/// Dimensions per vector
|
||||
dimensions: usize,
|
||||
/// Capacity (allocated vectors)
|
||||
capacity: usize,
|
||||
/// Storage for each dimension separately
|
||||
/// Layout: [dim0_vec0, dim0_vec1, ..., dim0_vecN, dim1_vec0, ...]
|
||||
data: *mut f32,
|
||||
}
|
||||
|
||||
impl SoAVectorStorage {
|
||||
/// Maximum allowed dimensions to prevent overflow
|
||||
const MAX_DIMENSIONS: usize = 65536;
|
||||
/// Maximum allowed capacity to prevent overflow
|
||||
const MAX_CAPACITY: usize = 1 << 24; // ~16M vectors
|
||||
|
||||
/// Create a new SoA vector storage
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if dimensions or capacity exceed safe limits or would cause overflow.
|
||||
pub fn new(dimensions: usize, initial_capacity: usize) -> Self {
|
||||
// Security: Validate inputs to prevent integer overflow
|
||||
assert!(
|
||||
dimensions > 0 && dimensions <= Self::MAX_DIMENSIONS,
|
||||
"dimensions must be between 1 and {}",
|
||||
Self::MAX_DIMENSIONS
|
||||
);
|
||||
assert!(
|
||||
initial_capacity <= Self::MAX_CAPACITY,
|
||||
"initial_capacity exceeds maximum of {}",
|
||||
Self::MAX_CAPACITY
|
||||
);
|
||||
|
||||
let capacity = initial_capacity.next_power_of_two();
|
||||
|
||||
// Security: Use checked arithmetic to prevent overflow
|
||||
let total_elements = dimensions
|
||||
.checked_mul(capacity)
|
||||
.expect("dimensions * capacity overflow");
|
||||
let total_bytes = total_elements
|
||||
.checked_mul(std::mem::size_of::<f32>())
|
||||
.expect("total size overflow");
|
||||
|
||||
let layout =
|
||||
Layout::from_size_align(total_bytes, CACHE_LINE_SIZE).expect("invalid memory layout");
|
||||
|
||||
let data = unsafe { alloc(layout) as *mut f32 };
|
||||
|
||||
// Zero initialize
|
||||
unsafe {
|
||||
ptr::write_bytes(data, 0, total_elements);
|
||||
}
|
||||
|
||||
Self {
|
||||
count: 0,
|
||||
dimensions,
|
||||
capacity,
|
||||
data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a vector to the storage
|
||||
pub fn push(&mut self, vector: &[f32]) {
|
||||
assert_eq!(vector.len(), self.dimensions);
|
||||
|
||||
if self.count >= self.capacity {
|
||||
self.grow();
|
||||
}
|
||||
|
||||
// Store each dimension separately
|
||||
for (dim_idx, &value) in vector.iter().enumerate() {
|
||||
let offset = dim_idx * self.capacity + self.count;
|
||||
unsafe {
|
||||
*self.data.add(offset) = value;
|
||||
}
|
||||
}
|
||||
|
||||
self.count += 1;
|
||||
}
|
||||
|
||||
/// Get a vector by index (copies to output buffer)
|
||||
pub fn get(&self, index: usize, output: &mut [f32]) {
|
||||
assert!(index < self.count);
|
||||
assert_eq!(output.len(), self.dimensions);
|
||||
|
||||
for (dim_idx, out) in output.iter_mut().enumerate().take(self.dimensions) {
|
||||
let offset = dim_idx * self.capacity + index;
|
||||
*out = unsafe { *self.data.add(offset) };
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a slice of a specific dimension across all vectors
|
||||
/// This allows efficient SIMD operations on a single dimension
|
||||
pub fn dimension_slice(&self, dim_idx: usize) -> &[f32] {
|
||||
assert!(dim_idx < self.dimensions);
|
||||
let offset = dim_idx * self.capacity;
|
||||
unsafe { std::slice::from_raw_parts(self.data.add(offset), self.count) }
|
||||
}
|
||||
|
||||
/// Get a mutable slice of a specific dimension
|
||||
pub fn dimension_slice_mut(&mut self, dim_idx: usize) -> &mut [f32] {
|
||||
assert!(dim_idx < self.dimensions);
|
||||
let offset = dim_idx * self.capacity;
|
||||
unsafe { std::slice::from_raw_parts_mut(self.data.add(offset), self.count) }
|
||||
}
|
||||
|
||||
/// Number of vectors stored
|
||||
pub fn len(&self) -> usize {
|
||||
self.count
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.count == 0
|
||||
}
|
||||
|
||||
/// Dimensions per vector
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
/// Grow the storage capacity
|
||||
fn grow(&mut self) {
|
||||
let new_capacity = self.capacity * 2;
|
||||
|
||||
// Security: Use checked arithmetic to prevent overflow
|
||||
let new_total_elements = self
|
||||
.dimensions
|
||||
.checked_mul(new_capacity)
|
||||
.expect("dimensions * new_capacity overflow");
|
||||
let new_total_bytes = new_total_elements
|
||||
.checked_mul(std::mem::size_of::<f32>())
|
||||
.expect("total size overflow in grow");
|
||||
|
||||
let new_layout = Layout::from_size_align(new_total_bytes, CACHE_LINE_SIZE)
|
||||
.expect("invalid memory layout in grow");
|
||||
|
||||
let new_data = unsafe { alloc(new_layout) as *mut f32 };
|
||||
|
||||
// Copy old data dimension by dimension
|
||||
for dim_idx in 0..self.dimensions {
|
||||
let old_offset = dim_idx * self.capacity;
|
||||
let new_offset = dim_idx * new_capacity;
|
||||
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(
|
||||
self.data.add(old_offset),
|
||||
new_data.add(new_offset),
|
||||
self.count,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Deallocate old data
|
||||
let old_layout = Layout::from_size_align(
|
||||
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
|
||||
CACHE_LINE_SIZE,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
unsafe {
|
||||
dealloc(self.data as *mut u8, old_layout);
|
||||
}
|
||||
|
||||
self.data = new_data;
|
||||
self.capacity = new_capacity;
|
||||
}
|
||||
|
||||
/// Compute distance from query to all stored vectors using dimension-wise operations
|
||||
/// This takes advantage of the SoA layout for better cache utilization
|
||||
#[inline(always)]
|
||||
pub fn batch_euclidean_distances(&self, query: &[f32], output: &mut [f32]) {
|
||||
assert_eq!(query.len(), self.dimensions);
|
||||
assert_eq!(output.len(), self.count);
|
||||
|
||||
// Use SIMD-optimized version for larger batches
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
if self.count >= 16 {
|
||||
unsafe { self.batch_euclidean_distances_neon(query, output) };
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if self.count >= 32 && is_x86_feature_detected!("avx2") {
|
||||
unsafe { self.batch_euclidean_distances_avx2(query, output) };
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar fallback
|
||||
self.batch_euclidean_distances_scalar(query, output);
|
||||
}
|
||||
|
||||
/// Scalar implementation of batch euclidean distances
|
||||
#[inline(always)]
|
||||
fn batch_euclidean_distances_scalar(&self, query: &[f32], output: &mut [f32]) {
|
||||
// Initialize output with zeros
|
||||
output.fill(0.0);
|
||||
|
||||
// Process dimension by dimension for cache-friendly access
|
||||
for dim_idx in 0..self.dimensions {
|
||||
let dim_slice = self.dimension_slice(dim_idx);
|
||||
// Safety: dim_idx is bounded by self.dimensions which is validated in constructor
|
||||
let query_val = unsafe { *query.get_unchecked(dim_idx) };
|
||||
|
||||
// Compute squared differences for this dimension
|
||||
// Use unchecked access since vec_idx is bounded by self.count
|
||||
for vec_idx in 0..self.count {
|
||||
let diff = unsafe { *dim_slice.get_unchecked(vec_idx) } - query_val;
|
||||
unsafe { *output.get_unchecked_mut(vec_idx) += diff * diff };
|
||||
}
|
||||
}
|
||||
|
||||
// Take square root
|
||||
for distance in output.iter_mut() {
|
||||
*distance = distance.sqrt();
|
||||
}
|
||||
}
|
||||
|
||||
/// NEON-optimized batch euclidean distances
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure query.len() == self.dimensions and output.len() == self.count
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
#[inline(always)]
|
||||
unsafe fn batch_euclidean_distances_neon(&self, query: &[f32], output: &mut [f32]) {
|
||||
use std::arch::aarch64::*;
|
||||
|
||||
let out_ptr = output.as_mut_ptr();
|
||||
let query_ptr = query.as_ptr();
|
||||
|
||||
// Initialize output with zeros
|
||||
let chunks = self.count / 4;
|
||||
|
||||
// Zero initialize using SIMD
|
||||
let zero = vdupq_n_f32(0.0);
|
||||
for i in 0..chunks {
|
||||
let idx = i * 4;
|
||||
vst1q_f32(out_ptr.add(idx), zero);
|
||||
}
|
||||
for i in (chunks * 4)..self.count {
|
||||
*output.get_unchecked_mut(i) = 0.0;
|
||||
}
|
||||
|
||||
// Process dimension by dimension for cache-friendly access
|
||||
for dim_idx in 0..self.dimensions {
|
||||
let dim_slice = self.dimension_slice(dim_idx);
|
||||
let dim_ptr = dim_slice.as_ptr();
|
||||
let query_val = vdupq_n_f32(*query_ptr.add(dim_idx));
|
||||
|
||||
// SIMD processing of 4 vectors at a time
|
||||
for i in 0..chunks {
|
||||
let idx = i * 4;
|
||||
let dim_vals = vld1q_f32(dim_ptr.add(idx));
|
||||
let out_vals = vld1q_f32(out_ptr.add(idx));
|
||||
|
||||
let diff = vsubq_f32(dim_vals, query_val);
|
||||
let result = vfmaq_f32(out_vals, diff, diff);
|
||||
|
||||
vst1q_f32(out_ptr.add(idx), result);
|
||||
}
|
||||
|
||||
// Handle remainder with bounds-check elimination
|
||||
let query_val_scalar = *query_ptr.add(dim_idx);
|
||||
for i in (chunks * 4)..self.count {
|
||||
let diff = *dim_slice.get_unchecked(i) - query_val_scalar;
|
||||
*output.get_unchecked_mut(i) += diff * diff;
|
||||
}
|
||||
}
|
||||
|
||||
// Take square root using SIMD vsqrtq_f32
|
||||
for i in 0..chunks {
|
||||
let idx = i * 4;
|
||||
let vals = vld1q_f32(out_ptr.add(idx));
|
||||
let sqrt_vals = vsqrtq_f32(vals);
|
||||
vst1q_f32(out_ptr.add(idx), sqrt_vals);
|
||||
}
|
||||
for i in (chunks * 4)..self.count {
|
||||
*output.get_unchecked_mut(i) = output.get_unchecked(i).sqrt();
|
||||
}
|
||||
}
|
||||
|
||||
/// AVX2-optimized batch euclidean distances
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
unsafe fn batch_euclidean_distances_avx2(&self, query: &[f32], output: &mut [f32]) {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
let chunks = self.count / 8;
|
||||
|
||||
// Zero initialize using SIMD
|
||||
let zero = _mm256_setzero_ps();
|
||||
for i in 0..chunks {
|
||||
let idx = i * 8;
|
||||
_mm256_storeu_ps(output.as_mut_ptr().add(idx), zero);
|
||||
}
|
||||
for out in output.iter_mut().take(self.count).skip(chunks * 8) {
|
||||
*out = 0.0;
|
||||
}
|
||||
|
||||
// Process dimension by dimension
|
||||
for (dim_idx, &q_val) in query.iter().enumerate().take(self.dimensions) {
|
||||
let dim_slice = self.dimension_slice(dim_idx);
|
||||
let query_val = _mm256_set1_ps(q_val);
|
||||
|
||||
// SIMD processing of 8 vectors at a time
|
||||
for i in 0..chunks {
|
||||
let idx = i * 8;
|
||||
let dim_vals = _mm256_loadu_ps(dim_slice.as_ptr().add(idx));
|
||||
let out_vals = _mm256_loadu_ps(output.as_ptr().add(idx));
|
||||
|
||||
let diff = _mm256_sub_ps(dim_vals, query_val);
|
||||
let sq = _mm256_mul_ps(diff, diff);
|
||||
let result = _mm256_add_ps(out_vals, sq);
|
||||
|
||||
_mm256_storeu_ps(output.as_mut_ptr().add(idx), result);
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for i in (chunks * 8)..self.count {
|
||||
let diff = dim_slice[i] - query[dim_idx];
|
||||
output[i] += diff * diff;
|
||||
}
|
||||
}
|
||||
|
||||
// Take square root (no SIMD sqrt in basic AVX2, use scalar)
|
||||
for distance in output.iter_mut() {
|
||||
*distance = distance.sqrt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Feature detection helper for x86_64
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[allow(dead_code)]
|
||||
fn is_x86_feature_detected_helper(feature: &str) -> bool {
|
||||
match feature {
|
||||
"avx2" => is_x86_feature_detected!("avx2"),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SoAVectorStorage {
|
||||
fn drop(&mut self) {
|
||||
let layout = Layout::from_size_align(
|
||||
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
|
||||
CACHE_LINE_SIZE,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
unsafe {
|
||||
dealloc(self.data as *mut u8, layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for SoAVectorStorage {}
|
||||
unsafe impl Sync for SoAVectorStorage {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_soa_storage() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
storage.push(&[4.0, 5.0, 6.0]);
|
||||
|
||||
assert_eq!(storage.len(), 2);
|
||||
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![1.0, 2.0, 3.0]);
|
||||
|
||||
storage.get(1, &mut output);
|
||||
assert_eq!(output, vec![4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_slice() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
storage.push(&[4.0, 5.0, 6.0]);
|
||||
storage.push(&[7.0, 8.0, 9.0]);
|
||||
|
||||
// Get all values for dimension 0
|
||||
let dim0 = storage.dimension_slice(0);
|
||||
assert_eq!(dim0, &[1.0, 4.0, 7.0]);
|
||||
|
||||
// Get all values for dimension 1
|
||||
let dim1 = storage.dimension_slice(1);
|
||||
assert_eq!(dim1, &[2.0, 5.0, 8.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_distances() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
storage.push(&[1.0, 0.0, 0.0]);
|
||||
storage.push(&[0.0, 1.0, 0.0]);
|
||||
storage.push(&[0.0, 0.0, 1.0]);
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let mut distances = vec![0.0; 3];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
assert!((distances[0] - 0.0).abs() < 0.001);
|
||||
assert!((distances[1] - 1.414).abs() < 0.01);
|
||||
assert!((distances[2] - 1.414).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
167
crates/ruvector-core/src/distance.rs
Normal file
167
crates/ruvector-core/src/distance.rs
Normal file
@@ -0,0 +1,167 @@
|
||||
//! SIMD-optimized distance metrics
|
||||
//! Uses SimSIMD when available (native), falls back to pure Rust for WASM
|
||||
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use crate::types::DistanceMetric;
|
||||
|
||||
/// Calculate distance between two vectors using the specified metric
|
||||
#[inline]
|
||||
pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> Result<f32> {
|
||||
if a.len() != b.len() {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: a.len(),
|
||||
actual: b.len(),
|
||||
});
|
||||
}
|
||||
|
||||
match metric {
|
||||
DistanceMetric::Euclidean => Ok(euclidean_distance(a, b)),
|
||||
DistanceMetric::Cosine => Ok(cosine_distance(a, b)),
|
||||
DistanceMetric::DotProduct => Ok(dot_product_distance(a, b)),
|
||||
DistanceMetric::Manhattan => Ok(manhattan_distance(a, b)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Euclidean (L2) distance
|
||||
#[inline]
|
||||
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
|
||||
{
|
||||
(simsimd::SpatialSimilarity::sqeuclidean(a, b)
|
||||
.expect("SimSIMD euclidean failed")
|
||||
.sqrt()) as f32
|
||||
}
|
||||
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
|
||||
{
|
||||
// Pure Rust fallback for WASM
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y) * (x - y))
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
/// Cosine distance (1 - cosine_similarity)
|
||||
#[inline]
|
||||
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
|
||||
{
|
||||
simsimd::SpatialSimilarity::cosine(a, b).expect("SimSIMD cosine failed") as f32
|
||||
}
|
||||
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
|
||||
{
|
||||
// Pure Rust fallback for WASM
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm_a > 1e-8 && norm_b > 1e-8 {
|
||||
1.0 - (dot / (norm_a * norm_b))
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dot product distance (negative for maximization)
|
||||
#[inline]
|
||||
pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let dot = simsimd::SpatialSimilarity::dot(a, b).expect("SimSIMD dot product failed");
|
||||
(-dot) as f32
|
||||
}
|
||||
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
|
||||
{
|
||||
// Pure Rust fallback for WASM
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
-dot
|
||||
}
|
||||
}
|
||||
|
||||
/// Manhattan (L1) distance
|
||||
#[inline]
|
||||
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
|
||||
}
|
||||
|
||||
/// Batch distance calculation optimized with Rayon (native) or sequential (WASM)
|
||||
pub fn batch_distances(
|
||||
query: &[f32],
|
||||
vectors: &[Vec<f32>],
|
||||
metric: DistanceMetric,
|
||||
) -> Result<Vec<f32>> {
|
||||
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
|
||||
{
|
||||
use rayon::prelude::*;
|
||||
vectors
|
||||
.par_iter()
|
||||
.map(|v| distance(query, v, metric))
|
||||
.collect()
|
||||
}
|
||||
#[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
|
||||
{
|
||||
// Sequential fallback for WASM
|
||||
vectors.iter().map(|v| distance(query, v, metric)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_distance() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
let dist = euclidean_distance(&a, &b);
|
||||
assert!((dist - 5.196).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_distance() {
|
||||
// Test with identical vectors (should have distance ~0)
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
let dist = cosine_distance(&a, &b);
|
||||
assert!(
|
||||
dist < 0.01,
|
||||
"Identical vectors should have ~0 distance, got {}",
|
||||
dist
|
||||
);
|
||||
|
||||
// Test with opposite vectors (should have high distance)
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![-1.0, 0.0, 0.0];
|
||||
let dist = cosine_distance(&a, &b);
|
||||
assert!(
|
||||
dist > 1.5,
|
||||
"Opposite vectors should have high distance, got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_distance() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
let dist = dot_product_distance(&a, &b);
|
||||
assert!((dist + 32.0).abs() < 0.01); // -(4 + 10 + 18) = -32
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_distance() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![4.0, 5.0, 6.0];
|
||||
let dist = manhattan_distance(&a, &b);
|
||||
assert!((dist - 9.0).abs() < 0.01); // |1-4| + |2-5| + |3-6| = 9
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch() {
|
||||
let a = vec![1.0, 2.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
let result = distance(&a, &b, DistanceMetric::Euclidean);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
415
crates/ruvector-core/src/embeddings.rs
Normal file
415
crates/ruvector-core/src/embeddings.rs
Normal file
@@ -0,0 +1,415 @@
|
||||
//! Text Embedding Providers
|
||||
//!
|
||||
//! This module provides a pluggable embedding system for AgenticDB.
|
||||
//!
|
||||
//! ## Available Providers
|
||||
//!
|
||||
//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
|
||||
//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
|
||||
//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding, ApiEmbedding};
|
||||
//! use ruvector_core::AgenticDB;
|
||||
//!
|
||||
//! // Default: Hash-based (fast, but not semantic)
|
||||
//! let hash_provider = HashEmbedding::new(384);
|
||||
//! let embedding = hash_provider.embed("hello world")?;
|
||||
//!
|
||||
//! // API-based (requires API key)
|
||||
//! let api_provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
|
||||
//! let embedding = api_provider.embed("hello world")?;
|
||||
//! # Ok::<(), Box<dyn std::error::Error>>(())
|
||||
//! ```
|
||||
|
||||
use crate::error::Result;
|
||||
#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
|
||||
use crate::error::RuvectorError;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Trait for text embedding providers
|
||||
pub trait EmbeddingProvider: Send + Sync {
|
||||
/// Generate embedding vector for the given text
|
||||
fn embed(&self, text: &str) -> Result<Vec<f32>>;
|
||||
|
||||
/// Get the dimensionality of embeddings produced by this provider
|
||||
fn dimensions(&self) -> usize;
|
||||
|
||||
/// Get a description of this provider (for logging/debugging)
|
||||
fn name(&self) -> &str;
|
||||
}
|
||||
|
||||
/// Hash-based embedding provider (placeholder, not semantic)
|
||||
///
|
||||
/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
|
||||
/// - "dog" and "cat" will NOT be similar
|
||||
/// - "dog" and "god" WILL be similar (same characters)
|
||||
///
|
||||
/// Use this only for:
|
||||
/// - Testing
|
||||
/// - Prototyping
|
||||
/// - When semantic similarity is not required
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HashEmbedding {
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
impl HashEmbedding {
|
||||
/// Create a new hash-based embedding provider
|
||||
pub fn new(dimensions: usize) -> Self {
|
||||
Self { dimensions }
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingProvider for HashEmbedding {
|
||||
fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
let mut embedding = vec![0.0; self.dimensions];
|
||||
let bytes = text.as_bytes();
|
||||
|
||||
for (i, byte) in bytes.iter().enumerate() {
|
||||
embedding[i % self.dimensions] += (*byte as f32) / 255.0;
|
||||
}
|
||||
|
||||
// Normalize
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for val in &mut embedding {
|
||||
*val /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"HashEmbedding (placeholder)"
|
||||
}
|
||||
}
|
||||
|
||||
/// Real embeddings using candle-transformers
|
||||
///
|
||||
/// Requires feature flag: `real-embeddings`
|
||||
///
|
||||
/// ⚠️ **Note**: Full candle integration is complex and model-specific.
|
||||
/// For production use, we recommend:
|
||||
/// 1. Using the API-based providers (simpler, always up-to-date)
|
||||
/// 2. Using ONNX Runtime with pre-exported models
|
||||
/// 3. Implementing your own candle wrapper for your specific model
|
||||
///
|
||||
/// This is a stub implementation showing the structure.
|
||||
/// Users should implement `EmbeddingProvider` trait for their specific models.
|
||||
#[cfg(feature = "real-embeddings")]
|
||||
pub mod candle {
|
||||
use super::*;
|
||||
|
||||
/// Candle-based embedding provider stub
|
||||
///
|
||||
/// This is a placeholder. For real implementation:
|
||||
/// 1. Add candle dependencies for your specific model type
|
||||
/// 2. Implement model loading and inference
|
||||
/// 3. Handle tokenization appropriately
|
||||
///
|
||||
/// Example structure:
|
||||
/// ```rust,ignore
|
||||
/// pub struct CandleEmbedding {
|
||||
/// model: YourModelType,
|
||||
/// tokenizer: Tokenizer,
|
||||
/// device: Device,
|
||||
/// dimensions: usize,
|
||||
/// }
|
||||
/// ```
|
||||
pub struct CandleEmbedding {
|
||||
dimensions: usize,
|
||||
model_id: String,
|
||||
}
|
||||
|
||||
impl CandleEmbedding {
|
||||
/// Create a stub candle embedding provider
|
||||
///
|
||||
/// **This is not a real implementation!**
|
||||
/// For production, implement with actual model loading.
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust,no_run
|
||||
/// # #[cfg(feature = "real-embeddings")]
|
||||
/// # {
|
||||
/// use ruvector_core::embeddings::candle::CandleEmbedding;
|
||||
///
|
||||
/// // This returns an error - real implementation required
|
||||
/// let result = CandleEmbedding::from_pretrained(
|
||||
/// "sentence-transformers/all-MiniLM-L6-v2",
|
||||
/// false
|
||||
/// );
|
||||
/// assert!(result.is_err());
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
|
||||
Err(RuvectorError::ModelLoadError(format!(
|
||||
"Candle embedding support is a stub. Please:\n\
|
||||
1. Use ApiEmbedding for production (recommended)\n\
|
||||
2. Or implement CandleEmbedding for model: {}\n\
|
||||
3. See docs for ONNX Runtime integration examples",
|
||||
model_id
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingProvider for CandleEmbedding {
|
||||
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
|
||||
Err(RuvectorError::ModelInferenceError(
|
||||
"Candle embedding not implemented - use ApiEmbedding instead".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"CandleEmbedding (stub - not implemented)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "real-embeddings")]
|
||||
pub use candle::CandleEmbedding;
|
||||
|
||||
/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
|
||||
///
|
||||
/// Supports any API that accepts JSON and returns embeddings in a standard format.
|
||||
///
|
||||
/// # Example (OpenAI)
|
||||
/// ```rust,no_run
|
||||
/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
|
||||
///
|
||||
/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
|
||||
/// let embedding = provider.embed("hello world")?;
|
||||
/// # Ok::<(), Box<dyn std::error::Error>>(())
|
||||
/// ```
|
||||
#[cfg(feature = "api-embeddings")]
|
||||
#[derive(Clone)]
|
||||
pub struct ApiEmbedding {
|
||||
api_key: String,
|
||||
endpoint: String,
|
||||
model: String,
|
||||
dimensions: usize,
|
||||
client: reqwest::blocking::Client,
|
||||
}
|
||||
|
||||
#[cfg(feature = "api-embeddings")]
|
||||
impl ApiEmbedding {
|
||||
/// Create a new API embedding provider
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `api_key` - API key for authentication
|
||||
/// * `endpoint` - API endpoint URL
|
||||
/// * `model` - Model identifier
|
||||
/// * `dimensions` - Expected embedding dimensions
|
||||
pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
endpoint,
|
||||
model,
|
||||
dimensions,
|
||||
client: reqwest::blocking::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create OpenAI embedding provider
|
||||
///
|
||||
/// # Models
|
||||
/// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
|
||||
/// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
|
||||
/// - `text-embedding-ada-002` - 1536 dimensions (legacy)
|
||||
pub fn openai(api_key: &str, model: &str) -> Self {
|
||||
let dimensions = match model {
|
||||
"text-embedding-3-large" => 3072,
|
||||
_ => 1536, // text-embedding-3-small and ada-002
|
||||
};
|
||||
|
||||
Self::new(
|
||||
api_key.to_string(),
|
||||
"https://api.openai.com/v1/embeddings".to_string(),
|
||||
model.to_string(),
|
||||
dimensions,
|
||||
)
|
||||
}
|
||||
|
||||
/// Create Cohere embedding provider
|
||||
///
|
||||
/// # Models
|
||||
/// - `embed-english-v3.0` - 1024 dimensions
|
||||
/// - `embed-multilingual-v3.0` - 1024 dimensions
|
||||
pub fn cohere(api_key: &str, model: &str) -> Self {
|
||||
Self::new(
|
||||
api_key.to_string(),
|
||||
"https://api.cohere.ai/v1/embed".to_string(),
|
||||
model.to_string(),
|
||||
1024,
|
||||
)
|
||||
}
|
||||
|
||||
/// Create Voyage AI embedding provider
|
||||
///
|
||||
/// # Models
|
||||
/// - `voyage-2` - 1024 dimensions
|
||||
/// - `voyage-large-2` - 1536 dimensions
|
||||
pub fn voyage(api_key: &str, model: &str) -> Self {
|
||||
let dimensions = if model.contains("large") { 1536 } else { 1024 };
|
||||
|
||||
Self::new(
|
||||
api_key.to_string(),
|
||||
"https://api.voyageai.com/v1/embeddings".to_string(),
|
||||
model.to_string(),
|
||||
dimensions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "api-embeddings")]
|
||||
impl EmbeddingProvider for ApiEmbedding {
|
||||
fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
let request_body = serde_json::json!({
|
||||
"input": text,
|
||||
"model": self.model,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&self.endpoint)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.map_err(|e| {
|
||||
RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response
|
||||
.text()
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(RuvectorError::ModelInferenceError(format!(
|
||||
"API returned error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let response_json: serde_json::Value = response.json().map_err(|e| {
|
||||
RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
|
||||
})?;
|
||||
|
||||
// Handle different API response formats
|
||||
let embedding = if let Some(data) = response_json.get("data") {
|
||||
// OpenAI format: {"data": [{"embedding": [...]}]}
|
||||
data.as_array()
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|obj| obj.get("embedding"))
|
||||
.and_then(|emb| emb.as_array())
|
||||
.ok_or_else(|| {
|
||||
RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
|
||||
})?
|
||||
} else if let Some(embeddings) = response_json.get("embeddings") {
|
||||
// Cohere format: {"embeddings": [[...]]}
|
||||
embeddings
|
||||
.as_array()
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|emb| emb.as_array())
|
||||
.ok_or_else(|| {
|
||||
RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
|
||||
})?
|
||||
} else {
|
||||
return Err(RuvectorError::ModelInferenceError(
|
||||
"Unknown API response format".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let embedding_vec: Result<Vec<f32>> = embedding
|
||||
.iter()
|
||||
.map(|v| {
|
||||
v.as_f64().map(|f| f as f32).ok_or_else(|| {
|
||||
RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
embedding_vec
|
||||
}
|
||||
|
||||
fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"ApiEmbedding"
|
||||
}
|
||||
}
|
||||
|
||||
/// Type-erased embedding provider for dynamic dispatch
|
||||
pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hash_embedding() {
|
||||
let provider = HashEmbedding::new(128);
|
||||
|
||||
let emb1 = provider.embed("hello world").unwrap();
|
||||
let emb2 = provider.embed("hello world").unwrap();
|
||||
|
||||
assert_eq!(emb1.len(), 128);
|
||||
assert_eq!(emb1, emb2, "Same text should produce same embedding");
|
||||
|
||||
// Check normalization
|
||||
let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_embedding_different_text() {
|
||||
let provider = HashEmbedding::new(128);
|
||||
|
||||
let emb1 = provider.embed("hello").unwrap();
|
||||
let emb2 = provider.embed("world").unwrap();
|
||||
|
||||
assert_ne!(
|
||||
emb1, emb2,
|
||||
"Different text should produce different embeddings"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "real-embeddings")]
|
||||
#[test]
|
||||
#[ignore] // Requires model download
|
||||
fn test_candle_embedding() {
|
||||
let provider =
|
||||
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
|
||||
.unwrap();
|
||||
|
||||
let embedding = provider.embed("hello world").unwrap();
|
||||
assert_eq!(embedding.len(), 384);
|
||||
|
||||
// Check normalization
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Requires API key
|
||||
fn test_api_embedding_openai() {
|
||||
let api_key = std::env::var("OPENAI_API_KEY").unwrap();
|
||||
let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
|
||||
|
||||
let embedding = provider.embed("hello world").unwrap();
|
||||
assert_eq!(embedding.len(), 1536);
|
||||
}
|
||||
}
|
||||
113
crates/ruvector-core/src/error.rs
Normal file
113
crates/ruvector-core/src/error.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
//! Error types for Ruvector
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type alias for Ruvector operations
|
||||
pub type Result<T> = std::result::Result<T, RuvectorError>;
|
||||
|
||||
/// Main error type for Ruvector
|
||||
#[derive(Error, Debug)]
|
||||
pub enum RuvectorError {
|
||||
/// Vector dimension mismatch
|
||||
#[error("Dimension mismatch: expected {expected}, got {actual}")]
|
||||
DimensionMismatch {
|
||||
/// Expected dimension
|
||||
expected: usize,
|
||||
/// Actual dimension
|
||||
actual: usize,
|
||||
},
|
||||
|
||||
/// Vector not found
|
||||
#[error("Vector not found: {0}")]
|
||||
VectorNotFound(String),
|
||||
|
||||
/// Invalid parameter
|
||||
#[error("Invalid parameter: {0}")]
|
||||
InvalidParameter(String),
|
||||
|
||||
/// Invalid input
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
/// Invalid dimension
|
||||
#[error("Invalid dimension: {0}")]
|
||||
InvalidDimension(String),
|
||||
|
||||
/// Storage error
|
||||
#[error("Storage error: {0}")]
|
||||
StorageError(String),
|
||||
|
||||
/// Model loading error
|
||||
#[error("Model loading error: {0}")]
|
||||
ModelLoadError(String),
|
||||
|
||||
/// Model inference error
|
||||
#[error("Model inference error: {0}")]
|
||||
ModelInferenceError(String),
|
||||
|
||||
/// Index error
|
||||
#[error("Index error: {0}")]
|
||||
IndexError(String),
|
||||
|
||||
/// Serialization error
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(String),
|
||||
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
/// Database error
|
||||
#[error("Database error: {0}")]
|
||||
DatabaseError(String),
|
||||
|
||||
/// Invalid path error
|
||||
#[error("Invalid path: {0}")]
|
||||
InvalidPath(String),
|
||||
|
||||
/// Other errors
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
impl From<redb::Error> for RuvectorError {
|
||||
fn from(err: redb::Error) -> Self {
|
||||
RuvectorError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
impl From<redb::DatabaseError> for RuvectorError {
|
||||
fn from(err: redb::DatabaseError) -> Self {
|
||||
RuvectorError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
impl From<redb::StorageError> for RuvectorError {
|
||||
fn from(err: redb::StorageError) -> Self {
|
||||
RuvectorError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
impl From<redb::TableError> for RuvectorError {
|
||||
fn from(err: redb::TableError) -> Self {
|
||||
RuvectorError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
impl From<redb::TransactionError> for RuvectorError {
|
||||
fn from(err: redb::TransactionError) -> Self {
|
||||
RuvectorError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
impl From<redb::CommitError> for RuvectorError {
|
||||
fn from(err: redb::CommitError) -> Self {
|
||||
RuvectorError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
36
crates/ruvector-core/src/index.rs
Normal file
36
crates/ruvector-core/src/index.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
//! Index structures for efficient vector search
|
||||
|
||||
pub mod flat;
|
||||
#[cfg(feature = "hnsw")]
|
||||
pub mod hnsw;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::types::{SearchResult, VectorId};
|
||||
|
||||
/// Trait for vector index implementations
|
||||
pub trait VectorIndex: Send + Sync {
|
||||
/// Add a vector to the index
|
||||
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()>;
|
||||
|
||||
/// Add multiple vectors in batch
|
||||
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
|
||||
for (id, vector) in entries {
|
||||
self.add(id, vector)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search for k nearest neighbors
|
||||
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
|
||||
|
||||
/// Remove a vector from the index
|
||||
fn remove(&mut self, id: &VectorId) -> Result<bool>;
|
||||
|
||||
/// Get the number of vectors in the index
|
||||
fn len(&self) -> usize;
|
||||
|
||||
/// Check if the index is empty
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
108
crates/ruvector-core/src/index/flat.rs
Normal file
108
crates/ruvector-core/src/index/flat.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
//! Flat (brute-force) index for baseline and small datasets
|
||||
|
||||
use crate::distance::distance;
|
||||
use crate::error::Result;
|
||||
use crate::index::VectorIndex;
|
||||
use crate::types::{DistanceMetric, SearchResult, VectorId};
|
||||
use dashmap::DashMap;
|
||||
|
||||
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Flat index using brute-force search
|
||||
pub struct FlatIndex {
|
||||
vectors: DashMap<VectorId, Vec<f32>>,
|
||||
metric: DistanceMetric,
|
||||
_dimensions: usize,
|
||||
}
|
||||
|
||||
impl FlatIndex {
|
||||
/// Create a new flat index
|
||||
pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
|
||||
Self {
|
||||
vectors: DashMap::new(),
|
||||
metric,
|
||||
_dimensions: dimensions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VectorIndex for FlatIndex {
|
||||
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
|
||||
self.vectors.insert(id, vector);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
|
||||
// Distance calculation - parallel on native, sequential on WASM
|
||||
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
|
||||
let mut results: Vec<_> = self
|
||||
.vectors
|
||||
.iter()
|
||||
.par_bridge()
|
||||
.map(|entry| {
|
||||
let id = entry.key().clone();
|
||||
let vector = entry.value();
|
||||
let dist = distance(query, vector, self.metric)?;
|
||||
Ok((id, dist))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
#[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
|
||||
let mut results: Vec<_> = self
|
||||
.vectors
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
let id = entry.key().clone();
|
||||
let vector = entry.value();
|
||||
let dist = distance(query, vector, self.metric)?;
|
||||
Ok((id, dist))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// Sort by distance and take top k
|
||||
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
results.truncate(k);
|
||||
|
||||
Ok(results
|
||||
.into_iter()
|
||||
.map(|(id, score)| SearchResult {
|
||||
id,
|
||||
score,
|
||||
vector: None,
|
||||
metadata: None,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn remove(&mut self, id: &VectorId) -> Result<bool> {
|
||||
Ok(self.vectors.remove(id).is_some())
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.vectors.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_flat_index() -> Result<()> {
|
||||
let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
|
||||
|
||||
index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
|
||||
index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
|
||||
index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let results = index.search(&query, 2)?;
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].id, "v1");
|
||||
assert!(results[0].score < 0.01);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
481
crates/ruvector-core/src/index/hnsw.rs
Normal file
481
crates/ruvector-core/src/index/hnsw.rs
Normal file
@@ -0,0 +1,481 @@
|
||||
//! HNSW (Hierarchical Navigable Small World) index implementation
|
||||
|
||||
use crate::distance::distance;
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use crate::index::VectorIndex;
|
||||
use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
|
||||
use bincode::{Decode, Encode};
|
||||
use dashmap::DashMap;
|
||||
use hnsw_rs::prelude::*;
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Distance function wrapper for hnsw_rs
|
||||
struct DistanceFn {
|
||||
metric: DistanceMetric,
|
||||
}
|
||||
|
||||
impl DistanceFn {
|
||||
fn new(metric: DistanceMetric) -> Self {
|
||||
Self { metric }
|
||||
}
|
||||
}
|
||||
|
||||
impl Distance<f32> for DistanceFn {
|
||||
fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
distance(a, b, self.metric).unwrap_or(f32::MAX)
|
||||
}
|
||||
}
|
||||
|
||||
/// HNSW index wrapper
|
||||
pub struct HnswIndex {
|
||||
inner: Arc<RwLock<HnswInner>>,
|
||||
config: HnswConfig,
|
||||
metric: DistanceMetric,
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
struct HnswInner {
|
||||
hnsw: Hnsw<'static, f32, DistanceFn>,
|
||||
vectors: DashMap<VectorId, Vec<f32>>,
|
||||
id_to_idx: DashMap<VectorId, usize>,
|
||||
idx_to_id: DashMap<usize, VectorId>,
|
||||
next_idx: usize,
|
||||
}
|
||||
|
||||
/// Serializable HNSW index state
|
||||
#[derive(Encode, Decode, Clone)]
|
||||
pub struct HnswState {
|
||||
vectors: Vec<(String, Vec<f32>)>,
|
||||
id_to_idx: Vec<(String, usize)>,
|
||||
idx_to_id: Vec<(usize, String)>,
|
||||
next_idx: usize,
|
||||
config: SerializableHnswConfig,
|
||||
dimensions: usize,
|
||||
metric: SerializableDistanceMetric,
|
||||
}
|
||||
|
||||
#[derive(Encode, Decode, Clone)]
|
||||
struct SerializableHnswConfig {
|
||||
m: usize,
|
||||
ef_construction: usize,
|
||||
ef_search: usize,
|
||||
max_elements: usize,
|
||||
}
|
||||
|
||||
#[derive(Encode, Decode, Clone, Copy)]
|
||||
enum SerializableDistanceMetric {
|
||||
Euclidean,
|
||||
Cosine,
|
||||
DotProduct,
|
||||
Manhattan,
|
||||
}
|
||||
|
||||
impl From<DistanceMetric> for SerializableDistanceMetric {
|
||||
fn from(metric: DistanceMetric) -> Self {
|
||||
match metric {
|
||||
DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
|
||||
DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
|
||||
DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
|
||||
DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SerializableDistanceMetric> for DistanceMetric {
|
||||
fn from(metric: SerializableDistanceMetric) -> Self {
|
||||
match metric {
|
||||
SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
|
||||
SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
|
||||
SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
|
||||
SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HnswIndex {
|
||||
/// Create a new HNSW index
|
||||
pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
|
||||
let distance_fn = DistanceFn::new(metric);
|
||||
|
||||
// Create HNSW with configured parameters
|
||||
let hnsw = Hnsw::<f32, DistanceFn>::new(
|
||||
config.m,
|
||||
config.max_elements,
|
||||
dimensions,
|
||||
config.ef_construction,
|
||||
distance_fn,
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(RwLock::new(HnswInner {
|
||||
hnsw,
|
||||
vectors: DashMap::new(),
|
||||
id_to_idx: DashMap::new(),
|
||||
idx_to_id: DashMap::new(),
|
||||
next_idx: 0,
|
||||
})),
|
||||
config,
|
||||
metric,
|
||||
dimensions,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get configuration
|
||||
pub fn config(&self) -> &HnswConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Set efSearch parameter for query-time accuracy tuning
|
||||
pub fn set_ef_search(&mut self, _ef_search: usize) {
|
||||
// Note: hnsw_rs controls ef_search via the search method's knbn parameter
|
||||
// We store it in config and use it in search_with_ef
|
||||
}
|
||||
|
||||
/// Serialize the index to bytes using bincode
|
||||
pub fn serialize(&self) -> Result<Vec<u8>> {
|
||||
let inner = self.inner.read();
|
||||
|
||||
let state = HnswState {
|
||||
vectors: inner
|
||||
.vectors
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
||||
.collect(),
|
||||
id_to_idx: inner
|
||||
.id_to_idx
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), *entry.value()))
|
||||
.collect(),
|
||||
idx_to_id: inner
|
||||
.idx_to_id
|
||||
.iter()
|
||||
.map(|entry| (*entry.key(), entry.value().clone()))
|
||||
.collect(),
|
||||
next_idx: inner.next_idx,
|
||||
config: SerializableHnswConfig {
|
||||
m: self.config.m,
|
||||
ef_construction: self.config.ef_construction,
|
||||
ef_search: self.config.ef_search,
|
||||
max_elements: self.config.max_elements,
|
||||
},
|
||||
dimensions: self.dimensions,
|
||||
metric: self.metric.into(),
|
||||
};
|
||||
|
||||
bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
|
||||
RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Deserialize the index from bytes using bincode
|
||||
pub fn deserialize(bytes: &[u8]) -> Result<Self> {
|
||||
let (state, _): (HnswState, usize) =
|
||||
bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
|
||||
RuvectorError::SerializationError(format!(
|
||||
"Failed to deserialize HNSW index: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: state.config.m,
|
||||
ef_construction: state.config.ef_construction,
|
||||
ef_search: state.config.ef_search,
|
||||
max_elements: state.config.max_elements,
|
||||
};
|
||||
|
||||
let dimensions = state.dimensions;
|
||||
let metric: DistanceMetric = state.metric.into();
|
||||
|
||||
let distance_fn = DistanceFn::new(metric);
|
||||
let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
|
||||
config.m,
|
||||
config.max_elements,
|
||||
dimensions,
|
||||
config.ef_construction,
|
||||
distance_fn,
|
||||
);
|
||||
|
||||
// Rebuild the index by inserting all vectors
|
||||
let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
|
||||
let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
|
||||
|
||||
// Insert vectors into HNSW in order
|
||||
for entry in idx_to_id.iter() {
|
||||
let idx = *entry.key();
|
||||
let id = entry.value();
|
||||
if let Some(vector) = state.vectors.iter().find(|(vid, _)| vid == id) {
|
||||
// Use insert_data method with slice and idx
|
||||
hnsw.insert_data(&vector.1, idx);
|
||||
}
|
||||
}
|
||||
|
||||
let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(RwLock::new(HnswInner {
|
||||
hnsw,
|
||||
vectors: vectors_map,
|
||||
id_to_idx,
|
||||
idx_to_id,
|
||||
next_idx: state.next_idx,
|
||||
})),
|
||||
config,
|
||||
metric,
|
||||
dimensions,
|
||||
})
|
||||
}
|
||||
|
||||
/// Search with custom efSearch parameter
|
||||
pub fn search_with_ef(
|
||||
&self,
|
||||
query: &[f32],
|
||||
k: usize,
|
||||
ef_search: usize,
|
||||
) -> Result<Vec<SearchResult>> {
|
||||
if query.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: query.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let inner = self.inner.read();
|
||||
|
||||
// Use HNSW search with custom ef parameter (knbn)
|
||||
let neighbors = inner.hnsw.search(query, k, ef_search);
|
||||
|
||||
Ok(neighbors
|
||||
.into_iter()
|
||||
.filter_map(|neighbor| {
|
||||
inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
|
||||
id: id.clone(),
|
||||
score: neighbor.distance,
|
||||
vector: None,
|
||||
metadata: None,
|
||||
})
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl VectorIndex for HnswIndex {
|
||||
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
|
||||
if vector.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: vector.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut inner = self.inner.write();
|
||||
let idx = inner.next_idx;
|
||||
inner.next_idx += 1;
|
||||
|
||||
// Insert into HNSW graph using insert_data
|
||||
inner.hnsw.insert_data(&vector, idx);
|
||||
|
||||
// Store mappings
|
||||
inner.vectors.insert(id.clone(), vector);
|
||||
inner.id_to_idx.insert(id.clone(), idx);
|
||||
inner.idx_to_id.insert(idx, id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
|
||||
// Validate all dimensions first
|
||||
for (_, vector) in &entries {
|
||||
if vector.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: vector.len(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut inner = self.inner.write();
|
||||
|
||||
// Prepare batch data for insertion
|
||||
// First, assign indices and collect vector data
|
||||
let data_with_ids: Vec<_> = entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (id, vector))| {
|
||||
let idx = inner.next_idx + i;
|
||||
(id.clone(), idx, vector.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Update next_idx
|
||||
inner.next_idx += entries.len();
|
||||
|
||||
// Insert into HNSW sequentially
|
||||
// Note: Using sequential insertion to avoid Send requirements with RwLock guard
|
||||
// For large batches, consider restructuring to use hnsw_rs parallel_insert
|
||||
for (_id, idx, vector) in &data_with_ids {
|
||||
inner.hnsw.insert_data(vector, *idx);
|
||||
}
|
||||
|
||||
// Store mappings
|
||||
for (id, idx, vector) in data_with_ids {
|
||||
inner.vectors.insert(id.clone(), vector);
|
||||
inner.id_to_idx.insert(id.clone(), idx);
|
||||
inner.idx_to_id.insert(idx, id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
|
||||
// Use configured ef_search
|
||||
self.search_with_ef(query, k, self.config.ef_search)
|
||||
}
|
||||
|
||||
fn remove(&mut self, id: &VectorId) -> Result<bool> {
|
||||
let inner = self.inner.write();
|
||||
|
||||
// Note: hnsw_rs doesn't support direct deletion
|
||||
// We remove from our mappings but the graph structure remains
|
||||
// This is a known limitation of HNSW
|
||||
let removed = inner.vectors.remove(id).is_some();
|
||||
|
||||
if removed {
|
||||
if let Some((_, idx)) = inner.id_to_idx.remove(id) {
|
||||
inner.idx_to_id.remove(&idx);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.inner.read().vectors.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
(0..count)
|
||||
.map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn normalize_vector(v: &[f32]) -> Vec<f32> {
|
||||
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
v.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_index_creation() -> Result<()> {
|
||||
let config = HnswConfig::default();
|
||||
let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
|
||||
assert_eq!(index.len(), 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_insert_and_search() -> Result<()> {
|
||||
let config = HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 50,
|
||||
max_elements: 1000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
|
||||
|
||||
// Insert a few vectors
|
||||
let vectors = generate_random_vectors(100, 128);
|
||||
for (i, vector) in vectors.iter().enumerate() {
|
||||
let normalized = normalize_vector(vector);
|
||||
index.add(format!("vec_{}", i), normalized)?;
|
||||
}
|
||||
|
||||
assert_eq!(index.len(), 100);
|
||||
|
||||
// Search for the first vector
|
||||
let query = normalize_vector(&vectors[0]);
|
||||
let results = index.search(&query, 10)?;
|
||||
|
||||
assert!(!results.is_empty());
|
||||
assert_eq!(results[0].id, "vec_0");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_batch_insert() -> Result<()> {
|
||||
let config = HnswConfig::default();
|
||||
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(100, 128);
|
||||
let entries: Vec<_> = vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
assert_eq!(index.len(), 100);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_serialization() -> Result<()> {
|
||||
let config = HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 50,
|
||||
max_elements: 1000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
|
||||
|
||||
// Insert vectors
|
||||
let vectors = generate_random_vectors(50, 128);
|
||||
for (i, vector) in vectors.iter().enumerate() {
|
||||
let normalized = normalize_vector(vector);
|
||||
index.add(format!("vec_{}", i), normalized)?;
|
||||
}
|
||||
|
||||
// Serialize
|
||||
let bytes = index.serialize()?;
|
||||
|
||||
// Deserialize
|
||||
let restored_index = HnswIndex::deserialize(&bytes)?;
|
||||
|
||||
assert_eq!(restored_index.len(), 50);
|
||||
|
||||
// Test search on restored index
|
||||
let query = normalize_vector(&vectors[0]);
|
||||
let results = restored_index.search(&query, 5)?;
|
||||
|
||||
assert!(!results.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch() -> Result<()> {
|
||||
let config = HnswConfig::default();
|
||||
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
|
||||
|
||||
let result = index.add("test".to_string(), vec![1.0; 64]);
|
||||
assert!(result.is_err());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
142
crates/ruvector-core/src/lib.rs
Normal file
142
crates/ruvector-core/src/lib.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
//! # Ruvector Core
|
||||
//!
|
||||
//! High-performance Rust-native vector database with HNSW indexing and SIMD-optimized operations.
|
||||
//!
|
||||
//! ## Working Features (Tested & Benchmarked)
|
||||
//!
|
||||
//! - **HNSW Indexing**: Approximate nearest neighbor search with O(log n) complexity
|
||||
//! - **SIMD Distance**: SimSIMD-powered distance calculations (~16M ops/sec for 512-dim)
|
||||
//! - **Quantization**: Scalar (4x), Int4 (8x), Product (8-16x), and binary (32x) compression with distance support
|
||||
//! - **Persistence**: REDB-based storage with config persistence
|
||||
//! - **Search**: ~2.5K queries/sec on 10K vectors (benchmarked)
|
||||
//!
|
||||
//! ## ⚠️ Experimental/Incomplete Features - READ BEFORE USE
|
||||
//!
|
||||
//! - **AgenticDB**: ⚠️⚠️⚠️ **CRITICAL WARNING** ⚠️⚠️⚠️
|
||||
//! - Uses PLACEHOLDER hash-based embeddings, NOT real semantic embeddings
|
||||
//! - "dog" and "cat" will NOT be similar (different characters)
|
||||
//! - "dog" and "god" WILL be similar (same characters) - **This is wrong!**
|
||||
//! - **MUST integrate real embedding model for production** (ONNX, Candle, or API)
|
||||
//! - See [`agenticdb`] module docs and `/examples/onnx-embeddings` for integration
|
||||
//! - **Advanced Features**: Conformal prediction, hybrid search - functional but less tested
|
||||
//!
|
||||
//! ## What This Is NOT
|
||||
//!
|
||||
//! - This is NOT a complete RAG solution - you need external embedding models
|
||||
//! - Examples use mock embeddings for demonstration only
|
||||
|
||||
#![allow(missing_docs)]
|
||||
#![warn(clippy::all)]
|
||||
#![allow(clippy::incompatible_msrv)]
|
||||
|
||||
pub mod advanced_features;
|
||||
|
||||
// AgenticDB requires storage feature
|
||||
#[cfg(feature = "storage")]
|
||||
pub mod agenticdb;
|
||||
|
||||
pub mod distance;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod quantization;
|
||||
|
||||
// Storage backends - conditional compilation based on features
|
||||
#[cfg(feature = "storage")]
|
||||
pub mod storage;
|
||||
|
||||
#[cfg(not(feature = "storage"))]
|
||||
pub mod storage_memory;
|
||||
|
||||
#[cfg(not(feature = "storage"))]
|
||||
pub use storage_memory as storage;
|
||||
|
||||
pub mod types;
|
||||
pub mod vector_db;
|
||||
|
||||
// Performance optimization modules
|
||||
pub mod arena;
|
||||
pub mod cache_optimized;
|
||||
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
|
||||
pub mod lockfree;
|
||||
pub mod simd_intrinsics;
|
||||
|
||||
/// Unified Memory Pool and Paging System (ADR-006)
|
||||
///
|
||||
/// High-performance paged memory management for LLM inference:
|
||||
/// - 2MB page-granular allocation with best-fit strategy
|
||||
/// - Reference-counted pinning with RAII guards
|
||||
/// - LRU eviction with hysteresis for thrash prevention
|
||||
/// - Multi-tenant isolation with Hot/Warm/Cold residency tiers
|
||||
pub mod memory;
|
||||
|
||||
/// Advanced techniques: hypergraphs, learned indexes, neural hashing, TDA (Phase 6)
|
||||
pub mod advanced;
|
||||
|
||||
// Re-exports
|
||||
pub use advanced_features::{
|
||||
ConformalConfig, ConformalPredictor, EnhancedPQ, FilterExpression, FilterStrategy,
|
||||
FilteredSearch, HybridConfig, HybridSearch, MMRConfig, MMRSearch, PQConfig, PredictionSet,
|
||||
BM25,
|
||||
};
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
pub use agenticdb::{
|
||||
AgenticDB, PolicyAction, PolicyEntry, PolicyMemoryStore, SessionStateIndex, SessionTurn,
|
||||
WitnessEntry, WitnessLog,
|
||||
};
|
||||
|
||||
#[cfg(feature = "api-embeddings")]
|
||||
pub use embeddings::ApiEmbedding;
|
||||
pub use embeddings::{BoxedEmbeddingProvider, EmbeddingProvider, HashEmbedding};
|
||||
|
||||
#[cfg(feature = "real-embeddings")]
|
||||
pub use embeddings::CandleEmbedding;
|
||||
|
||||
// Compile-time warning about AgenticDB limitations
|
||||
#[cfg(feature = "storage")]
|
||||
#[allow(deprecated, clippy::let_unit_value)]
|
||||
const _: () = {
|
||||
#[deprecated(
|
||||
since = "0.1.0",
|
||||
note = "AgenticDB uses placeholder hash-based embeddings. For semantic search, integrate a real embedding model (ONNX, Candle, or API). See /examples/onnx-embeddings for production setup."
|
||||
)]
|
||||
const AGENTICDB_EMBEDDING_WARNING: () = ();
|
||||
let _ = AGENTICDB_EMBEDDING_WARNING;
|
||||
};
|
||||
|
||||
pub use error::{Result, RuvectorError};
|
||||
pub use types::{DistanceMetric, SearchQuery, SearchResult, VectorEntry, VectorId};
|
||||
pub use vector_db::VectorDB;
|
||||
|
||||
// Quantization types (ADR-001)
|
||||
pub use quantization::{
|
||||
BinaryQuantized, Int4Quantized, ProductQuantized, QuantizedVector, ScalarQuantized,
|
||||
};
|
||||
|
||||
// Memory management types (ADR-001)
|
||||
pub use arena::{Arena, ArenaVec, BatchVectorAllocator, CacheAlignedVec, CACHE_LINE_SIZE};
|
||||
|
||||
// Lock-free structures (requires parallel feature)
|
||||
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
|
||||
pub use lockfree::{
|
||||
AtomicVectorPool, BatchItem, BatchResult, LockFreeBatchProcessor, LockFreeCounter,
|
||||
LockFreeStats, LockFreeWorkQueue, ObjectPool, PooledObject, PooledVector, StatsSnapshot,
|
||||
VectorPoolStats,
|
||||
};
|
||||
|
||||
// Cache-optimized storage
|
||||
pub use cache_optimized::SoAVectorStorage;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_version() {
|
||||
// Verify version matches workspace - use dynamic check instead of hardcoded value
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
assert!(!version.is_empty(), "Version should not be empty");
|
||||
assert!(version.starts_with("0.1."), "Version should be 0.1.x");
|
||||
}
|
||||
}
|
||||
590
crates/ruvector-core/src/lockfree.rs
Normal file
590
crates/ruvector-core/src/lockfree.rs
Normal file
@@ -0,0 +1,590 @@
|
||||
//! Lock-free data structures for high-concurrency operations
|
||||
//!
|
||||
//! This module provides lock-free implementations of common data structures
|
||||
//! to minimize contention and improve scalability.
|
||||
//!
|
||||
//! Note: This module requires the `parallel` feature and is not available on WASM.
|
||||
|
||||
#![cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
|
||||
|
||||
use crossbeam::queue::{ArrayQueue, SegQueue};
|
||||
use crossbeam::utils::CachePadded;
|
||||
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Lock-free counter with cache padding to prevent false sharing
|
||||
#[repr(align(64))]
|
||||
pub struct LockFreeCounter {
|
||||
value: CachePadded<AtomicU64>,
|
||||
}
|
||||
|
||||
impl LockFreeCounter {
|
||||
pub fn new(initial: u64) -> Self {
|
||||
Self {
|
||||
value: CachePadded::new(AtomicU64::new(initial)),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn increment(&self) -> u64 {
|
||||
self.value.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get(&self) -> u64 {
|
||||
self.value.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn add(&self, delta: u64) -> u64 {
|
||||
self.value.fetch_add(delta, Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Lock-free statistics collector
|
||||
pub struct LockFreeStats {
|
||||
queries: CachePadded<AtomicU64>,
|
||||
inserts: CachePadded<AtomicU64>,
|
||||
deletes: CachePadded<AtomicU64>,
|
||||
total_latency_ns: CachePadded<AtomicU64>,
|
||||
}
|
||||
|
||||
impl LockFreeStats {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
queries: CachePadded::new(AtomicU64::new(0)),
|
||||
inserts: CachePadded::new(AtomicU64::new(0)),
|
||||
deletes: CachePadded::new(AtomicU64::new(0)),
|
||||
total_latency_ns: CachePadded::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn record_query(&self, latency_ns: u64) {
|
||||
self.queries.fetch_add(1, Ordering::Relaxed);
|
||||
self.total_latency_ns
|
||||
.fetch_add(latency_ns, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn record_insert(&self) {
|
||||
self.inserts.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn record_delete(&self) {
|
||||
self.deletes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn snapshot(&self) -> StatsSnapshot {
|
||||
let queries = self.queries.load(Ordering::Relaxed);
|
||||
let total_latency = self.total_latency_ns.load(Ordering::Relaxed);
|
||||
|
||||
StatsSnapshot {
|
||||
queries,
|
||||
inserts: self.inserts.load(Ordering::Relaxed),
|
||||
deletes: self.deletes.load(Ordering::Relaxed),
|
||||
avg_latency_ns: if queries > 0 {
|
||||
total_latency / queries
|
||||
} else {
|
||||
0
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LockFreeStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StatsSnapshot {
|
||||
pub queries: u64,
|
||||
pub inserts: u64,
|
||||
pub deletes: u64,
|
||||
pub avg_latency_ns: u64,
|
||||
}
|
||||
|
||||
/// Lock-free object pool for reducing allocations
|
||||
pub struct ObjectPool<T> {
|
||||
queue: Arc<SegQueue<T>>,
|
||||
factory: Arc<dyn Fn() -> T + Send + Sync>,
|
||||
capacity: usize,
|
||||
allocated: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<T> ObjectPool<T> {
|
||||
pub fn new<F>(capacity: usize, factory: F) -> Self
|
||||
where
|
||||
F: Fn() -> T + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
queue: Arc::new(SegQueue::new()),
|
||||
factory: Arc::new(factory),
|
||||
capacity,
|
||||
allocated: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an object from the pool or create a new one
|
||||
pub fn acquire(&self) -> PooledObject<T> {
|
||||
let object = self.queue.pop().unwrap_or_else(|| {
|
||||
let current = self.allocated.fetch_add(1, Ordering::Relaxed);
|
||||
if current < self.capacity {
|
||||
(self.factory)()
|
||||
} else {
|
||||
self.allocated.fetch_sub(1, Ordering::Relaxed);
|
||||
// Wait for an object to be returned
|
||||
loop {
|
||||
if let Some(obj) = self.queue.pop() {
|
||||
break obj;
|
||||
}
|
||||
std::hint::spin_loop();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
PooledObject {
|
||||
object: Some(object),
|
||||
pool: Arc::clone(&self.queue),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII wrapper for pooled objects
|
||||
pub struct PooledObject<T> {
|
||||
object: Option<T>,
|
||||
pool: Arc<SegQueue<T>>,
|
||||
}
|
||||
|
||||
impl<T> PooledObject<T> {
|
||||
pub fn get(&self) -> &T {
|
||||
self.object.as_ref().unwrap()
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self) -> &mut T {
|
||||
self.object.as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PooledObject<T> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(object) = self.object.take() {
|
||||
self.pool.push(object);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for PooledObject<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.object.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for PooledObject<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.object.as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Lock-free ring buffer for work distribution
|
||||
pub struct LockFreeWorkQueue<T> {
|
||||
queue: ArrayQueue<T>,
|
||||
}
|
||||
|
||||
impl<T> LockFreeWorkQueue<T> {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
queue: ArrayQueue::new(capacity),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn try_push(&self, item: T) -> Result<(), T> {
|
||||
self.queue.push(item)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn try_pop(&self) -> Option<T> {
|
||||
self.queue.pop()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.queue.len()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.queue.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Atomic vector pool for lock-free vector operations (ADR-001)
|
||||
///
|
||||
/// Provides a pool of pre-allocated vectors that can be acquired and released
|
||||
/// without locking, ideal for high-throughput batch operations.
|
||||
pub struct AtomicVectorPool {
|
||||
/// Pool of available vectors
|
||||
pool: SegQueue<Vec<f32>>,
|
||||
/// Dimensions per vector
|
||||
dimensions: usize,
|
||||
/// Maximum pool size
|
||||
max_size: usize,
|
||||
/// Current pool size
|
||||
size: AtomicUsize,
|
||||
/// Total allocations
|
||||
total_allocations: AtomicU64,
|
||||
/// Pool hits (reused vectors)
|
||||
pool_hits: AtomicU64,
|
||||
}
|
||||
|
||||
impl AtomicVectorPool {
|
||||
/// Create a new atomic vector pool
|
||||
pub fn new(dimensions: usize, initial_size: usize, max_size: usize) -> Self {
|
||||
let pool = SegQueue::new();
|
||||
|
||||
// Pre-allocate vectors
|
||||
for _ in 0..initial_size {
|
||||
pool.push(vec![0.0; dimensions]);
|
||||
}
|
||||
|
||||
Self {
|
||||
pool,
|
||||
dimensions,
|
||||
max_size,
|
||||
size: AtomicUsize::new(initial_size),
|
||||
total_allocations: AtomicU64::new(0),
|
||||
pool_hits: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire a vector from the pool (or allocate new one)
|
||||
pub fn acquire(&self) -> PooledVector<'_> {
|
||||
self.total_allocations.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let vec = if let Some(mut v) = self.pool.pop() {
|
||||
self.pool_hits.fetch_add(1, Ordering::Relaxed);
|
||||
// Clear the vector for reuse
|
||||
v.fill(0.0);
|
||||
v
|
||||
} else {
|
||||
// Allocate new vector
|
||||
vec![0.0; self.dimensions]
|
||||
};
|
||||
|
||||
PooledVector {
|
||||
vec: Some(vec),
|
||||
pool: self,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a vector to the pool
|
||||
fn return_to_pool(&self, vec: Vec<f32>) {
|
||||
let current_size = self.size.load(Ordering::Relaxed);
|
||||
if current_size < self.max_size {
|
||||
self.pool.push(vec);
|
||||
self.size.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
// If pool is full, vector is dropped
|
||||
}
|
||||
|
||||
/// Get pool statistics
|
||||
pub fn stats(&self) -> VectorPoolStats {
|
||||
let total = self.total_allocations.load(Ordering::Relaxed);
|
||||
let hits = self.pool_hits.load(Ordering::Relaxed);
|
||||
let hit_rate = if total > 0 {
|
||||
hits as f64 / total as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
VectorPoolStats {
|
||||
total_allocations: total,
|
||||
pool_hits: hits,
|
||||
hit_rate,
|
||||
current_size: self.size.load(Ordering::Relaxed),
|
||||
max_size: self.max_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get dimensions
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for the vector pool
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorPoolStats {
|
||||
pub total_allocations: u64,
|
||||
pub pool_hits: u64,
|
||||
pub hit_rate: f64,
|
||||
pub current_size: usize,
|
||||
pub max_size: usize,
|
||||
}
|
||||
|
||||
/// RAII wrapper for pooled vectors
|
||||
pub struct PooledVector<'a> {
|
||||
vec: Option<Vec<f32>>,
|
||||
pool: &'a AtomicVectorPool,
|
||||
}
|
||||
|
||||
impl<'a> PooledVector<'a> {
|
||||
/// Get as slice
|
||||
pub fn as_slice(&self) -> &[f32] {
|
||||
self.vec.as_ref().unwrap()
|
||||
}
|
||||
|
||||
/// Get as mutable slice
|
||||
pub fn as_mut_slice(&mut self) -> &mut [f32] {
|
||||
self.vec.as_mut().unwrap()
|
||||
}
|
||||
|
||||
/// Copy from source slice
|
||||
pub fn copy_from(&mut self, src: &[f32]) {
|
||||
let vec = self.vec.as_mut().unwrap();
|
||||
assert_eq!(vec.len(), src.len(), "Dimension mismatch");
|
||||
vec.copy_from_slice(src);
|
||||
}
|
||||
|
||||
/// Detach the vector from the pool (it won't be returned)
|
||||
pub fn detach(mut self) -> Vec<f32> {
|
||||
self.vec.take().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for PooledVector<'a> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(vec) = self.vec.take() {
|
||||
self.pool.return_to_pool(vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> std::ops::Deref for PooledVector<'a> {
|
||||
type Target = [f32];
|
||||
|
||||
fn deref(&self) -> &[f32] {
|
||||
self.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> std::ops::DerefMut for PooledVector<'a> {
|
||||
fn deref_mut(&mut self) -> &mut [f32] {
|
||||
self.as_mut_slice()
|
||||
}
|
||||
}
|
||||
|
||||
/// Lock-free batch processor for parallel vector operations (ADR-001)
|
||||
///
|
||||
/// Distributes work across multiple workers without contention.
|
||||
pub struct LockFreeBatchProcessor {
|
||||
/// Work queue for pending items
|
||||
work_queue: ArrayQueue<BatchItem>,
|
||||
/// Results queue
|
||||
results_queue: SegQueue<BatchResult>,
|
||||
/// Pending count
|
||||
pending: AtomicUsize,
|
||||
/// Completed count
|
||||
completed: AtomicUsize,
|
||||
}
|
||||
|
||||
/// Item in the batch work queue
|
||||
#[derive(Debug)]
|
||||
pub struct BatchItem {
|
||||
pub id: u64,
|
||||
pub data: Vec<f32>,
|
||||
}
|
||||
|
||||
/// Result from batch processing
|
||||
pub struct BatchResult {
|
||||
pub id: u64,
|
||||
pub result: Vec<f32>,
|
||||
}
|
||||
|
||||
impl LockFreeBatchProcessor {
|
||||
/// Create a new batch processor with given capacity
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
work_queue: ArrayQueue::new(capacity),
|
||||
results_queue: SegQueue::new(),
|
||||
pending: AtomicUsize::new(0),
|
||||
completed: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Submit a batch item for processing
|
||||
pub fn submit(&self, item: BatchItem) -> Result<(), BatchItem> {
|
||||
self.pending.fetch_add(1, Ordering::Relaxed);
|
||||
self.work_queue.push(item)
|
||||
}
|
||||
|
||||
/// Try to get a work item (for workers)
|
||||
pub fn try_get_work(&self) -> Option<BatchItem> {
|
||||
self.work_queue.pop()
|
||||
}
|
||||
|
||||
/// Submit a result (from workers)
|
||||
pub fn submit_result(&self, result: BatchResult) {
|
||||
self.completed.fetch_add(1, Ordering::Relaxed);
|
||||
self.results_queue.push(result);
|
||||
}
|
||||
|
||||
/// Collect all available results
|
||||
pub fn collect_results(&self) -> Vec<BatchResult> {
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = self.results_queue.pop() {
|
||||
results.push(result);
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
/// Get pending count
|
||||
pub fn pending(&self) -> usize {
|
||||
self.pending.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get completed count
|
||||
pub fn completed(&self) -> usize {
|
||||
self.completed.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Check if all work is done
|
||||
pub fn is_done(&self) -> bool {
|
||||
self.pending() == self.completed()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn test_lockfree_counter() {
|
||||
let counter = Arc::new(LockFreeCounter::new(0));
|
||||
let mut handles = vec![];
|
||||
|
||||
for _ in 0..10 {
|
||||
let counter_clone = Arc::clone(&counter);
|
||||
handles.push(thread::spawn(move || {
|
||||
for _ in 0..1000 {
|
||||
counter_clone.increment();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(counter.get(), 10000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_object_pool() {
|
||||
let pool = ObjectPool::new(4, || Vec::<u8>::with_capacity(1024));
|
||||
|
||||
let mut obj1 = pool.acquire();
|
||||
obj1.push(1);
|
||||
assert_eq!(obj1.len(), 1);
|
||||
|
||||
drop(obj1);
|
||||
|
||||
let obj2 = pool.acquire();
|
||||
// Object should be reused (but cleared state is not guaranteed)
|
||||
assert!(obj2.capacity() >= 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats_collector() {
|
||||
let stats = LockFreeStats::new();
|
||||
|
||||
stats.record_query(1000);
|
||||
stats.record_query(2000);
|
||||
stats.record_insert();
|
||||
|
||||
let snapshot = stats.snapshot();
|
||||
assert_eq!(snapshot.queries, 2);
|
||||
assert_eq!(snapshot.inserts, 1);
|
||||
assert_eq!(snapshot.avg_latency_ns, 1500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_vector_pool() {
|
||||
let pool = AtomicVectorPool::new(4, 2, 10);
|
||||
|
||||
// Acquire first vector
|
||||
let mut v1 = pool.acquire();
|
||||
v1.copy_from(&[1.0, 2.0, 3.0, 4.0]);
|
||||
assert_eq!(v1.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
|
||||
|
||||
// Acquire second vector
|
||||
let mut v2 = pool.acquire();
|
||||
v2.copy_from(&[5.0, 6.0, 7.0, 8.0]);
|
||||
|
||||
// Stats should show allocations
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.total_allocations, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_pool_reuse() {
|
||||
let pool = AtomicVectorPool::new(3, 1, 5);
|
||||
|
||||
// Acquire and release
|
||||
{
|
||||
let mut v = pool.acquire();
|
||||
v.copy_from(&[1.0, 2.0, 3.0]);
|
||||
} // v is returned to pool here
|
||||
|
||||
// Acquire again - should be a pool hit
|
||||
let _v2 = pool.acquire();
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.total_allocations, 2);
|
||||
assert!(stats.pool_hits >= 1, "Should have at least one pool hit");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_processor() {
|
||||
let processor = LockFreeBatchProcessor::new(10);
|
||||
|
||||
// Submit work items
|
||||
processor
|
||||
.submit(BatchItem {
|
||||
id: 1,
|
||||
data: vec![1.0, 2.0],
|
||||
})
|
||||
.unwrap();
|
||||
processor
|
||||
.submit(BatchItem {
|
||||
id: 2,
|
||||
data: vec![3.0, 4.0],
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(processor.pending(), 2);
|
||||
|
||||
// Process work
|
||||
while let Some(item) = processor.try_get_work() {
|
||||
let result = BatchResult {
|
||||
id: item.id,
|
||||
result: item.data.iter().map(|x| x * 2.0).collect(),
|
||||
};
|
||||
processor.submit_result(result);
|
||||
}
|
||||
|
||||
assert!(processor.is_done());
|
||||
assert_eq!(processor.completed(), 2);
|
||||
|
||||
// Collect results
|
||||
let results = processor.collect_results();
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
}
|
||||
38
crates/ruvector-core/src/memory.rs
Normal file
38
crates/ruvector-core/src/memory.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
//! Memory management utilities for ruvector-core
|
||||
//!
|
||||
//! This module provides memory-efficient data structures and utilities
|
||||
//! for vector storage operations.
|
||||
|
||||
/// Memory pool for vector allocations.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MemoryPool {
|
||||
/// Total allocated bytes.
|
||||
allocated: usize,
|
||||
/// Maximum allocation limit.
|
||||
limit: Option<usize>,
|
||||
}
|
||||
|
||||
impl MemoryPool {
|
||||
/// Create a new memory pool.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create a memory pool with a limit.
|
||||
pub fn with_limit(limit: usize) -> Self {
|
||||
Self {
|
||||
allocated: 0,
|
||||
limit: Some(limit),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get currently allocated bytes.
|
||||
pub fn allocated(&self) -> usize {
|
||||
self.allocated
|
||||
}
|
||||
|
||||
/// Get the allocation limit, if any.
|
||||
pub fn limit(&self) -> Option<usize> {
|
||||
self.limit
|
||||
}
|
||||
}
|
||||
934
crates/ruvector-core/src/quantization.rs
Normal file
934
crates/ruvector-core/src/quantization.rs
Normal file
@@ -0,0 +1,934 @@
|
||||
//! Quantization techniques for memory compression
|
||||
//!
|
||||
//! This module provides tiered quantization strategies as specified in ADR-001:
|
||||
//!
|
||||
//! | Quantization | Compression | Use Case |
|
||||
//! |--------------|-------------|----------|
|
||||
//! | Scalar (u8) | 4x | Warm data (40-80% access) |
|
||||
//! | Int4 | 8x | Cool data (10-40% access) |
|
||||
//! | Product | 8-16x | Cold data (1-10% access) |
|
||||
//! | Binary | 32x | Archive (<1% access) |
|
||||
//!
|
||||
//! ## Performance Optimizations v2
|
||||
//!
|
||||
//! - SIMD-accelerated distance calculations for scalar (int8) quantization
|
||||
//! - SIMD popcnt for binary hamming distance
|
||||
//! - 4x loop unrolling for better instruction-level parallelism
|
||||
//! - Separate accumulator strategy to reduce data dependencies
|
||||
|
||||
use crate::error::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Trait for quantized vector representations
|
||||
pub trait QuantizedVector: Send + Sync {
|
||||
/// Quantize a full-precision vector
|
||||
fn quantize(vector: &[f32]) -> Self;
|
||||
|
||||
/// Calculate distance to another quantized vector
|
||||
fn distance(&self, other: &Self) -> f32;
|
||||
|
||||
/// Reconstruct approximate full-precision vector
|
||||
fn reconstruct(&self) -> Vec<f32>;
|
||||
}
|
||||
|
||||
/// Scalar quantization to int8 (4x compression)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ScalarQuantized {
|
||||
/// Quantized values (int8)
|
||||
pub data: Vec<u8>,
|
||||
/// Minimum value for dequantization
|
||||
pub min: f32,
|
||||
/// Scale factor for dequantization
|
||||
pub scale: f32,
|
||||
}
|
||||
|
||||
impl QuantizedVector for ScalarQuantized {
|
||||
fn quantize(vector: &[f32]) -> Self {
|
||||
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
|
||||
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Handle edge case where all values are the same (scale = 0)
|
||||
let scale = if (max - min).abs() < f32::EPSILON {
|
||||
1.0 // Arbitrary non-zero scale when all values are identical
|
||||
} else {
|
||||
(max - min) / 255.0
|
||||
};
|
||||
|
||||
let data = vector
|
||||
.iter()
|
||||
.map(|&v| ((v - min) / scale).round().clamp(0.0, 255.0) as u8)
|
||||
.collect();
|
||||
|
||||
Self { data, min, scale }
|
||||
}
|
||||
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
// Fast int8 distance calculation with SIMD optimization
|
||||
// Use i32 to avoid overflow: max diff is 255, and 255*255=65025 fits in i32
|
||||
|
||||
// Scale handling: We use the average of both scales for balanced comparison.
|
||||
// Using max(scale) would bias toward the vector with larger range,
|
||||
// while average provides a more symmetric distance metric.
|
||||
// This ensures distance(a, b) ≈ distance(b, a) in the reconstructed space.
|
||||
let avg_scale = (self.scale + other.scale) / 2.0;
|
||||
|
||||
// Use SIMD-optimized version for larger vectors
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
if self.data.len() >= 16 {
|
||||
return unsafe { scalar_distance_neon(&self.data, &other.data) }.sqrt() * avg_scale;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if self.data.len() >= 32 && is_x86_feature_detected!("avx2") {
|
||||
return unsafe { scalar_distance_avx2(&self.data, &other.data) }.sqrt() * avg_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar fallback with 4x loop unrolling for better ILP
|
||||
scalar_distance_scalar(&self.data, &other.data).sqrt() * avg_scale
|
||||
}
|
||||
|
||||
fn reconstruct(&self) -> Vec<f32> {
|
||||
self.data
|
||||
.iter()
|
||||
.map(|&v| self.min + (v as f32) * self.scale)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Product quantization (8-16x compression)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProductQuantized {
|
||||
/// Quantized codes (one per subspace)
|
||||
pub codes: Vec<u8>,
|
||||
/// Codebooks for each subspace
|
||||
pub codebooks: Vec<Vec<Vec<f32>>>,
|
||||
}
|
||||
|
||||
impl ProductQuantized {
|
||||
/// Train product quantization on a set of vectors
|
||||
pub fn train(
|
||||
vectors: &[Vec<f32>],
|
||||
num_subspaces: usize,
|
||||
codebook_size: usize,
|
||||
iterations: usize,
|
||||
) -> Result<Self> {
|
||||
if vectors.is_empty() {
|
||||
return Err(crate::error::RuvectorError::InvalidInput(
|
||||
"Cannot train on empty vector set".into(),
|
||||
));
|
||||
}
|
||||
if vectors[0].is_empty() {
|
||||
return Err(crate::error::RuvectorError::InvalidInput(
|
||||
"Cannot train on vectors with zero dimensions".into(),
|
||||
));
|
||||
}
|
||||
if codebook_size > 256 {
|
||||
return Err(crate::error::RuvectorError::InvalidParameter(format!(
|
||||
"Codebook size {} exceeds u8 maximum of 256",
|
||||
codebook_size
|
||||
)));
|
||||
}
|
||||
let dimensions = vectors[0].len();
|
||||
let subspace_dim = dimensions / num_subspaces;
|
||||
|
||||
let mut codebooks = Vec::with_capacity(num_subspaces);
|
||||
|
||||
// Train codebook for each subspace using k-means
|
||||
for subspace_idx in 0..num_subspaces {
|
||||
let start = subspace_idx * subspace_dim;
|
||||
let end = start + subspace_dim;
|
||||
|
||||
// Extract subspace vectors
|
||||
let subspace_vectors: Vec<Vec<f32>> =
|
||||
vectors.iter().map(|v| v[start..end].to_vec()).collect();
|
||||
|
||||
// Run k-means
|
||||
let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
|
||||
codebooks.push(codebook);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
codes: vec![],
|
||||
codebooks,
|
||||
})
|
||||
}
|
||||
|
||||
/// Quantize a vector using trained codebooks
|
||||
pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
|
||||
let num_subspaces = self.codebooks.len();
|
||||
let subspace_dim = vector.len() / num_subspaces;
|
||||
|
||||
let mut codes = Vec::with_capacity(num_subspaces);
|
||||
|
||||
for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
|
||||
let start = subspace_idx * subspace_dim;
|
||||
let end = start + subspace_dim;
|
||||
let subvector = &vector[start..end];
|
||||
|
||||
// Find nearest centroid
|
||||
let code = codebook
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by(|(_, a), (_, b)| {
|
||||
let dist_a = euclidean_squared(subvector, a);
|
||||
let dist_b = euclidean_squared(subvector, b);
|
||||
dist_a.partial_cmp(&dist_b).unwrap()
|
||||
})
|
||||
.map(|(idx, _)| idx as u8)
|
||||
.unwrap_or(0);
|
||||
|
||||
codes.push(code);
|
||||
}
|
||||
|
||||
codes
|
||||
}
|
||||
}
|
||||
|
||||
/// Int4 quantization (8x compression)
|
||||
///
|
||||
/// Quantizes f32 to 4-bit integers (0-15), packing 2 values per byte.
|
||||
/// Provides 8x compression with better precision than binary.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Int4Quantized {
|
||||
/// Packed 4-bit values (2 per byte)
|
||||
pub data: Vec<u8>,
|
||||
/// Minimum value for dequantization
|
||||
pub min: f32,
|
||||
/// Scale factor for dequantization
|
||||
pub scale: f32,
|
||||
/// Number of dimensions
|
||||
pub dimensions: usize,
|
||||
}
|
||||
|
||||
impl Int4Quantized {
|
||||
/// Quantize a vector to 4-bit representation
|
||||
pub fn quantize(vector: &[f32]) -> Self {
|
||||
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
|
||||
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
|
||||
// Handle edge case where all values are the same
|
||||
let scale = if (max - min).abs() < f32::EPSILON {
|
||||
1.0
|
||||
} else {
|
||||
(max - min) / 15.0 // 4-bit gives 0-15 range
|
||||
};
|
||||
|
||||
let dimensions = vector.len();
|
||||
let num_bytes = dimensions.div_ceil(2);
|
||||
let mut data = vec![0u8; num_bytes];
|
||||
|
||||
for (i, &v) in vector.iter().enumerate() {
|
||||
let quantized = ((v - min) / scale).round().clamp(0.0, 15.0) as u8;
|
||||
let byte_idx = i / 2;
|
||||
if i % 2 == 0 {
|
||||
// Low nibble
|
||||
data[byte_idx] |= quantized;
|
||||
} else {
|
||||
// High nibble
|
||||
data[byte_idx] |= quantized << 4;
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
data,
|
||||
min,
|
||||
scale,
|
||||
dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate distance to another Int4 quantized vector
|
||||
pub fn distance(&self, other: &Self) -> f32 {
|
||||
assert_eq!(self.dimensions, other.dimensions);
|
||||
|
||||
// Use average scale for balanced comparison
|
||||
let avg_scale = (self.scale + other.scale) / 2.0;
|
||||
let _avg_min = (self.min + other.min) / 2.0;
|
||||
|
||||
let mut sum_sq = 0i32;
|
||||
|
||||
for i in 0..self.dimensions {
|
||||
let byte_idx = i / 2;
|
||||
let shift = if i % 2 == 0 { 0 } else { 4 };
|
||||
|
||||
let a = ((self.data[byte_idx] >> shift) & 0x0F) as i32;
|
||||
let b = ((other.data[byte_idx] >> shift) & 0x0F) as i32;
|
||||
let diff = a - b;
|
||||
sum_sq += diff * diff;
|
||||
}
|
||||
|
||||
(sum_sq as f32).sqrt() * avg_scale
|
||||
}
|
||||
|
||||
/// Reconstruct approximate full-precision vector
|
||||
pub fn reconstruct(&self) -> Vec<f32> {
|
||||
let mut result = Vec::with_capacity(self.dimensions);
|
||||
|
||||
for i in 0..self.dimensions {
|
||||
let byte_idx = i / 2;
|
||||
let shift = if i % 2 == 0 { 0 } else { 4 };
|
||||
let quantized = (self.data[byte_idx] >> shift) & 0x0F;
|
||||
result.push(self.min + (quantized as f32) * self.scale);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Get compression ratio (8x for Int4)
|
||||
pub fn compression_ratio() -> f32 {
|
||||
8.0 // f32 (4 bytes) -> 4 bits (0.5 bytes)
|
||||
}
|
||||
}
|
||||
|
||||
/// Binary quantization (32x compression)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BinaryQuantized {
|
||||
/// Binary representation (1 bit per dimension, packed into bytes)
|
||||
pub bits: Vec<u8>,
|
||||
/// Number of dimensions
|
||||
pub dimensions: usize,
|
||||
}
|
||||
|
||||
impl QuantizedVector for BinaryQuantized {
|
||||
fn quantize(vector: &[f32]) -> Self {
|
||||
let dimensions = vector.len();
|
||||
let num_bytes = dimensions.div_ceil(8);
|
||||
let mut bits = vec![0u8; num_bytes];
|
||||
|
||||
for (i, &v) in vector.iter().enumerate() {
|
||||
if v > 0.0 {
|
||||
let byte_idx = i / 8;
|
||||
let bit_idx = i % 8;
|
||||
bits[byte_idx] |= 1 << bit_idx;
|
||||
}
|
||||
}
|
||||
|
||||
Self { bits, dimensions }
|
||||
}
|
||||
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
// Hamming distance using SIMD-friendly operations
|
||||
Self::hamming_distance_fast(&self.bits, &other.bits) as f32
|
||||
}
|
||||
|
||||
fn reconstruct(&self) -> Vec<f32> {
|
||||
let mut result = Vec::with_capacity(self.dimensions);
|
||||
|
||||
for i in 0..self.dimensions {
|
||||
let byte_idx = i / 8;
|
||||
let bit_idx = i % 8;
|
||||
let bit = (self.bits[byte_idx] >> bit_idx) & 1;
|
||||
result.push(if bit == 1 { 1.0 } else { -1.0 });
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryQuantized {
|
||||
/// Fast hamming distance using SIMD-optimized operations
|
||||
///
|
||||
/// Uses hardware POPCNT on x86_64 or NEON vcnt on ARM64 for optimal performance.
|
||||
/// Processes 16 bytes at a time on ARM64, 8 bytes at a time on x86_64.
|
||||
/// Falls back to 64-bit operations for remainders.
|
||||
pub fn hamming_distance_fast(a: &[u8], b: &[u8]) -> u32 {
|
||||
// Use SIMD-optimized version based on architecture
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
if a.len() >= 16 {
|
||||
return unsafe { hamming_distance_neon(a, b) };
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if a.len() >= 8 && is_x86_feature_detected!("popcnt") {
|
||||
return unsafe { hamming_distance_simd_x86(a, b) };
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar fallback using 64-bit operations
|
||||
let mut distance = 0u32;
|
||||
|
||||
// Process 8 bytes at a time using u64
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let remainder_b = chunks_b.remainder();
|
||||
|
||||
for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
|
||||
let a_u64 = u64::from_le_bytes(chunk_a.try_into().unwrap());
|
||||
let b_u64 = u64::from_le_bytes(chunk_b.try_into().unwrap());
|
||||
distance += (a_u64 ^ b_u64).count_ones();
|
||||
}
|
||||
|
||||
// Handle remainder bytes
|
||||
for (&a_byte, &b_byte) in remainder_a.iter().zip(remainder_b) {
|
||||
distance += (a_byte ^ b_byte).count_ones();
|
||||
}
|
||||
|
||||
distance
|
||||
}
|
||||
|
||||
/// Compute normalized hamming similarity (0.0 to 1.0)
|
||||
pub fn similarity(&self, other: &Self) -> f32 {
|
||||
let distance = self.distance(other);
|
||||
1.0 - (distance / self.dimensions as f32)
|
||||
}
|
||||
|
||||
/// Get compression ratio (32x for binary)
|
||||
pub fn compression_ratio() -> f32 {
|
||||
32.0 // f32 (4 bytes = 32 bits) -> 1 bit
|
||||
}
|
||||
|
||||
/// Convert to bytes for storage
|
||||
pub fn to_bytes(&self) -> &[u8] {
|
||||
&self.bits
|
||||
}
|
||||
|
||||
/// Create from bytes
|
||||
pub fn from_bytes(bits: Vec<u8>, dimensions: usize) -> Self {
|
||||
Self { bits, dimensions }
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper functions for scalar quantization distance
|
||||
// ============================================================================
|
||||
|
||||
/// Scalar fallback for scalar quantization distance (sum of squared differences)
|
||||
fn scalar_distance_scalar(a: &[u8], b: &[u8]) -> f32 {
|
||||
let mut sum_sq = 0i32;
|
||||
|
||||
// 4x loop unrolling for better ILP
|
||||
let chunks = a.len() / 4;
|
||||
for i in 0..chunks {
|
||||
let idx = i * 4;
|
||||
let d0 = (a[idx] as i32) - (b[idx] as i32);
|
||||
let d1 = (a[idx + 1] as i32) - (b[idx + 1] as i32);
|
||||
let d2 = (a[idx + 2] as i32) - (b[idx + 2] as i32);
|
||||
let d3 = (a[idx + 3] as i32) - (b[idx + 3] as i32);
|
||||
sum_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for i in (chunks * 4)..a.len() {
|
||||
let diff = (a[i] as i32) - (b[i] as i32);
|
||||
sum_sq += diff * diff;
|
||||
}
|
||||
|
||||
sum_sq as f32
|
||||
}
|
||||
|
||||
/// NEON SIMD distance for scalar quantization
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure a.len() == b.len()
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
#[inline(always)]
|
||||
unsafe fn scalar_distance_neon(a: &[u8], b: &[u8]) -> f32 {
|
||||
use std::arch::aarch64::*;
|
||||
|
||||
let len = a.len();
|
||||
let a_ptr = a.as_ptr();
|
||||
let b_ptr = b.as_ptr();
|
||||
|
||||
let mut sum = vdupq_n_s32(0);
|
||||
|
||||
// Process 8 bytes at a time
|
||||
let chunks = len / 8;
|
||||
let mut idx = 0usize;
|
||||
|
||||
for _ in 0..chunks {
|
||||
// Load 8 u8 values
|
||||
let va = vld1_u8(a_ptr.add(idx));
|
||||
let vb = vld1_u8(b_ptr.add(idx));
|
||||
|
||||
// Zero-extend u8 to u16
|
||||
let va_u16 = vmovl_u8(va);
|
||||
let vb_u16 = vmovl_u8(vb);
|
||||
|
||||
// Convert to signed for subtraction
|
||||
let va_s16 = vreinterpretq_s16_u16(va_u16);
|
||||
let vb_s16 = vreinterpretq_s16_u16(vb_u16);
|
||||
|
||||
// Compute difference
|
||||
let diff = vsubq_s16(va_s16, vb_s16);
|
||||
|
||||
// Square and accumulate
|
||||
let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
|
||||
let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
|
||||
|
||||
sum = vaddq_s32(sum, prod_lo);
|
||||
sum = vaddq_s32(sum, prod_hi);
|
||||
|
||||
idx += 8;
|
||||
}
|
||||
|
||||
let mut total = vaddvq_s32(sum);
|
||||
|
||||
// Handle remainder with bounds-check elimination
|
||||
for i in (chunks * 8)..len {
|
||||
let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
|
||||
total += diff * diff;
|
||||
}
|
||||
|
||||
total as f32
|
||||
}
|
||||
|
||||
/// AVX2 SIMD distance for scalar quantization
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
#[inline]
|
||||
unsafe fn scalar_distance_avx2(a: &[u8], b: &[u8]) -> f32 {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
let len = a.len();
|
||||
let mut sum = _mm256_setzero_si256();
|
||||
|
||||
// Process 16 bytes at a time
|
||||
let chunks = len / 16;
|
||||
for i in 0..chunks {
|
||||
let idx = i * 16;
|
||||
|
||||
// Load 16 u8 values
|
||||
let va = _mm_loadu_si128(a.as_ptr().add(idx) as *const __m128i);
|
||||
let vb = _mm_loadu_si128(b.as_ptr().add(idx) as *const __m128i);
|
||||
|
||||
// Zero-extend u8 to i16 (low and high halves)
|
||||
let va_lo = _mm256_cvtepu8_epi16(va);
|
||||
let vb_lo = _mm256_cvtepu8_epi16(vb);
|
||||
|
||||
// Compute difference
|
||||
let diff = _mm256_sub_epi16(va_lo, vb_lo);
|
||||
|
||||
// Square (multiply i16 * i16 -> i32)
|
||||
let prod = _mm256_madd_epi16(diff, diff);
|
||||
|
||||
// Accumulate
|
||||
sum = _mm256_add_epi32(sum, prod);
|
||||
}
|
||||
|
||||
// Horizontal sum
|
||||
let sum_lo = _mm256_castsi256_si128(sum);
|
||||
let sum_hi = _mm256_extracti128_si256(sum, 1);
|
||||
let sum_128 = _mm_add_epi32(sum_lo, sum_hi);
|
||||
|
||||
let shuffle = _mm_shuffle_epi32(sum_128, 0b10_11_00_01);
|
||||
let sum_64 = _mm_add_epi32(sum_128, shuffle);
|
||||
|
||||
let shuffle2 = _mm_shuffle_epi32(sum_64, 0b00_00_10_10);
|
||||
let final_sum = _mm_add_epi32(sum_64, shuffle2);
|
||||
|
||||
let mut total = _mm_cvtsi128_si32(final_sum);
|
||||
|
||||
// Handle remainder
|
||||
for i in (chunks * 16)..len {
|
||||
let diff = (a[i] as i32) - (b[i] as i32);
|
||||
total += diff * diff;
|
||||
}
|
||||
|
||||
total as f32
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b)
|
||||
.map(|(&x, &y)| {
|
||||
let diff = x - y;
|
||||
diff * diff
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let mut rng = thread_rng();
|
||||
|
||||
// Initialize centroids randomly
|
||||
let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
|
||||
|
||||
for _ in 0..iterations {
|
||||
// Assign vectors to nearest centroid
|
||||
let mut assignments = vec![Vec::new(); k];
|
||||
|
||||
for vector in vectors {
|
||||
let nearest = centroids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.min_by(|(_, a), (_, b)| {
|
||||
let dist_a = euclidean_squared(vector, a);
|
||||
let dist_b = euclidean_squared(vector, b);
|
||||
dist_a.partial_cmp(&dist_b).unwrap()
|
||||
})
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(0);
|
||||
|
||||
assignments[nearest].push(vector.clone());
|
||||
}
|
||||
|
||||
// Update centroids
|
||||
for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
|
||||
if !assigned.is_empty() {
|
||||
let dim = centroid.len();
|
||||
*centroid = vec![0.0; dim];
|
||||
|
||||
for vector in assigned {
|
||||
for (i, &v) in vector.iter().enumerate() {
|
||||
centroid[i] += v;
|
||||
}
|
||||
}
|
||||
|
||||
let count = assigned.len() as f32;
|
||||
for v in centroid.iter_mut() {
|
||||
*v /= count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
centroids
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SIMD-Optimized Distance Calculations for Quantized Vectors
|
||||
// =============================================================================
|
||||
|
||||
// NOTE: scalar_distance_scalar is already defined above (lines 404-425)
|
||||
// NOTE: scalar_distance_neon is already defined above (lines 430-473)
|
||||
// NOTE: scalar_distance_avx2 is already defined above (lines 479-540)
|
||||
// This section uses the existing implementations for consistency
|
||||
|
||||
/// SIMD-optimized hamming distance using popcnt
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "popcnt")]
|
||||
#[inline]
|
||||
unsafe fn hamming_distance_simd_x86(a: &[u8], b: &[u8]) -> u32 {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
let mut distance = 0u64;
|
||||
|
||||
// Process 8 bytes at a time using u64 with hardware popcnt
|
||||
let chunks_a = a.chunks_exact(8);
|
||||
let chunks_b = b.chunks_exact(8);
|
||||
let remainder_a = chunks_a.remainder();
|
||||
let remainder_b = chunks_b.remainder();
|
||||
|
||||
for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
|
||||
let a_u64 = u64::from_le_bytes(chunk_a.try_into().unwrap());
|
||||
let b_u64 = u64::from_le_bytes(chunk_b.try_into().unwrap());
|
||||
distance += _popcnt64((a_u64 ^ b_u64) as i64) as u64;
|
||||
}
|
||||
|
||||
// Handle remainder
|
||||
for (&a_byte, &b_byte) in remainder_a.iter().zip(remainder_b) {
|
||||
distance += (a_byte ^ b_byte).count_ones() as u64;
|
||||
}
|
||||
|
||||
distance as u32
|
||||
}
|
||||
|
||||
/// NEON-optimized hamming distance for ARM64
|
||||
///
|
||||
/// # Safety
|
||||
/// Caller must ensure a.len() == b.len()
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
#[inline(always)]
|
||||
unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
|
||||
use std::arch::aarch64::*;
|
||||
|
||||
let len = a.len();
|
||||
let a_ptr = a.as_ptr();
|
||||
let b_ptr = b.as_ptr();
|
||||
|
||||
let chunks = len / 16;
|
||||
let mut idx = 0usize;
|
||||
|
||||
let mut sum = vdupq_n_u8(0);
|
||||
|
||||
for _ in 0..chunks {
|
||||
// Load 16 bytes
|
||||
let a_vec = vld1q_u8(a_ptr.add(idx));
|
||||
let b_vec = vld1q_u8(b_ptr.add(idx));
|
||||
|
||||
// XOR and count bits using vcntq_u8 (population count)
|
||||
let xor_result = veorq_u8(a_vec, b_vec);
|
||||
let bits = vcntq_u8(xor_result);
|
||||
|
||||
// Accumulate
|
||||
sum = vaddq_u8(sum, bits);
|
||||
|
||||
idx += 16;
|
||||
}
|
||||
|
||||
// Horizontal sum
|
||||
let sum_val = vaddvq_u8(sum) as u32;
|
||||
|
||||
// Handle remainder with bounds-check elimination
|
||||
let mut remainder_sum = 0u32;
|
||||
let start = chunks * 16;
|
||||
for i in start..len {
|
||||
remainder_sum += (*a.get_unchecked(i) ^ *b.get_unchecked(i)).count_ones();
|
||||
}
|
||||
|
||||
sum_val + remainder_sum
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
// Check approximate reconstruction
|
||||
for (orig, recon) in vector.iter().zip(&reconstructed) {
|
||||
assert!((orig - recon).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization() {
|
||||
let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(quantized.dimensions, 5);
|
||||
assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_distance() {
|
||||
let v1 = vec![1.0, 1.0, 1.0, 1.0];
|
||||
let v2 = vec![1.0, 1.0, -1.0, -1.0];
|
||||
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let dist = q1.distance(&q2);
|
||||
assert_eq!(dist, 2.0); // 2 bits differ
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_roundtrip() {
|
||||
// Test that quantize -> reconstruct produces values close to original
|
||||
let test_vectors = vec![
|
||||
vec![1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
vec![-10.0, -5.0, 0.0, 5.0, 10.0],
|
||||
vec![0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
vec![100.0, 200.0, 300.0, 400.0, 500.0],
|
||||
];
|
||||
|
||||
for vector in test_vectors {
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(vector.len(), reconstructed.len());
|
||||
|
||||
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
|
||||
// With 8-bit quantization, max error is roughly (max-min)/255
|
||||
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
|
||||
let max_error = (max - min) / 255.0 * 2.0; // Allow 2x for rounding
|
||||
|
||||
assert!(
|
||||
(orig - recon).abs() < max_error,
|
||||
"Roundtrip error too large: orig={}, recon={}, error={}",
|
||||
orig,
|
||||
recon,
|
||||
(orig - recon).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_distance_symmetry() {
|
||||
// Test that distance(a, b) == distance(b, a)
|
||||
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
|
||||
let q1 = ScalarQuantized::quantize(&v1);
|
||||
let q2 = ScalarQuantized::quantize(&v2);
|
||||
|
||||
let dist_ab = q1.distance(&q2);
|
||||
let dist_ba = q2.distance(&q1);
|
||||
|
||||
// Distance should be symmetric (within floating point precision)
|
||||
assert!(
|
||||
(dist_ab - dist_ba).abs() < 0.01,
|
||||
"Distance is not symmetric: d(a,b)={}, d(b,a)={}",
|
||||
dist_ab,
|
||||
dist_ba
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_distance_different_scales() {
|
||||
// Test distance calculation with vectors that have different scales
|
||||
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // range: 4.0
|
||||
let v2 = vec![10.0, 20.0, 30.0, 40.0, 50.0]; // range: 40.0
|
||||
|
||||
let q1 = ScalarQuantized::quantize(&v1);
|
||||
let q2 = ScalarQuantized::quantize(&v2);
|
||||
|
||||
let dist_ab = q1.distance(&q2);
|
||||
let dist_ba = q2.distance(&q1);
|
||||
|
||||
// With average scaling, symmetry should be maintained
|
||||
assert!(
|
||||
(dist_ab - dist_ba).abs() < 0.01,
|
||||
"Distance with different scales not symmetric: d(a,b)={}, d(b,a)={}",
|
||||
dist_ab,
|
||||
dist_ba
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_edge_cases() {
|
||||
// Test with all same values
|
||||
let same_values = vec![5.0, 5.0, 5.0, 5.0];
|
||||
let quantized = ScalarQuantized::quantize(&same_values);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in same_values.iter().zip(reconstructed.iter()) {
|
||||
assert!((orig - recon).abs() < 0.01);
|
||||
}
|
||||
|
||||
// Test with extreme ranges
|
||||
let extreme = vec![f32::MIN / 1e10, 0.0, f32::MAX / 1e10];
|
||||
let quantized = ScalarQuantized::quantize(&extreme);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(extreme.len(), reconstructed.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_distance_symmetry() {
|
||||
// Test that binary distance is symmetric
|
||||
let v1 = vec![1.0, -1.0, 1.0, -1.0];
|
||||
let v2 = vec![1.0, 1.0, -1.0, -1.0];
|
||||
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let dist_ab = q1.distance(&q2);
|
||||
let dist_ba = q2.distance(&q1);
|
||||
|
||||
assert_eq!(
|
||||
dist_ab, dist_ba,
|
||||
"Binary distance not symmetric: d(a,b)={}, d(b,a)={}",
|
||||
dist_ab, dist_ba
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int4_quantization() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let quantized = Int4Quantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(quantized.dimensions, 5);
|
||||
// 5 dimensions = 3 bytes (2 per byte, last byte has 1)
|
||||
assert_eq!(quantized.data.len(), 3);
|
||||
|
||||
// Check approximate reconstruction
|
||||
for (orig, recon) in vector.iter().zip(&reconstructed) {
|
||||
// With 4-bit quantization, max error is roughly (max-min)/15
|
||||
let max_error = (5.0 - 1.0) / 15.0 * 2.0;
|
||||
assert!(
|
||||
(orig - recon).abs() < max_error,
|
||||
"Int4 roundtrip error too large: orig={}, recon={}",
|
||||
orig,
|
||||
recon
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int4_distance() {
|
||||
// Use vectors with different quantized patterns
|
||||
// v1 spans [0.0, 15.0] -> quantizes to [0, 1, 2, ..., 15] (linear mapping)
|
||||
// v2 spans [0.0, 15.0] but with different distribution
|
||||
let v1 = vec![0.0, 5.0, 10.0, 15.0];
|
||||
let v2 = vec![0.0, 3.0, 12.0, 15.0]; // Different middle values
|
||||
|
||||
let q1 = Int4Quantized::quantize(&v1);
|
||||
let q2 = Int4Quantized::quantize(&v2);
|
||||
|
||||
let dist = q1.distance(&q2);
|
||||
// The quantized values differ in the middle, so distance should be positive
|
||||
assert!(
|
||||
dist > 0.0,
|
||||
"Distance should be positive, got {}. q1.data={:?}, q2.data={:?}",
|
||||
dist,
|
||||
q1.data,
|
||||
q2.data
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int4_distance_symmetry() {
|
||||
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
|
||||
let q1 = Int4Quantized::quantize(&v1);
|
||||
let q2 = Int4Quantized::quantize(&v2);
|
||||
|
||||
let dist_ab = q1.distance(&q2);
|
||||
let dist_ba = q2.distance(&q1);
|
||||
|
||||
assert!(
|
||||
(dist_ab - dist_ba).abs() < 0.01,
|
||||
"Int4 distance not symmetric: d(a,b)={}, d(b,a)={}",
|
||||
dist_ab,
|
||||
dist_ba
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int4_compression_ratio() {
|
||||
assert_eq!(Int4Quantized::compression_ratio(), 8.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_fast_hamming() {
|
||||
// Test fast hamming distance with various sizes
|
||||
let a = vec![0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xAA];
|
||||
let b = vec![0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x55];
|
||||
|
||||
let distance = BinaryQuantized::hamming_distance_fast(&a, &b);
|
||||
// All bits differ: 9 bytes * 8 bits = 72 bits
|
||||
assert_eq!(distance, 72);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_similarity() {
|
||||
let v1 = vec![1.0; 8]; // All positive
|
||||
let v2 = vec![1.0; 8]; // Same
|
||||
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let sim = q1.similarity(&q2);
|
||||
assert!(
|
||||
(sim - 1.0).abs() < 0.001,
|
||||
"Same vectors should have similarity 1.0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_compression_ratio() {
|
||||
assert_eq!(BinaryQuantized::compression_ratio(), 32.0);
|
||||
}
|
||||
}
|
||||
1604
crates/ruvector-core/src/simd_intrinsics.rs
Normal file
1604
crates/ruvector-core/src/simd_intrinsics.rs
Normal file
File diff suppressed because it is too large
Load Diff
446
crates/ruvector-core/src/storage.rs
Normal file
446
crates/ruvector-core/src/storage.rs
Normal file
@@ -0,0 +1,446 @@
|
||||
//! Storage layer with redb for metadata and memory-mapped vectors
|
||||
//!
|
||||
//! This module is only available when the "storage" feature is enabled.
|
||||
//! For WASM builds, use the in-memory storage backend instead.
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
use crate::error::{Result, RuvectorError};
|
||||
#[cfg(feature = "storage")]
|
||||
use crate::types::{DbOptions, VectorEntry, VectorId};
|
||||
#[cfg(feature = "storage")]
|
||||
use bincode::config;
|
||||
#[cfg(feature = "storage")]
|
||||
use once_cell::sync::Lazy;
|
||||
#[cfg(feature = "storage")]
|
||||
use parking_lot::Mutex;
|
||||
#[cfg(feature = "storage")]
|
||||
use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
|
||||
#[cfg(feature = "storage")]
|
||||
use serde_json;
|
||||
#[cfg(feature = "storage")]
|
||||
use std::collections::HashMap;
|
||||
#[cfg(feature = "storage")]
|
||||
use std::path::{Path, PathBuf};
|
||||
#[cfg(feature = "storage")]
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
|
||||
const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
|
||||
const CONFIG_TABLE: TableDefinition<&str, &str> = TableDefinition::new("config");
|
||||
|
||||
/// Key used to store database configuration in CONFIG_TABLE
|
||||
const DB_CONFIG_KEY: &str = "__ruvector_db_config__";
|
||||
|
||||
// Global database connection pool to allow multiple VectorDB instances
|
||||
// to share the same underlying database file
|
||||
static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
|
||||
Lazy::new(|| Mutex::new(HashMap::new()));
|
||||
|
||||
/// Storage backend for vector database
|
||||
pub struct VectorStorage {
|
||||
db: Arc<Database>,
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
impl VectorStorage {
|
||||
/// Create or open a vector storage at the given path
|
||||
///
|
||||
/// This method uses a global connection pool to allow multiple VectorDB
|
||||
/// instances to share the same underlying database file, fixing the
|
||||
/// "Database already open. Cannot acquire lock" error.
|
||||
pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
|
||||
// SECURITY: Validate path to prevent directory traversal attacks
|
||||
let path_ref = path.as_ref();
|
||||
|
||||
// Create parent directories if they don't exist (needed for canonicalize)
|
||||
if let Some(parent) = path_ref.parent() {
|
||||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to absolute path first, then validate
|
||||
let path_buf = if path_ref.is_absolute() {
|
||||
path_ref.to_path_buf()
|
||||
} else {
|
||||
std::env::current_dir()
|
||||
.map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
|
||||
.join(path_ref)
|
||||
};
|
||||
|
||||
// SECURITY: Check for path traversal attempts (e.g., "../../../etc/passwd")
|
||||
// Only reject paths that contain ".." components trying to escape
|
||||
let path_str = path_ref.to_string_lossy();
|
||||
if path_str.contains("..") {
|
||||
// Verify the resolved path doesn't escape intended boundaries
|
||||
// For absolute paths, we allow them as-is (user explicitly specified)
|
||||
// For relative paths with "..", check they don't escape cwd
|
||||
if !path_ref.is_absolute() {
|
||||
if let Ok(cwd) = std::env::current_dir() {
|
||||
// Normalize the path by resolving .. components
|
||||
let mut normalized = cwd.clone();
|
||||
for component in path_ref.components() {
|
||||
match component {
|
||||
std::path::Component::ParentDir => {
|
||||
if !normalized.pop() || !normalized.starts_with(&cwd) {
|
||||
return Err(RuvectorError::InvalidPath(
|
||||
"Path traversal attempt detected".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
std::path::Component::Normal(c) => normalized.push(c),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we already have a Database instance for this path
|
||||
let db = {
|
||||
let mut pool = DB_POOL.lock();
|
||||
|
||||
if let Some(existing_db) = pool.get(&path_buf) {
|
||||
// Reuse existing database connection
|
||||
Arc::clone(existing_db)
|
||||
} else {
|
||||
// Create new database and add to pool
|
||||
let new_db = Arc::new(Database::create(&path_buf)?);
|
||||
|
||||
// Initialize tables
|
||||
let write_txn = new_db.begin_write()?;
|
||||
{
|
||||
let _ = write_txn.open_table(VECTORS_TABLE)?;
|
||||
let _ = write_txn.open_table(METADATA_TABLE)?;
|
||||
let _ = write_txn.open_table(CONFIG_TABLE)?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
pool.insert(path_buf, Arc::clone(&new_db));
|
||||
new_db
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self { db, dimensions })
|
||||
}
|
||||
|
||||
/// Insert a vector entry
|
||||
pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
|
||||
if entry.vector.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: entry.vector.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let id = entry
|
||||
.id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let mut table = write_txn.open_table(VECTORS_TABLE)?;
|
||||
|
||||
// Serialize vector data
|
||||
let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
|
||||
|
||||
table.insert(id.as_str(), vector_data.as_slice())?;
|
||||
|
||||
// Store metadata if present
|
||||
if let Some(metadata) = &entry.metadata {
|
||||
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
|
||||
let metadata_json = serde_json::to_string(metadata)
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
|
||||
meta_table.insert(id.as_str(), metadata_json.as_str())?;
|
||||
}
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Insert multiple vectors in a batch
|
||||
pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let mut ids = Vec::with_capacity(entries.len());
|
||||
|
||||
{
|
||||
let mut table = write_txn.open_table(VECTORS_TABLE)?;
|
||||
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
|
||||
|
||||
for entry in entries {
|
||||
if entry.vector.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: entry.vector.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let id = entry
|
||||
.id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
// Serialize and insert vector
|
||||
let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
|
||||
table.insert(id.as_str(), vector_data.as_slice())?;
|
||||
|
||||
// Insert metadata if present
|
||||
if let Some(metadata) = &entry.metadata {
|
||||
let metadata_json = serde_json::to_string(metadata)
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
|
||||
meta_table.insert(id.as_str(), metadata_json.as_str())?;
|
||||
}
|
||||
|
||||
ids.push(id);
|
||||
}
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Get a vector by ID
|
||||
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(VECTORS_TABLE)?;
|
||||
|
||||
let Some(vector_data) = table.get(id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let (vector, _): (Vec<f32>, usize) =
|
||||
bincode::decode_from_slice(vector_data.value(), config::standard())
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
|
||||
|
||||
// Try to get metadata
|
||||
let meta_table = read_txn.open_table(METADATA_TABLE)?;
|
||||
let metadata = if let Some(meta_data) = meta_table.get(id)? {
|
||||
let meta_str = meta_data.value();
|
||||
Some(
|
||||
serde_json::from_str(meta_str)
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Some(VectorEntry {
|
||||
id: Some(id.to_string()),
|
||||
vector,
|
||||
metadata,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Delete a vector by ID
|
||||
pub fn delete(&self, id: &str) -> Result<bool> {
|
||||
let write_txn = self.db.begin_write()?;
|
||||
let deleted;
|
||||
|
||||
{
|
||||
let mut table = write_txn.open_table(VECTORS_TABLE)?;
|
||||
deleted = table.remove(id)?.is_some();
|
||||
|
||||
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
|
||||
let _ = meta_table.remove(id)?;
|
||||
}
|
||||
|
||||
write_txn.commit()?;
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
/// Get the number of vectors stored
|
||||
pub fn len(&self) -> Result<usize> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(VECTORS_TABLE)?;
|
||||
Ok(table.len()? as usize)
|
||||
}
|
||||
|
||||
/// Check if storage is empty
|
||||
pub fn is_empty(&self) -> Result<bool> {
|
||||
Ok(self.len()? == 0)
|
||||
}
|
||||
|
||||
/// Get all vector IDs
|
||||
pub fn all_ids(&self) -> Result<Vec<VectorId>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
let table = read_txn.open_table(VECTORS_TABLE)?;
|
||||
|
||||
let mut ids = Vec::new();
|
||||
let iter = table.iter()?;
|
||||
for item in iter {
|
||||
let (key, _) = item?;
|
||||
ids.push(key.value().to_string());
|
||||
}
|
||||
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Save database configuration to persistent storage
|
||||
pub fn save_config(&self, options: &DbOptions) -> Result<()> {
|
||||
let config_json = serde_json::to_string(options)
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
|
||||
|
||||
let write_txn = self.db.begin_write()?;
|
||||
{
|
||||
let mut table = write_txn.open_table(CONFIG_TABLE)?;
|
||||
table.insert(DB_CONFIG_KEY, config_json.as_str())?;
|
||||
}
|
||||
write_txn.commit()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load database configuration from persistent storage
|
||||
pub fn load_config(&self) -> Result<Option<DbOptions>> {
|
||||
let read_txn = self.db.begin_read()?;
|
||||
|
||||
// Try to open config table - may not exist in older databases
|
||||
let table = match read_txn.open_table(CONFIG_TABLE) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
|
||||
let Some(config_data) = table.get(DB_CONFIG_KEY)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let config: DbOptions = serde_json::from_str(config_data.value())
|
||||
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
|
||||
|
||||
Ok(Some(config))
|
||||
}
|
||||
|
||||
/// Get the stored dimensions
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
}
|
||||
|
||||
// Add uuid dependency
|
||||
use uuid;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_insert_and_get() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("test1".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let id = storage.insert(&entry)?;
|
||||
assert_eq!(id, "test1");
|
||||
|
||||
let retrieved = storage.get("test1")?;
|
||||
assert!(retrieved.is_some());
|
||||
let retrieved = retrieved.unwrap();
|
||||
assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let entries = vec![
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
},
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![4.0, 5.0, 6.0],
|
||||
metadata: None,
|
||||
},
|
||||
];
|
||||
|
||||
let ids = storage.insert_batch(&entries)?;
|
||||
assert_eq!(ids.len(), 2);
|
||||
assert_eq!(storage.len()?, 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("test1".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
storage.insert(&entry)?;
|
||||
assert_eq!(storage.len()?, 1);
|
||||
|
||||
let deleted = storage.delete("test1")?;
|
||||
assert!(deleted);
|
||||
assert_eq!(storage.len()?, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_instances_same_path() -> Result<()> {
|
||||
// This test verifies the fix for the database locking bug
|
||||
// Multiple VectorStorage instances should be able to share the same database file
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir.path().join("shared.db");
|
||||
|
||||
// Create first instance
|
||||
let storage1 = VectorStorage::new(&db_path, 3)?;
|
||||
|
||||
// Insert data with first instance
|
||||
storage1.insert(&VectorEntry {
|
||||
id: Some("test1".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
// Create second instance with SAME path - this should NOT fail
|
||||
let storage2 = VectorStorage::new(&db_path, 3)?;
|
||||
|
||||
// Both instances should see the same data
|
||||
assert_eq!(storage1.len()?, 1);
|
||||
assert_eq!(storage2.len()?, 1);
|
||||
|
||||
// Insert with second instance
|
||||
storage2.insert(&VectorEntry {
|
||||
id: Some("test2".to_string()),
|
||||
vector: vec![4.0, 5.0, 6.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
// Both instances should see both records
|
||||
assert_eq!(storage1.len()?, 2);
|
||||
assert_eq!(storage2.len()?, 2);
|
||||
|
||||
// Verify data integrity
|
||||
let retrieved1 = storage1.get("test1")?;
|
||||
assert!(retrieved1.is_some());
|
||||
|
||||
let retrieved2 = storage2.get("test2")?;
|
||||
assert!(retrieved2.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
79
crates/ruvector-core/src/storage_compat.rs
Normal file
79
crates/ruvector-core/src/storage_compat.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
//! Storage compatibility layer
|
||||
//!
|
||||
//! This module provides a unified interface that works with both
|
||||
//! file-based (redb) and in-memory storage backends.
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::types::{VectorEntry, VectorId};
|
||||
|
||||
#[cfg(feature = "storage")]
|
||||
pub use crate::storage::VectorStorage;
|
||||
|
||||
#[cfg(not(feature = "storage"))]
|
||||
pub use crate::storage_memory::MemoryStorage as VectorStorage;
|
||||
|
||||
/// Unified storage trait
|
||||
pub trait StorageBackend {
|
||||
fn insert(&self, entry: &VectorEntry) -> Result<VectorId>;
|
||||
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>>;
|
||||
fn get(&self, id: &str) -> Result<Option<VectorEntry>>;
|
||||
fn delete(&self, id: &str) -> Result<bool>;
|
||||
fn len(&self) -> Result<usize>;
|
||||
fn is_empty(&self) -> Result<bool>;
|
||||
}
|
||||
|
||||
// Implement trait for redb-based storage
|
||||
#[cfg(feature = "storage")]
|
||||
impl StorageBackend for crate::storage::VectorStorage {
|
||||
fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
|
||||
self.insert(entry)
|
||||
}
|
||||
|
||||
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
|
||||
self.insert_batch(entries)
|
||||
}
|
||||
|
||||
fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
|
||||
self.get(id)
|
||||
}
|
||||
|
||||
fn delete(&self, id: &str) -> Result<bool> {
|
||||
self.delete(id)
|
||||
}
|
||||
|
||||
fn len(&self) -> Result<usize> {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> Result<bool> {
|
||||
self.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
// Implement trait for memory storage
|
||||
#[cfg(not(feature = "storage"))]
|
||||
impl StorageBackend for crate::storage_memory::MemoryStorage {
|
||||
fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
|
||||
self.insert(entry)
|
||||
}
|
||||
|
||||
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
|
||||
self.insert_batch(entries)
|
||||
}
|
||||
|
||||
fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
|
||||
self.get(id)
|
||||
}
|
||||
|
||||
fn delete(&self, id: &str) -> Result<bool> {
|
||||
self.delete(id)
|
||||
}
|
||||
|
||||
fn len(&self) -> Result<usize> {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> Result<bool> {
|
||||
self.is_empty()
|
||||
}
|
||||
}
|
||||
257
crates/ruvector-core/src/storage_memory.rs
Normal file
257
crates/ruvector-core/src/storage_memory.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! In-memory storage backend for WASM and testing
|
||||
//!
|
||||
//! This storage implementation doesn't require file system access,
|
||||
//! making it suitable for WebAssembly environments.
|
||||
|
||||
use crate::error::{Result, RuvectorError};
|
||||
use crate::types::{VectorEntry, VectorId};
|
||||
use dashmap::DashMap;
|
||||
use serde_json::Value as JsonValue;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
/// In-memory storage backend using DashMap for thread-safe concurrent access
|
||||
pub struct MemoryStorage {
|
||||
vectors: DashMap<String, Vec<f32>>,
|
||||
metadata: DashMap<String, JsonValue>,
|
||||
dimensions: usize,
|
||||
counter: AtomicU64,
|
||||
}
|
||||
|
||||
impl MemoryStorage {
|
||||
/// Create a new in-memory storage
|
||||
pub fn new(dimensions: usize) -> Result<Self> {
|
||||
Ok(Self {
|
||||
vectors: DashMap::new(),
|
||||
metadata: DashMap::new(),
|
||||
dimensions,
|
||||
counter: AtomicU64::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a new unique ID
|
||||
fn generate_id(&self) -> String {
|
||||
let id = self.counter.fetch_add(1, Ordering::SeqCst);
|
||||
format!("vec_{}", id)
|
||||
}
|
||||
|
||||
/// Insert a vector entry
|
||||
pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
|
||||
if entry.vector.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: entry.vector.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let id = entry.id.clone().unwrap_or_else(|| self.generate_id());
|
||||
|
||||
// Insert vector
|
||||
self.vectors.insert(id.clone(), entry.vector.clone());
|
||||
|
||||
// Insert metadata if present
|
||||
if let Some(metadata) = &entry.metadata {
|
||||
self.metadata.insert(
|
||||
id.clone(),
|
||||
serde_json::Value::Object(
|
||||
metadata
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Insert multiple vectors in a batch
|
||||
pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
|
||||
let mut ids = Vec::with_capacity(entries.len());
|
||||
|
||||
for entry in entries {
|
||||
if entry.vector.len() != self.dimensions {
|
||||
return Err(RuvectorError::DimensionMismatch {
|
||||
expected: self.dimensions,
|
||||
actual: entry.vector.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let id = entry.id.clone().unwrap_or_else(|| self.generate_id());
|
||||
|
||||
self.vectors.insert(id.clone(), entry.vector.clone());
|
||||
|
||||
if let Some(metadata) = &entry.metadata {
|
||||
self.metadata.insert(
|
||||
id.clone(),
|
||||
serde_json::Value::Object(
|
||||
metadata
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
ids.push(id);
|
||||
}
|
||||
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Get a vector by ID
|
||||
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
|
||||
if let Some(vector_ref) = self.vectors.get(id) {
|
||||
let vector = vector_ref.clone();
|
||||
let metadata = self.metadata.get(id).and_then(|m| {
|
||||
if let serde_json::Value::Object(map) = m.value() {
|
||||
Some(map.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Some(VectorEntry {
|
||||
id: Some(id.to_string()),
|
||||
vector,
|
||||
metadata,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete a vector by ID
|
||||
pub fn delete(&self, id: &str) -> Result<bool> {
|
||||
let vector_removed = self.vectors.remove(id).is_some();
|
||||
self.metadata.remove(id);
|
||||
Ok(vector_removed)
|
||||
}
|
||||
|
||||
/// Get the number of vectors stored
|
||||
pub fn len(&self) -> Result<usize> {
|
||||
Ok(self.vectors.len())
|
||||
}
|
||||
|
||||
/// Check if the storage is empty
|
||||
pub fn is_empty(&self) -> Result<bool> {
|
||||
Ok(self.vectors.is_empty())
|
||||
}
|
||||
|
||||
/// Get all vector IDs (for iteration)
|
||||
pub fn keys(&self) -> Vec<String> {
|
||||
self.vectors
|
||||
.iter()
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all vector IDs (alias for keys, for API compatibility with VectorStorage)
|
||||
pub fn all_ids(&self) -> Result<Vec<String>> {
|
||||
Ok(self.keys())
|
||||
}
|
||||
|
||||
/// Clear all data
|
||||
pub fn clear(&self) -> Result<()> {
|
||||
self.vectors.clear();
|
||||
self.metadata.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_insert_and_get() {
|
||||
let storage = MemoryStorage::new(128).unwrap();
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("test_1".to_string()),
|
||||
vector: vec![0.1; 128],
|
||||
metadata: Some(json!({"key": "value"})),
|
||||
};
|
||||
|
||||
let id = storage.insert(&entry).unwrap();
|
||||
assert_eq!(id, "test_1");
|
||||
|
||||
let retrieved = storage.get("test_1").unwrap().unwrap();
|
||||
assert_eq!(retrieved.vector.len(), 128);
|
||||
assert!(retrieved.metadata.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert() {
|
||||
let storage = MemoryStorage::new(64).unwrap();
|
||||
|
||||
let entries: Vec<_> = (0..10)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 64],
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ids = storage.insert_batch(&entries).unwrap();
|
||||
assert_eq!(ids.len(), 10);
|
||||
assert_eq!(storage.len().unwrap(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete() {
|
||||
let storage = MemoryStorage::new(32).unwrap();
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("delete_me".to_string()),
|
||||
vector: vec![1.0; 32],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
storage.insert(&entry).unwrap();
|
||||
assert_eq!(storage.len().unwrap(), 1);
|
||||
|
||||
let deleted = storage.delete("delete_me").unwrap();
|
||||
assert!(deleted);
|
||||
assert_eq!(storage.len().unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_id_generation() {
|
||||
let storage = MemoryStorage::new(16).unwrap();
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: None,
|
||||
vector: vec![0.5; 16],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let id1 = storage.insert(&entry).unwrap();
|
||||
let id2 = storage.insert(&entry).unwrap();
|
||||
|
||||
assert_ne!(id1, id2);
|
||||
assert!(id1.starts_with("vec_"));
|
||||
assert!(id2.starts_with("vec_"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch() {
|
||||
let storage = MemoryStorage::new(128).unwrap();
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("bad".to_string()),
|
||||
vector: vec![0.1; 64], // Wrong dimension
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let result = storage.insert(&entry);
|
||||
assert!(result.is_err());
|
||||
|
||||
if let Err(RuvectorError::DimensionMismatch { expected, actual }) = result {
|
||||
assert_eq!(expected, 128);
|
||||
assert_eq!(actual, 64);
|
||||
} else {
|
||||
panic!("Expected DimensionMismatch error");
|
||||
}
|
||||
}
|
||||
}
|
||||
126
crates/ruvector-core/src/types.rs
Normal file
126
crates/ruvector-core/src/types.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
//! Core types and data structures
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Unique identifier for vectors
|
||||
pub type VectorId = String;
|
||||
|
||||
/// Distance metric for similarity calculation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DistanceMetric {
|
||||
/// Euclidean (L2) distance
|
||||
Euclidean,
|
||||
/// Cosine similarity (converted to distance)
|
||||
Cosine,
|
||||
/// Dot product (converted to distance for maximization)
|
||||
DotProduct,
|
||||
/// Manhattan (L1) distance
|
||||
Manhattan,
|
||||
}
|
||||
|
||||
/// Vector entry with metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorEntry {
|
||||
/// Optional ID (auto-generated if not provided)
|
||||
pub id: Option<VectorId>,
|
||||
/// Vector data
|
||||
pub vector: Vec<f32>,
|
||||
/// Optional metadata
|
||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Search query parameters
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchQuery {
|
||||
/// Query vector
|
||||
pub vector: Vec<f32>,
|
||||
/// Number of results to return (top-k)
|
||||
pub k: usize,
|
||||
/// Optional metadata filters
|
||||
pub filter: Option<HashMap<String, serde_json::Value>>,
|
||||
/// Optional ef_search parameter for HNSW (overrides default)
|
||||
pub ef_search: Option<usize>,
|
||||
}
|
||||
|
||||
/// Search result with similarity score
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchResult {
|
||||
/// Vector ID
|
||||
pub id: VectorId,
|
||||
/// Distance/similarity score (lower is better for distance metrics)
|
||||
pub score: f32,
|
||||
/// Vector data (optional)
|
||||
pub vector: Option<Vec<f32>>,
|
||||
/// Metadata (optional)
|
||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Database configuration options
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DbOptions {
|
||||
/// Vector dimensions
|
||||
pub dimensions: usize,
|
||||
/// Distance metric
|
||||
pub distance_metric: DistanceMetric,
|
||||
/// Storage path
|
||||
pub storage_path: String,
|
||||
/// HNSW configuration
|
||||
pub hnsw_config: Option<HnswConfig>,
|
||||
/// Quantization configuration
|
||||
pub quantization: Option<QuantizationConfig>,
|
||||
}
|
||||
|
||||
/// HNSW index configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HnswConfig {
|
||||
/// Number of connections per layer (M)
|
||||
pub m: usize,
|
||||
/// Size of dynamic candidate list during construction (efConstruction)
|
||||
pub ef_construction: usize,
|
||||
/// Size of dynamic candidate list during search (efSearch)
|
||||
pub ef_search: usize,
|
||||
/// Maximum number of elements
|
||||
pub max_elements: usize,
|
||||
}
|
||||
|
||||
impl Default for HnswConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 100,
|
||||
max_elements: 10_000_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantization configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum QuantizationConfig {
|
||||
/// No quantization (full precision)
|
||||
None,
|
||||
/// Scalar quantization to int8 (4x compression)
|
||||
Scalar,
|
||||
/// Product quantization
|
||||
Product {
|
||||
/// Number of subspaces
|
||||
subspaces: usize,
|
||||
/// Codebook size (typically 256)
|
||||
k: usize,
|
||||
},
|
||||
/// Binary quantization (32x compression)
|
||||
Binary,
|
||||
}
|
||||
|
||||
impl Default for DbOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dimensions: 384,
|
||||
distance_metric: DistanceMetric::Cosine,
|
||||
storage_path: "./ruvector.db".to_string(),
|
||||
hnsw_config: Some(HnswConfig::default()),
|
||||
quantization: Some(QuantizationConfig::Scalar),
|
||||
}
|
||||
}
|
||||
}
|
||||
391
crates/ruvector-core/src/vector_db.rs
Normal file
391
crates/ruvector-core/src/vector_db.rs
Normal file
@@ -0,0 +1,391 @@
|
||||
//! Main VectorDB interface
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::index::flat::FlatIndex;
|
||||
|
||||
#[cfg(feature = "hnsw")]
|
||||
use crate::index::hnsw::HnswIndex;
|
||||
|
||||
use crate::index::VectorIndex;
|
||||
use crate::types::*;
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::Arc;
|
||||
|
||||
// Import appropriate storage backend based on features
|
||||
#[cfg(feature = "storage")]
|
||||
use crate::storage::VectorStorage;
|
||||
|
||||
#[cfg(not(feature = "storage"))]
|
||||
use crate::storage_memory::MemoryStorage as VectorStorage;
|
||||
|
||||
/// Main vector database
|
||||
pub struct VectorDB {
|
||||
storage: Arc<VectorStorage>,
|
||||
index: Arc<RwLock<Box<dyn VectorIndex>>>,
|
||||
options: DbOptions,
|
||||
}
|
||||
|
||||
impl VectorDB {
|
||||
/// Create a new vector database with the given options
|
||||
///
|
||||
/// If a storage path is provided and contains persisted vectors,
|
||||
/// the HNSW index will be automatically rebuilt from storage.
|
||||
/// If opening an existing database, the stored configuration (dimensions,
|
||||
/// distance metric, etc.) will be used instead of the provided options.
|
||||
#[allow(unused_mut)] // `options` is mutated only when feature = "storage"
|
||||
pub fn new(mut options: DbOptions) -> Result<Self> {
|
||||
#[cfg(feature = "storage")]
|
||||
let storage = {
|
||||
// First, try to load existing configuration from the database
|
||||
// We create a temporary storage to check for config
|
||||
let temp_storage = VectorStorage::new(&options.storage_path, options.dimensions)?;
|
||||
|
||||
let stored_config = temp_storage.load_config()?;
|
||||
|
||||
if let Some(config) = stored_config {
|
||||
// Existing database - use stored configuration
|
||||
tracing::info!(
|
||||
"Loading existing database with {} dimensions",
|
||||
config.dimensions
|
||||
);
|
||||
options = DbOptions {
|
||||
// Keep the provided storage path (may have changed)
|
||||
storage_path: options.storage_path.clone(),
|
||||
// Use stored configuration for everything else
|
||||
dimensions: config.dimensions,
|
||||
distance_metric: config.distance_metric,
|
||||
hnsw_config: config.hnsw_config,
|
||||
quantization: config.quantization,
|
||||
};
|
||||
// Recreate storage with correct dimensions
|
||||
Arc::new(VectorStorage::new(
|
||||
&options.storage_path,
|
||||
options.dimensions,
|
||||
)?)
|
||||
} else {
|
||||
// New database - save the configuration
|
||||
tracing::info!(
|
||||
"Creating new database with {} dimensions",
|
||||
options.dimensions
|
||||
);
|
||||
temp_storage.save_config(&options)?;
|
||||
Arc::new(temp_storage)
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "storage"))]
|
||||
let storage = Arc::new(VectorStorage::new(options.dimensions)?);
|
||||
|
||||
// Choose index based on configuration and available features
|
||||
#[allow(unused_mut)] // `index` is mutated only when feature = "storage"
|
||||
let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
|
||||
#[cfg(feature = "hnsw")]
|
||||
{
|
||||
Box::new(HnswIndex::new(
|
||||
options.dimensions,
|
||||
options.distance_metric,
|
||||
hnsw_config.clone(),
|
||||
)?)
|
||||
}
|
||||
#[cfg(not(feature = "hnsw"))]
|
||||
{
|
||||
// Fall back to flat index if HNSW is not available
|
||||
tracing::warn!("HNSW requested but not available (WASM build), using flat index");
|
||||
Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
|
||||
}
|
||||
} else {
|
||||
Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
|
||||
};
|
||||
|
||||
// Rebuild index from persisted vectors if storage is not empty
|
||||
// This fixes the bug where search() returns empty results after restart
|
||||
#[cfg(feature = "storage")]
|
||||
{
|
||||
let stored_ids = storage.all_ids()?;
|
||||
if !stored_ids.is_empty() {
|
||||
tracing::info!(
|
||||
"Rebuilding index from {} persisted vectors",
|
||||
stored_ids.len()
|
||||
);
|
||||
|
||||
// Batch load all vectors for efficient index rebuilding
|
||||
let mut entries = Vec::with_capacity(stored_ids.len());
|
||||
for id in stored_ids {
|
||||
if let Some(entry) = storage.get(&id)? {
|
||||
entries.push((id, entry.vector));
|
||||
}
|
||||
}
|
||||
|
||||
// Add all vectors to index in batch for better performance
|
||||
index.add_batch(entries)?;
|
||||
|
||||
tracing::info!("Index rebuilt successfully");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
storage,
|
||||
index: Arc::new(RwLock::new(index)),
|
||||
options,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with default options
|
||||
pub fn with_dimensions(dimensions: usize) -> Result<Self> {
|
||||
let options = DbOptions {
|
||||
dimensions,
|
||||
..DbOptions::default()
|
||||
};
|
||||
Self::new(options)
|
||||
}
|
||||
|
||||
/// Insert a vector entry
|
||||
pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
|
||||
let id = self.storage.insert(&entry)?;
|
||||
|
||||
// Add to index
|
||||
let mut index = self.index.write();
|
||||
index.add(id.clone(), entry.vector)?;
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Insert multiple vectors in a batch
|
||||
pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
|
||||
let ids = self.storage.insert_batch(&entries)?;
|
||||
|
||||
// Add to index
|
||||
let mut index = self.index.write();
|
||||
let index_entries: Vec<_> = ids
|
||||
.iter()
|
||||
.zip(entries.iter())
|
||||
.map(|(id, entry)| (id.clone(), entry.vector.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(index_entries)?;
|
||||
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
/// Search for similar vectors
|
||||
pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
|
||||
let index = self.index.read();
|
||||
let mut results = index.search(&query.vector, query.k)?;
|
||||
|
||||
// Enrich results with full data if needed
|
||||
for result in &mut results {
|
||||
if let Ok(Some(entry)) = self.storage.get(&result.id) {
|
||||
result.vector = Some(entry.vector);
|
||||
result.metadata = entry.metadata;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply metadata filters if specified
|
||||
if let Some(filter) = &query.filter {
|
||||
results.retain(|r| {
|
||||
if let Some(metadata) = &r.metadata {
|
||||
filter
|
||||
.iter()
|
||||
.all(|(key, value)| metadata.get(key).is_some_and(|v| v == value))
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Delete a vector by ID
|
||||
pub fn delete(&self, id: &str) -> Result<bool> {
|
||||
let deleted_storage = self.storage.delete(id)?;
|
||||
|
||||
if deleted_storage {
|
||||
let mut index = self.index.write();
|
||||
let _ = index.remove(&id.to_string())?;
|
||||
}
|
||||
|
||||
Ok(deleted_storage)
|
||||
}
|
||||
|
||||
/// Get a vector by ID
|
||||
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
|
||||
self.storage.get(id)
|
||||
}
|
||||
|
||||
/// Get the number of vectors
|
||||
pub fn len(&self) -> Result<usize> {
|
||||
self.storage.len()
|
||||
}
|
||||
|
||||
/// Check if database is empty
|
||||
pub fn is_empty(&self) -> Result<bool> {
|
||||
self.storage.is_empty()
|
||||
}
|
||||
|
||||
/// Get database options
|
||||
pub fn options(&self) -> &DbOptions {
|
||||
&self.options
|
||||
}
|
||||
|
||||
/// Get all vector IDs (for iteration/serialization)
|
||||
pub fn keys(&self) -> Result<Vec<String>> {
|
||||
self.storage.all_ids()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::Path;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_vector_db_creation() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
assert!(db.is_empty()?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_and_search() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
options.distance_metric = DistanceMetric::Euclidean; // Use Euclidean for clearer test
|
||||
options.hnsw_config = None; // Use flat index for testing
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
// Insert vectors
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v1".to_string()),
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v2".to_string()),
|
||||
vector: vec![0.0, 1.0, 0.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v3".to_string()),
|
||||
vector: vec![0.0, 0.0, 1.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
// Search for exact match
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
k: 2,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})?;
|
||||
|
||||
assert!(results.len() >= 1);
|
||||
assert_eq!(results[0].id, "v1", "First result should be exact match");
|
||||
assert!(
|
||||
results[0].score < 0.01,
|
||||
"Exact match should have ~0 distance"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test that search works after simulated restart (new VectorDB instance)
|
||||
/// This verifies the fix for issue #30: HNSW index not rebuilt from storage
|
||||
#[test]
|
||||
#[cfg(feature = "storage")]
|
||||
fn test_search_after_restart() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
|
||||
|
||||
// Phase 1: Create database and insert vectors
|
||||
{
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = db_path.clone();
|
||||
options.dimensions = 3;
|
||||
options.distance_metric = DistanceMetric::Euclidean;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v1".to_string()),
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v2".to_string()),
|
||||
vector: vec![0.0, 1.0, 0.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v3".to_string()),
|
||||
vector: vec![0.7, 0.7, 0.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
// Verify search works before "restart"
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![0.8, 0.6, 0.0],
|
||||
k: 3,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})?;
|
||||
assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
|
||||
}
|
||||
// db is dropped here, simulating application shutdown
|
||||
|
||||
// Phase 2: Create new database instance (simulates restart)
|
||||
{
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = db_path.clone();
|
||||
options.dimensions = 3;
|
||||
options.distance_metric = DistanceMetric::Euclidean;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
// Verify vectors are still accessible
|
||||
assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
|
||||
|
||||
// Verify get() works
|
||||
let v1 = db.get("v1")?;
|
||||
assert!(v1.is_some(), "get() should work after restart");
|
||||
|
||||
// Verify search() works - THIS WAS THE BUG
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![0.8, 0.6, 0.0],
|
||||
k: 3,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(
|
||||
results.len(),
|
||||
3,
|
||||
"search() should return results after restart (was returning 0 before fix)"
|
||||
);
|
||||
|
||||
// v3 should be closest to query [0.8, 0.6, 0.0]
|
||||
assert_eq!(
|
||||
results[0].id, "v3",
|
||||
"v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
251
crates/ruvector-core/tests/README.md
Normal file
251
crates/ruvector-core/tests/README.md
Normal file
@@ -0,0 +1,251 @@
|
||||
# Ruvector Core Test Suite
|
||||
|
||||
## Overview
|
||||
|
||||
This directory contains a comprehensive Test-Driven Development (TDD) test suite following the London School approach. The test suite covers unit tests, integration tests, property-based tests, stress tests, and concurrent access tests.
|
||||
|
||||
## Test Files
|
||||
|
||||
### 1. `unit_tests.rs` - Unit Tests with Mocking (London School)
|
||||
Comprehensive unit tests using `mockall` for mocking dependencies:
|
||||
|
||||
- **Distance Metric Tests**: Tests for all 4 distance metrics (Euclidean, Cosine, Dot Product, Manhattan)
|
||||
- Self-distance verification
|
||||
- Symmetry properties
|
||||
- Orthogonal and parallel vector cases
|
||||
- Dimension mismatch error handling
|
||||
|
||||
- **Quantization Tests**: Tests for scalar and binary quantization
|
||||
- Round-trip reconstruction accuracy
|
||||
- Distance calculation correctness
|
||||
- Sign preservation (binary quantization)
|
||||
- Hamming distance validation
|
||||
|
||||
- **Storage Layer Tests**: Tests for VectorStorage
|
||||
- Insert with explicit and auto-generated IDs
|
||||
- Metadata handling
|
||||
- Dimension validation
|
||||
- Batch operations
|
||||
- Delete operations
|
||||
- Error cases (non-existent vectors, dimension mismatches)
|
||||
|
||||
- **VectorDB Tests**: High-level API tests
|
||||
- Empty database operations
|
||||
- Insert/delete with len() tracking
|
||||
- Search functionality
|
||||
- Metadata filtering
|
||||
- Batch insert operations
|
||||
|
||||
### 2. `integration_tests.rs` - End-to-End Integration Tests
|
||||
Full workflow tests that verify all components work together:
|
||||
|
||||
- **Complete Workflows**: Insert + search + retrieve with metadata
|
||||
- **Large Batch Operations**: 10K+ vector batch insertions
|
||||
- **Persistence**: Database save and reload verification
|
||||
- **Mixed Operations**: Combined insert, delete, and search operations
|
||||
- **Distance Metrics**: Tests for all 4 metrics end-to-end
|
||||
- **HNSW Configurations**: Different HNSW parameter combinations
|
||||
- **Metadata Filtering**: Complex filtering scenarios
|
||||
- **Error Handling**: Dimension validation, wrong query dimensions
|
||||
|
||||
### 3. `property_tests.rs` - Property-Based Tests (proptest)
|
||||
Mathematical property verification using proptest:
|
||||
|
||||
- **Distance Metric Properties**:
|
||||
- Self-distance is zero
|
||||
- Symmetry: d(a,b) = d(b,a)
|
||||
- Triangle inequality: d(a,c) ≤ d(a,b) + d(b,c)
|
||||
- Non-negativity
|
||||
- Scale invariance (cosine)
|
||||
- Translation invariance (Euclidean)
|
||||
|
||||
- **Quantization Properties**:
|
||||
- Round-trip reconstruction bounds
|
||||
- Sign preservation (binary)
|
||||
- Self-distance is zero
|
||||
- Symmetry
|
||||
- Distance bounds
|
||||
|
||||
- **Batch Operations**:
|
||||
- Consistency between batch and individual operations
|
||||
|
||||
- **Dimension Handling**:
|
||||
- Mismatch error detection
|
||||
- Success on matching dimensions
|
||||
|
||||
### 4. `stress_tests.rs` - Scalability and Performance Stress Tests
|
||||
Tests that push the system to its limits:
|
||||
|
||||
- **Million Vector Insertion** (ignored by default): Insert 1M vectors in batches
|
||||
- **Concurrent Queries**: 10 threads × 100 queries each
|
||||
- **Concurrent Mixed Operations**: Simultaneous readers and writers
|
||||
- **Memory Pressure**: Large 2048-dimensional vectors
|
||||
- **Error Recovery**: Invalid operations don't crash the system
|
||||
- **Repeated Operations**: Same operation executed many times
|
||||
- **Extreme Parameters**: k values larger than database size
|
||||
|
||||
### 5. `concurrent_tests.rs` - Thread-Safety Tests
|
||||
Multi-threaded access patterns:
|
||||
|
||||
- **Concurrent Reads**: Multiple threads reading simultaneously
|
||||
- **Concurrent Writes**: Non-overlapping writes from multiple threads
|
||||
- **Mixed Read/Write**: Concurrent reads and writes
|
||||
- **Delete and Insert**: Simultaneous deletes and inserts
|
||||
- **Search and Insert**: Searching while inserting
|
||||
- **Batch Atomicity**: Verifying batch operations are atomic
|
||||
- **Read-Write Consistency**: Ensuring no data corruption
|
||||
- **Metadata Updates**: Concurrent metadata modifications
|
||||
|
||||
## Benchmarks
|
||||
|
||||
### 6. `benches/quantization_bench.rs` - Quantization Performance
|
||||
Criterion benchmarks for quantization operations:
|
||||
|
||||
- Scalar quantization encode/decode/distance
|
||||
- Binary quantization encode/decode/distance
|
||||
- Compression ratio comparisons
|
||||
|
||||
### 7. `benches/batch_operations.rs` - Batch Operation Performance
|
||||
Criterion benchmarks for batch operations:
|
||||
|
||||
- Batch insert at various scales (100, 1K, 10K)
|
||||
- Individual vs batch insert comparison
|
||||
- Parallel search performance
|
||||
- Batch delete operations
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Run All Tests
|
||||
```bash
|
||||
cargo test --package ruvector-core
|
||||
```
|
||||
|
||||
### Run Specific Test Suites
|
||||
```bash
|
||||
# Unit tests only
|
||||
cargo test --test unit_tests
|
||||
|
||||
# Integration tests only
|
||||
cargo test --test integration_tests
|
||||
|
||||
# Property tests only
|
||||
cargo test --test property_tests
|
||||
|
||||
# Concurrent tests only
|
||||
cargo test --test concurrent_tests
|
||||
|
||||
# Stress tests (including ignored tests)
|
||||
cargo test --test stress_tests -- --ignored --test-threads=1
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
```bash
|
||||
# Distance metrics (existing)
|
||||
cargo bench --bench distance_metrics
|
||||
|
||||
# HNSW search (existing)
|
||||
cargo bench --bench hnsw_search
|
||||
|
||||
# Quantization (new)
|
||||
cargo bench --bench quantization_bench
|
||||
|
||||
# Batch operations (new)
|
||||
cargo bench --bench batch_operations
|
||||
```
|
||||
|
||||
### Generate Coverage Report
|
||||
```bash
|
||||
# Install tarpaulin if not already installed
|
||||
cargo install cargo-tarpaulin
|
||||
|
||||
# Generate HTML coverage report
|
||||
cargo tarpaulin --out Html --output-dir target/coverage
|
||||
|
||||
# Open coverage report
|
||||
open target/coverage/index.html
|
||||
```
|
||||
|
||||
## Test Coverage Goals
|
||||
|
||||
- **Target**: 90%+ code coverage
|
||||
- **Focus Areas**:
|
||||
- Distance calculations
|
||||
- Index operations
|
||||
- Storage layer
|
||||
- Error handling paths
|
||||
- Edge cases
|
||||
|
||||
## Known Issues
|
||||
|
||||
As of the current implementation, there are pre-existing compilation errors in the codebase that prevent some tests from running:
|
||||
|
||||
1. **HNSW Index**: `DataId::new` construction issues in `src/index/hnsw.rs`
|
||||
2. **AgenticDB**: Missing `Encode`/`Decode` trait implementations for `ReflexionEpisode`
|
||||
|
||||
These issues exist in the main codebase and need to be fixed before the full test suite can execute.
|
||||
|
||||
## Test Organization
|
||||
|
||||
Tests are organized by purpose and scope:
|
||||
|
||||
1. **Unit Tests**: Test individual components in isolation with mocking
|
||||
2. **Integration Tests**: Test component interactions and workflows
|
||||
3. **Property Tests**: Test mathematical properties and invariants
|
||||
4. **Stress Tests**: Test performance limits and edge cases
|
||||
5. **Concurrent Tests**: Test thread-safety and concurrent access patterns
|
||||
|
||||
## Dependencies
|
||||
|
||||
Test dependencies (in `Cargo.toml`):
|
||||
|
||||
```toml
|
||||
[dev-dependencies]
|
||||
criterion = { workspace = true }
|
||||
proptest = { workspace = true }
|
||||
mockall = { workspace = true }
|
||||
tempfile = "3.13"
|
||||
tracing-subscriber = { workspace = true }
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
When adding new tests:
|
||||
|
||||
1. Follow the existing structure and naming conventions
|
||||
2. Add tests to the appropriate file (unit, integration, property, etc.)
|
||||
3. Document the test purpose clearly
|
||||
4. Ensure tests are deterministic and don't depend on timing
|
||||
5. Use `tempdir()` for database paths in tests
|
||||
6. Clean up resources properly
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Recommended CI pipeline:
|
||||
|
||||
```yaml
|
||||
test:
|
||||
script:
|
||||
- cargo test --all-features
|
||||
- cargo tarpaulin --out Xml
|
||||
coverage: '/\d+\.\d+% coverage/'
|
||||
|
||||
bench:
|
||||
script:
|
||||
- cargo bench --no-run
|
||||
```
|
||||
|
||||
## Performance Expectations
|
||||
|
||||
Based on stress tests:
|
||||
|
||||
- **Insert**: ~10K vectors/second (batch mode)
|
||||
- **Search**: ~1K queries/second (k=10, HNSW)
|
||||
- **Concurrent**: 10+ threads without performance degradation
|
||||
- **Memory**: ~4KB per 384-dim vector (uncompressed)
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Mockall Documentation](https://docs.rs/mockall/)
|
||||
- [Proptest Guide](https://altsysrq.github.io/proptest-book/)
|
||||
- [Criterion.rs Guide](https://bheisler.github.io/criterion.rs/book/)
|
||||
- [Cargo Tarpaulin](https://github.com/xd009642/tarpaulin)
|
||||
550
crates/ruvector-core/tests/advanced_features_integration.rs
Normal file
550
crates/ruvector-core/tests/advanced_features_integration.rs
Normal file
@@ -0,0 +1,550 @@
|
||||
//! Integration tests for advanced features
|
||||
//!
|
||||
//! Tests Enhanced PQ, Filtered Search, MMR, Hybrid Search, and Conformal Prediction
|
||||
//! across multiple vector dimensions (128D, 384D, 768D)
|
||||
|
||||
use ruvector_core::advanced_features::*;
|
||||
use ruvector_core::types::{DistanceMetric, SearchResult};
|
||||
use ruvector_core::{Result, RuvectorError};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Helper function to generate random vectors
|
||||
fn generate_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
|
||||
use rand::Rng;
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
(0..count)
|
||||
.map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Helper function to normalize vectors
|
||||
fn normalize_vector(v: &[f32]) -> Vec<f32> {
|
||||
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
v.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enhanced_pq_128d() {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 1000;
|
||||
|
||||
let config = PQConfig {
|
||||
num_subspaces: 8,
|
||||
codebook_size: 256,
|
||||
num_iterations: 10,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
|
||||
|
||||
// Generate training data
|
||||
let training_data = generate_vectors(num_vectors, dimensions);
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
// Test encoding and search
|
||||
let query = normalize_vector(&generate_vectors(1, dimensions)[0]);
|
||||
|
||||
// Add quantized vectors
|
||||
for (i, vector) in training_data.iter().enumerate() {
|
||||
pq.add_quantized(format!("vec_{}", i), vector).unwrap();
|
||||
}
|
||||
|
||||
// Perform search
|
||||
let results = pq.search(&query, 10).unwrap();
|
||||
assert_eq!(results.len(), 10);
|
||||
|
||||
// Test compression ratio
|
||||
let compression_ratio = pq.compression_ratio();
|
||||
assert!(compression_ratio >= 8.0); // Should be 16x for 128D with 8 subspaces
|
||||
|
||||
println!(
|
||||
"✓ Enhanced PQ 128D: compression ratio = {:.1}x",
|
||||
compression_ratio
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enhanced_pq_384d() {
|
||||
let dimensions = 384;
|
||||
let num_vectors = 500;
|
||||
|
||||
let config = PQConfig {
|
||||
num_subspaces: 8,
|
||||
codebook_size: 256,
|
||||
num_iterations: 10,
|
||||
metric: DistanceMetric::Cosine,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
|
||||
|
||||
// Generate training data
|
||||
let training_data: Vec<Vec<f32>> = generate_vectors(num_vectors, dimensions)
|
||||
.into_iter()
|
||||
.map(|v| normalize_vector(&v))
|
||||
.collect();
|
||||
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
// Test reconstruction
|
||||
let test_vector = &training_data[0];
|
||||
let codes = pq.encode(test_vector).unwrap();
|
||||
let reconstructed = pq.reconstruct(&codes).unwrap();
|
||||
|
||||
assert_eq!(reconstructed.len(), dimensions);
|
||||
|
||||
// Calculate reconstruction error
|
||||
let error: f32 = test_vector
|
||||
.iter()
|
||||
.zip(&reconstructed)
|
||||
.map(|(a, b)| (a - b).powi(2))
|
||||
.sum::<f32>()
|
||||
.sqrt();
|
||||
|
||||
println!("✓ Enhanced PQ 384D: reconstruction error = {:.4}", error);
|
||||
assert!(error < 5.0); // Reasonable reconstruction error
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enhanced_pq_768d() {
|
||||
let dimensions = 768;
|
||||
let num_vectors = 300; // Increased to ensure we have enough vectors for search
|
||||
|
||||
let config = PQConfig {
|
||||
num_subspaces: 16,
|
||||
codebook_size: 256,
|
||||
num_iterations: 10,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
|
||||
|
||||
let training_data = generate_vectors(num_vectors, dimensions);
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
// Test lookup table creation
|
||||
let query = generate_vectors(1, dimensions)[0].clone();
|
||||
let lookup_table = pq.create_lookup_table(&query).unwrap();
|
||||
|
||||
assert_eq!(lookup_table.tables.len(), 16);
|
||||
assert_eq!(lookup_table.tables[0].len(), 256);
|
||||
|
||||
println!("✓ Enhanced PQ 768D: lookup table created successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filtered_search_pre_filter() {
|
||||
use serde_json::json;
|
||||
|
||||
// Create metadata store
|
||||
let mut metadata_store = HashMap::new();
|
||||
for i in 0..100 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"category".to_string(),
|
||||
json!(if i % 3 == 0 { "A" } else { "B" }),
|
||||
);
|
||||
metadata.insert("price".to_string(), json!(i as f32 * 10.0));
|
||||
metadata_store.insert(format!("vec_{}", i), metadata);
|
||||
}
|
||||
|
||||
// Create filter: category == "A" AND price < 500
|
||||
let filter = FilterExpression::And(vec![
|
||||
FilterExpression::Eq("category".to_string(), json!("A")),
|
||||
FilterExpression::Lt("price".to_string(), json!(500.0)),
|
||||
]);
|
||||
|
||||
let search = FilteredSearch::new(filter, FilterStrategy::PreFilter, metadata_store);
|
||||
|
||||
// Test pre-filtering
|
||||
let filtered_ids = search.get_filtered_ids();
|
||||
assert!(!filtered_ids.is_empty());
|
||||
assert!(filtered_ids.len() < 50); // Should be selective
|
||||
|
||||
println!(
|
||||
"✓ Filtered Search (Pre-filter): {} matching documents",
|
||||
filtered_ids.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filtered_search_auto_strategy() {
|
||||
use serde_json::json;
|
||||
|
||||
let mut metadata_store = HashMap::new();
|
||||
for i in 0..1000 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("id".to_string(), json!(i));
|
||||
metadata_store.insert(format!("vec_{}", i), metadata);
|
||||
}
|
||||
|
||||
// Highly selective filter (should choose pre-filter)
|
||||
let selective_filter = FilterExpression::Eq("id".to_string(), json!(42));
|
||||
let search1 = FilteredSearch::new(
|
||||
selective_filter,
|
||||
FilterStrategy::Auto,
|
||||
metadata_store.clone(),
|
||||
);
|
||||
assert_eq!(search1.auto_select_strategy(), FilterStrategy::PreFilter);
|
||||
|
||||
// Less selective filter (should choose post-filter)
|
||||
let broad_filter = FilterExpression::Gte("id".to_string(), json!(0));
|
||||
let search2 = FilteredSearch::new(broad_filter, FilterStrategy::Auto, metadata_store);
|
||||
assert_eq!(search2.auto_select_strategy(), FilterStrategy::PostFilter);
|
||||
|
||||
println!("✓ Filtered Search: automatic strategy selection working");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_diversity_128d() {
|
||||
let dimensions = 128;
|
||||
|
||||
let config = MMRConfig {
|
||||
lambda: 0.5, // Balance relevance and diversity
|
||||
metric: DistanceMetric::Cosine,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
|
||||
let mmr = MMRSearch::new(config).unwrap();
|
||||
|
||||
// Create query and candidates
|
||||
let query = normalize_vector(&generate_vectors(1, dimensions)[0]);
|
||||
let candidates: Vec<SearchResult> = (0..20)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("doc_{}", i),
|
||||
score: i as f32 * 0.05,
|
||||
vector: Some(normalize_vector(&generate_vectors(1, dimensions)[0])),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Rerank using MMR
|
||||
let results = mmr.rerank(&query, candidates, 10).unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
// Results should be diverse (not just top-10 by relevance)
|
||||
|
||||
println!("✓ MMR 128D: diversified {} results", results.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_lambda_variations() {
|
||||
let dimensions = 64;
|
||||
|
||||
// Test with pure relevance (lambda = 1.0)
|
||||
let config_relevance = MMRConfig {
|
||||
lambda: 1.0,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
let mmr_relevance = MMRSearch::new(config_relevance).unwrap();
|
||||
|
||||
// Test with pure diversity (lambda = 0.0)
|
||||
let config_diversity = MMRConfig {
|
||||
lambda: 0.0,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
fetch_multiplier: 2.0,
|
||||
};
|
||||
let mmr_diversity = MMRSearch::new(config_diversity).unwrap();
|
||||
|
||||
let query = generate_vectors(1, dimensions)[0].clone();
|
||||
let candidates: Vec<SearchResult> = (0..10)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("doc_{}", i),
|
||||
score: i as f32 * 0.1,
|
||||
vector: Some(generate_vectors(1, dimensions)[0].clone()),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results_relevance = mmr_relevance.rerank(&query, candidates.clone(), 5).unwrap();
|
||||
let results_diversity = mmr_diversity.rerank(&query, candidates, 5).unwrap();
|
||||
|
||||
assert_eq!(results_relevance.len(), 5);
|
||||
assert_eq!(results_diversity.len(), 5);
|
||||
|
||||
println!("✓ MMR: lambda variations tested successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_search_basic() {
|
||||
let config = HybridConfig {
|
||||
vector_weight: 0.7,
|
||||
keyword_weight: 0.3,
|
||||
normalization: NormalizationStrategy::MinMax,
|
||||
};
|
||||
|
||||
let mut hybrid = HybridSearch::new(config);
|
||||
|
||||
// Index documents
|
||||
hybrid.index_document("doc1".to_string(), "rust programming language".to_string());
|
||||
hybrid.index_document("doc2".to_string(), "python machine learning".to_string());
|
||||
hybrid.index_document("doc3".to_string(), "rust systems programming".to_string());
|
||||
hybrid.finalize_indexing();
|
||||
|
||||
// Test BM25 scoring
|
||||
let score = hybrid.bm25.score(
|
||||
"rust programming",
|
||||
&"doc1".to_string(),
|
||||
"rust programming language",
|
||||
);
|
||||
assert!(score > 0.0);
|
||||
|
||||
println!(
|
||||
"✓ Hybrid Search: indexed {} documents",
|
||||
hybrid.doc_texts.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_search_keyword_matching() {
|
||||
let mut bm25 = BM25::new(1.5, 0.75);
|
||||
|
||||
bm25.index_document("doc1".to_string(), "vector database with HNSW indexing");
|
||||
bm25.index_document("doc2".to_string(), "relational database management system");
|
||||
bm25.index_document("doc3".to_string(), "vector search and similarity matching");
|
||||
bm25.build_idf();
|
||||
|
||||
// Test candidate retrieval
|
||||
let candidates = bm25.get_candidate_docs("vector database");
|
||||
assert!(candidates.contains(&"doc1".to_string()));
|
||||
assert!(candidates.contains(&"doc3".to_string()));
|
||||
|
||||
// Test scoring
|
||||
let score1 = bm25.score(
|
||||
"vector database",
|
||||
&"doc1".to_string(),
|
||||
"vector database with HNSW indexing",
|
||||
);
|
||||
let score2 = bm25.score(
|
||||
"vector database",
|
||||
&"doc2".to_string(),
|
||||
"relational database management system",
|
||||
);
|
||||
|
||||
assert!(score1 > score2); // doc1 matches better
|
||||
|
||||
println!(
|
||||
"✓ Hybrid Search (BM25): {} candidate documents",
|
||||
candidates.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_prediction_128d() {
|
||||
let dimensions = 128;
|
||||
|
||||
let config = ConformalConfig {
|
||||
alpha: 0.1, // 90% coverage
|
||||
calibration_fraction: 0.2,
|
||||
nonconformity_measure: NonconformityMeasure::Distance,
|
||||
};
|
||||
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
// Create calibration data
|
||||
let calibration_queries = generate_vectors(10, dimensions);
|
||||
let true_neighbors: Vec<Vec<String>> = (0..10)
|
||||
.map(|i| vec![format!("vec_{}", i), format!("vec_{}", i + 1)])
|
||||
.collect();
|
||||
|
||||
// Mock search function
|
||||
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
|
||||
Ok((0..k)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("vec_{}", i),
|
||||
score: i as f32 * 0.1,
|
||||
vector: Some(vec![0.0; dimensions]),
|
||||
metadata: None,
|
||||
})
|
||||
.collect())
|
||||
};
|
||||
|
||||
// Calibrate
|
||||
predictor
|
||||
.calibrate(&calibration_queries, &true_neighbors, search_fn)
|
||||
.unwrap();
|
||||
|
||||
assert!(predictor.threshold.is_some());
|
||||
assert!(!predictor.calibration_scores.is_empty());
|
||||
|
||||
// Make prediction
|
||||
let query = generate_vectors(1, dimensions)[0].clone();
|
||||
let prediction_set = predictor.predict(&query, search_fn).unwrap();
|
||||
|
||||
assert_eq!(prediction_set.confidence, 0.9);
|
||||
assert!(!prediction_set.results.is_empty());
|
||||
|
||||
println!(
|
||||
"✓ Conformal Prediction 128D: prediction set size = {}",
|
||||
prediction_set.results.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_prediction_384d() {
|
||||
let dimensions = 384;
|
||||
|
||||
let config = ConformalConfig {
|
||||
alpha: 0.05, // 95% coverage
|
||||
calibration_fraction: 0.2,
|
||||
nonconformity_measure: NonconformityMeasure::NormalizedDistance,
|
||||
};
|
||||
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
let calibration_queries = generate_vectors(5, dimensions);
|
||||
let true_neighbors: Vec<Vec<String>> = (0..5).map(|i| vec![format!("vec_{}", i)]).collect();
|
||||
|
||||
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
|
||||
Ok((0..k)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("vec_{}", i),
|
||||
score: 0.1 + (i as f32 * 0.05),
|
||||
vector: Some(vec![0.0; dimensions]),
|
||||
metadata: None,
|
||||
})
|
||||
.collect())
|
||||
};
|
||||
|
||||
predictor
|
||||
.calibrate(&calibration_queries, &true_neighbors, search_fn)
|
||||
.unwrap();
|
||||
|
||||
// Test calibration statistics
|
||||
let stats = predictor.get_statistics().unwrap();
|
||||
assert_eq!(stats.num_samples, 5);
|
||||
assert!(stats.mean > 0.0);
|
||||
|
||||
println!("✓ Conformal Prediction 384D: calibration stats computed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conformal_prediction_adaptive_k() {
|
||||
let dimensions = 256;
|
||||
|
||||
let config = ConformalConfig {
|
||||
alpha: 0.1,
|
||||
calibration_fraction: 0.2,
|
||||
nonconformity_measure: NonconformityMeasure::InverseRank,
|
||||
};
|
||||
|
||||
let mut predictor = ConformalPredictor::new(config).unwrap();
|
||||
|
||||
let calibration_queries = generate_vectors(8, dimensions);
|
||||
let true_neighbors: Vec<Vec<String>> = (0..8).map(|i| vec![format!("vec_{}", i)]).collect();
|
||||
|
||||
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
|
||||
Ok((0..k)
|
||||
.map(|i| SearchResult {
|
||||
id: format!("vec_{}", i),
|
||||
score: i as f32 * 0.08,
|
||||
vector: Some(vec![0.0; dimensions]),
|
||||
metadata: None,
|
||||
})
|
||||
.collect())
|
||||
};
|
||||
|
||||
predictor
|
||||
.calibrate(&calibration_queries, &true_neighbors, search_fn)
|
||||
.unwrap();
|
||||
|
||||
// Test adaptive top-k
|
||||
let query = generate_vectors(1, dimensions)[0].clone();
|
||||
let adaptive_k = predictor.adaptive_top_k(&query, search_fn).unwrap();
|
||||
|
||||
assert!(adaptive_k > 0);
|
||||
println!("✓ Conformal Prediction: adaptive k = {}", adaptive_k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_features_integration() {
|
||||
// Test that all features can work together
|
||||
let dimensions = 128;
|
||||
|
||||
// 1. Enhanced PQ
|
||||
let pq_config = PQConfig {
|
||||
num_subspaces: 4,
|
||||
codebook_size: 16,
|
||||
num_iterations: 5,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
let mut pq = EnhancedPQ::new(dimensions, pq_config).unwrap();
|
||||
let training_data = generate_vectors(50, dimensions);
|
||||
pq.train(&training_data).unwrap();
|
||||
|
||||
// 2. MMR
|
||||
let mmr_config = MMRConfig::default();
|
||||
let mmr = MMRSearch::new(mmr_config).unwrap();
|
||||
|
||||
// 3. Hybrid Search
|
||||
let hybrid_config = HybridConfig::default();
|
||||
let mut hybrid = HybridSearch::new(hybrid_config);
|
||||
hybrid.index_document("doc1".to_string(), "test document".to_string());
|
||||
hybrid.finalize_indexing();
|
||||
|
||||
// 4. Conformal Prediction
|
||||
let cp_config = ConformalConfig::default();
|
||||
let predictor = ConformalPredictor::new(cp_config).unwrap();
|
||||
|
||||
println!("✓ All features integrated successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pq_recall_384d() {
|
||||
let dimensions = 384;
|
||||
let num_vectors = 500;
|
||||
let k = 10;
|
||||
|
||||
let config = PQConfig {
|
||||
num_subspaces: 8,
|
||||
codebook_size: 256,
|
||||
num_iterations: 15,
|
||||
metric: DistanceMetric::Euclidean,
|
||||
};
|
||||
|
||||
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
|
||||
|
||||
// Generate and train
|
||||
let vectors = generate_vectors(num_vectors, dimensions);
|
||||
pq.train(&vectors).unwrap();
|
||||
|
||||
// Add vectors
|
||||
for (i, vector) in vectors.iter().enumerate() {
|
||||
pq.add_quantized(format!("vec_{}", i), vector).unwrap();
|
||||
}
|
||||
|
||||
// Test search
|
||||
let query = &vectors[0]; // Use first vector as query
|
||||
let results = pq.search(query, k).unwrap();
|
||||
|
||||
// Verify we got results
|
||||
assert!(!results.is_empty(), "Search should return results");
|
||||
assert_eq!(results.len(), k, "Should return k results");
|
||||
|
||||
// First result should be among the top candidates (PQ is approximate)
|
||||
// Due to quantization, the exact match might not be at position 0
|
||||
// but the distance should be reasonably small relative to random vectors
|
||||
let min_distance = results
|
||||
.iter()
|
||||
.map(|(_, d)| *d)
|
||||
.fold(f32::INFINITY, f32::min);
|
||||
|
||||
// In high dimensions, PQ distances vary based on quantization quality
|
||||
// Check that we get reasonable results (top result should be closer than random)
|
||||
assert!(
|
||||
min_distance < 50.0,
|
||||
"Minimum distance {} should be reasonable for quantized search",
|
||||
min_distance
|
||||
);
|
||||
|
||||
println!(
|
||||
"✓ PQ 384D Recall Test: top-{} results retrieved, min distance = {:.4}",
|
||||
results.len(),
|
||||
min_distance
|
||||
);
|
||||
}
|
||||
440
crates/ruvector-core/tests/concurrent_tests.rs
Normal file
440
crates/ruvector-core/tests/concurrent_tests.rs
Normal file
@@ -0,0 +1,440 @@
|
||||
//! Concurrent access tests with multiple threads
|
||||
//!
|
||||
//! These tests verify thread-safety and correct behavior under concurrent access.
|
||||
|
||||
use ruvector_core::types::{DbOptions, HnswConfig, SearchQuery};
|
||||
use ruvector_core::{VectorDB, VectorEntry};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_reads() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent_reads.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 32;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial data
|
||||
for i in 0..100 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Spawn multiple reader threads
|
||||
let num_threads = 10;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..50 {
|
||||
let id = format!("vec_{}", (thread_id * 10 + i) % 100);
|
||||
let result = db_clone.get(&id).unwrap();
|
||||
assert!(result.is_some(), "Failed to get {}", id);
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_writes_no_collision() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent_writes.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 32;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Spawn multiple writer threads with non-overlapping IDs
|
||||
let num_threads = 10;
|
||||
let vectors_per_thread = 20;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..vectors_per_thread {
|
||||
let id = format!("thread_{}_{}", thread_id, i);
|
||||
db_clone
|
||||
.insert(VectorEntry {
|
||||
id: Some(id.clone()),
|
||||
vector: vec![thread_id as f32; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify all vectors were inserted
|
||||
assert_eq!(db.len().unwrap(), num_threads * vectors_per_thread);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_delete_and_insert() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent_delete_insert.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 16;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial data
|
||||
for i in 0..100 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 16],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let num_threads = 5;
|
||||
let mut handles = vec![];
|
||||
|
||||
// Deleter threads
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
let id = format!("vec_{}", thread_id * 10 + i);
|
||||
db_clone.delete(&id).unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Inserter threads
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
let id = format!("new_{}_{}", thread_id, i);
|
||||
db_clone
|
||||
.insert(VectorEntry {
|
||||
id: Some(id),
|
||||
vector: vec![(thread_id * 100 + i) as f32; 16],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify database is in consistent state
|
||||
let final_len = db.len().unwrap();
|
||||
assert_eq!(final_len, 100); // 100 original - 50 deleted + 50 inserted
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_search_and_insert() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent_search_insert.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = Some(HnswConfig::default());
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial data
|
||||
for i in 0..100 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let num_search_threads = 5;
|
||||
let num_insert_threads = 2;
|
||||
let mut handles = vec![];
|
||||
|
||||
// Search threads
|
||||
for search_id in 0..num_search_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..20 {
|
||||
let query: Vec<f32> = (0..64)
|
||||
.map(|j| ((search_id * 10 + i + j) as f32) * 0.01)
|
||||
.collect();
|
||||
let results = db_clone
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 5,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Should always return some results (at least from initial data)
|
||||
assert!(results.len() > 0);
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Insert threads
|
||||
for insert_id in 0..num_insert_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..50 {
|
||||
db_clone
|
||||
.insert(VectorEntry {
|
||||
id: Some(format!("new_{}_{}", insert_id, i)),
|
||||
vector: (0..64)
|
||||
.map(|j| ((insert_id * 1000 + i + j) as f32) * 0.01)
|
||||
.collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
assert_eq!(db.len().unwrap(), 200); // 100 initial + 100 new
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomicity_of_batch_insert() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("atomic_batch.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 16;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Track successful insertions
|
||||
let inserted_ids = Arc::new(Mutex::new(HashSet::new()));
|
||||
|
||||
let num_threads = 5;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let ids_clone = Arc::clone(&inserted_ids);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for batch_idx in 0..10 {
|
||||
let vectors: Vec<VectorEntry> = (0..10)
|
||||
.map(|i| {
|
||||
let id = format!("t{}_b{}_v{}", thread_id, batch_idx, i);
|
||||
VectorEntry {
|
||||
id: Some(id.clone()),
|
||||
vector: vec![(thread_id * 100 + batch_idx * 10 + i) as f32; 16],
|
||||
metadata: None,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ids = db_clone.insert_batch(vectors).unwrap();
|
||||
|
||||
let mut lock = ids_clone.lock().unwrap();
|
||||
for id in ids {
|
||||
lock.insert(id);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify all insertions were recorded
|
||||
let total_inserted = inserted_ids.lock().unwrap().len();
|
||||
assert_eq!(total_inserted, num_threads * 10 * 10); // threads * batches * vectors_per_batch
|
||||
assert_eq!(db.len().unwrap(), total_inserted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_write_consistency() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("consistency.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 32;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial vector
|
||||
db.insert(VectorEntry {
|
||||
id: Some("test".to_string()),
|
||||
vector: vec![1.0; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let num_threads = 10;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
// Read
|
||||
let entry = db_clone.get("test").unwrap();
|
||||
assert!(entry.is_some());
|
||||
|
||||
// Verify vector is consistent
|
||||
let vector = entry.unwrap().vector;
|
||||
assert_eq!(vector.len(), 32);
|
||||
|
||||
// All values should be the same (not corrupted)
|
||||
let first_val = vector[0];
|
||||
assert!(vector
|
||||
.iter()
|
||||
.all(|&v| v == first_val || (first_val == 1.0 || v == (thread_id as f32))));
|
||||
|
||||
// Write (update) - this creates a race condition intentionally
|
||||
if thread_id % 2 == 0 {
|
||||
let _ = db_clone.insert(VectorEntry {
|
||||
id: Some("test".to_string()),
|
||||
vector: vec![thread_id as f32; 32],
|
||||
metadata: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify database is still consistent
|
||||
let final_entry = db.get("test").unwrap();
|
||||
assert!(final_entry.is_some());
|
||||
|
||||
let vector = final_entry.unwrap().vector;
|
||||
assert_eq!(vector.len(), 32);
|
||||
|
||||
// Check no corruption (all values should be the same)
|
||||
let first_val = vector[0];
|
||||
assert!(vector.iter().all(|&v| v == first_val));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_metadata_updates() {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("metadata.db").to_string_lossy().to_string();
|
||||
options.dimensions = 16;
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert initial vectors
|
||||
for i in 0..50 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 16],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let num_threads = 5;
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(format!("thread_{}", thread_id), serde_json::json!(i));
|
||||
|
||||
// Update vector with metadata
|
||||
let id = format!("vec_{}", i * 5 + thread_id);
|
||||
db_clone
|
||||
.insert(VectorEntry {
|
||||
id: Some(id.clone()),
|
||||
vector: vec![thread_id as f32; 16],
|
||||
metadata: Some(metadata),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify some vectors have metadata
|
||||
let entry = db.get("vec_0").unwrap();
|
||||
assert!(entry.is_some());
|
||||
}
|
||||
307
crates/ruvector-core/tests/embeddings_test.rs
Normal file
307
crates/ruvector-core/tests/embeddings_test.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
//! Integration tests for embedding providers
|
||||
|
||||
use ruvector_core::embeddings::{ApiEmbedding, EmbeddingProvider, HashEmbedding};
|
||||
use ruvector_core::{types::DbOptions, AgenticDB};
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_hash_embedding_provider() {
|
||||
let provider = HashEmbedding::new(128);
|
||||
|
||||
// Test basic embedding
|
||||
let emb1 = provider.embed("hello world").unwrap();
|
||||
assert_eq!(emb1.len(), 128);
|
||||
|
||||
// Test consistency
|
||||
let emb2 = provider.embed("hello world").unwrap();
|
||||
assert_eq!(emb1, emb2, "Same text should produce same embedding");
|
||||
|
||||
// Test different text produces different embeddings
|
||||
let emb3 = provider.embed("goodbye world").unwrap();
|
||||
assert_ne!(
|
||||
emb1, emb3,
|
||||
"Different text should produce different embeddings"
|
||||
);
|
||||
|
||||
// Test normalization
|
||||
let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!(
|
||||
(norm - 1.0).abs() < 1e-5,
|
||||
"Embedding should be normalized to unit length"
|
||||
);
|
||||
|
||||
// Test provider info
|
||||
assert_eq!(provider.dimensions(), 128);
|
||||
assert!(provider.name().contains("Hash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agenticdb_with_hash_embeddings() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
|
||||
// Create AgenticDB with default hash embeddings
|
||||
let db = AgenticDB::new(options).unwrap();
|
||||
|
||||
assert_eq!(db.embedding_provider_name(), "HashEmbedding (placeholder)");
|
||||
|
||||
// Test storing a reflexion episode
|
||||
let episode_id = db
|
||||
.store_episode(
|
||||
"Solve a math problem".to_string(),
|
||||
vec!["read problem".to_string(), "calculate".to_string()],
|
||||
vec!["got answer 42".to_string()],
|
||||
"Should have shown intermediate steps".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Test retrieving similar episodes
|
||||
let episodes = db
|
||||
.retrieve_similar_episodes("math problem solving", 5)
|
||||
.unwrap();
|
||||
assert!(!episodes.is_empty());
|
||||
assert_eq!(episodes[0].id, episode_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agenticdb_with_custom_hash_provider() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 256;
|
||||
|
||||
// Create custom hash provider
|
||||
let provider = Arc::new(HashEmbedding::new(256));
|
||||
|
||||
// Create AgenticDB with custom provider
|
||||
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
|
||||
|
||||
assert_eq!(db.embedding_provider_name(), "HashEmbedding (placeholder)");
|
||||
|
||||
// Test creating a skill
|
||||
let mut params = std::collections::HashMap::new();
|
||||
params.insert("input".to_string(), "string".to_string());
|
||||
|
||||
let skill_id = db
|
||||
.create_skill(
|
||||
"Parse JSON".to_string(),
|
||||
"Parse JSON from string".to_string(),
|
||||
params,
|
||||
vec!["json.parse()".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Search for skills
|
||||
let skills = db.search_skills("parse json data", 5).unwrap();
|
||||
assert!(!skills.is_empty());
|
||||
assert_eq!(skills[0].id, skill_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch_validation() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
|
||||
// Try to create with mismatched dimensions
|
||||
let provider = Arc::new(HashEmbedding::new(256)); // Different from options
|
||||
|
||||
let result = AgenticDB::with_embedding_provider(options, provider);
|
||||
assert!(result.is_err(), "Should fail when dimensions don't match");
|
||||
|
||||
if let Err(err) = result {
|
||||
assert!(
|
||||
err.to_string().contains("do not match"),
|
||||
"Error should mention dimension mismatch"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_embedding_provider_construction() {
|
||||
// Test OpenAI provider construction
|
||||
let openai_small = ApiEmbedding::openai("sk-test", "text-embedding-3-small");
|
||||
assert_eq!(openai_small.dimensions(), 1536);
|
||||
assert_eq!(openai_small.name(), "ApiEmbedding");
|
||||
|
||||
let openai_large = ApiEmbedding::openai("sk-test", "text-embedding-3-large");
|
||||
assert_eq!(openai_large.dimensions(), 3072);
|
||||
|
||||
// Test Cohere provider construction
|
||||
let cohere = ApiEmbedding::cohere("co-test", "embed-english-v3.0");
|
||||
assert_eq!(cohere.dimensions(), 1024);
|
||||
|
||||
// Test Voyage provider construction
|
||||
let voyage = ApiEmbedding::voyage("vo-test", "voyage-2");
|
||||
assert_eq!(voyage.dimensions(), 1024);
|
||||
|
||||
let voyage_large = ApiEmbedding::voyage("vo-test", "voyage-large-2");
|
||||
assert_eq!(voyage_large.dimensions(), 1536);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Requires API key and network access
|
||||
fn test_api_embedding_openai() {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.expect("OPENAI_API_KEY environment variable required for this test");
|
||||
|
||||
let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
|
||||
|
||||
let embedding = provider.embed("hello world").unwrap();
|
||||
assert_eq!(embedding.len(), 1536);
|
||||
|
||||
// Check that embeddings are different for different texts
|
||||
let embedding2 = provider.embed("goodbye world").unwrap();
|
||||
assert_ne!(embedding, embedding2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // Requires API key and network access
|
||||
fn test_agenticdb_with_openai_embeddings() {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.expect("OPENAI_API_KEY environment variable required for this test");
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 1536; // OpenAI text-embedding-3-small dimensions
|
||||
|
||||
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
|
||||
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
|
||||
|
||||
assert_eq!(db.embedding_provider_name(), "ApiEmbedding");
|
||||
|
||||
// Test with real semantic embeddings
|
||||
let _episode1_id = db
|
||||
.store_episode(
|
||||
"Solve calculus problem".to_string(),
|
||||
vec![
|
||||
"identify function".to_string(),
|
||||
"take derivative".to_string(),
|
||||
],
|
||||
vec!["computed derivative".to_string()],
|
||||
"Should explain chain rule application".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let _episode2_id = db
|
||||
.store_episode(
|
||||
"Solve algebra problem".to_string(),
|
||||
vec!["simplify equation".to_string(), "solve for x".to_string()],
|
||||
vec!["found x = 5".to_string()],
|
||||
"Should show all steps".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Search with semantic query - should find calculus episode first
|
||||
let episodes = db
|
||||
.retrieve_similar_episodes("derivative calculation", 2)
|
||||
.unwrap();
|
||||
assert!(!episodes.is_empty());
|
||||
|
||||
// With real embeddings, "derivative" should match calculus better than algebra
|
||||
println!(
|
||||
"Found episodes: {:?}",
|
||||
episodes.iter().map(|e| &e.task).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "real-embeddings")]
|
||||
#[test]
|
||||
#[ignore] // Requires model download
|
||||
fn test_candle_embedding_provider() {
|
||||
use ruvector_core::CandleEmbedding;
|
||||
|
||||
let provider =
|
||||
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false).unwrap();
|
||||
|
||||
assert_eq!(provider.dimensions(), 384);
|
||||
assert_eq!(provider.name(), "CandleEmbedding (transformer)");
|
||||
|
||||
let embedding = provider.embed("hello world").unwrap();
|
||||
assert_eq!(embedding.len(), 384);
|
||||
|
||||
// Check normalization
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
assert!((norm - 1.0).abs() < 1e-3, "Embedding should be normalized");
|
||||
|
||||
// Test semantic similarity
|
||||
let emb_dog = provider.embed("dog").unwrap();
|
||||
let emb_cat = provider.embed("cat").unwrap();
|
||||
let emb_car = provider.embed("car").unwrap();
|
||||
|
||||
// Cosine similarity
|
||||
let similarity_dog_cat: f32 = emb_dog.iter().zip(emb_cat.iter()).map(|(a, b)| a * b).sum();
|
||||
|
||||
let similarity_dog_car: f32 = emb_dog.iter().zip(emb_car.iter()).map(|(a, b)| a * b).sum();
|
||||
|
||||
// "dog" and "cat" should be more similar than "dog" and "car"
|
||||
assert!(
|
||||
similarity_dog_cat > similarity_dog_car,
|
||||
"Semantic embeddings should show dog-cat more similar than dog-car"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "real-embeddings")]
|
||||
#[test]
|
||||
#[ignore] // Requires model download
|
||||
fn test_agenticdb_with_candle_embeddings() {
|
||||
use ruvector_core::CandleEmbedding;
|
||||
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 384;
|
||||
|
||||
let provider = Arc::new(
|
||||
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false).unwrap(),
|
||||
);
|
||||
|
||||
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
db.embedding_provider_name(),
|
||||
"CandleEmbedding (transformer)"
|
||||
);
|
||||
|
||||
// Test with real semantic embeddings
|
||||
let skill1_id = db
|
||||
.create_skill(
|
||||
"File I/O".to_string(),
|
||||
"Read and write files to disk".to_string(),
|
||||
std::collections::HashMap::new(),
|
||||
vec![
|
||||
"open()".to_string(),
|
||||
"read()".to_string(),
|
||||
"write()".to_string(),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let skill2_id = db
|
||||
.create_skill(
|
||||
"Network I/O".to_string(),
|
||||
"Send and receive data over network".to_string(),
|
||||
std::collections::HashMap::new(),
|
||||
vec![
|
||||
"connect()".to_string(),
|
||||
"send()".to_string(),
|
||||
"recv()".to_string(),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Search with semantic query
|
||||
let skills = db.search_skills("reading files from storage", 2).unwrap();
|
||||
assert!(!skills.is_empty());
|
||||
|
||||
// With real embeddings, file I/O should match better
|
||||
println!(
|
||||
"Found skills: {:?}",
|
||||
skills.iter().map(|s| &s.name).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
495
crates/ruvector-core/tests/hnsw_integration_test.rs
Normal file
495
crates/ruvector-core/tests/hnsw_integration_test.rs
Normal file
@@ -0,0 +1,495 @@
|
||||
//! Comprehensive HNSW integration tests with different index sizes
|
||||
|
||||
use ruvector_core::index::hnsw::HnswIndex;
|
||||
use ruvector_core::index::VectorIndex;
|
||||
use ruvector_core::types::{DistanceMetric, HnswConfig};
|
||||
use ruvector_core::Result;
|
||||
|
||||
fn generate_random_vectors(count: usize, dimensions: usize, seed: u64) -> Vec<Vec<f32>> {
|
||||
use rand::{Rng, SeedableRng};
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
|
||||
(0..count)
|
||||
.map(|_| {
|
||||
(0..dimensions)
|
||||
.map(|_| rng.gen::<f32>() * 2.0 - 1.0)
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn normalize_vector(v: &[f32]) -> Vec<f32> {
|
||||
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
v.to_vec()
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_recall(ground_truth: &[String], results: &[String]) -> f32 {
|
||||
let gt_set: std::collections::HashSet<_> = ground_truth.iter().collect();
|
||||
let found = results.iter().filter(|id| gt_set.contains(id)).count();
|
||||
found as f32 / ground_truth.len() as f32
|
||||
}
|
||||
|
||||
fn brute_force_search(
|
||||
query: &[f32],
|
||||
vectors: &[(String, Vec<f32>)],
|
||||
k: usize,
|
||||
metric: DistanceMetric,
|
||||
) -> Vec<String> {
|
||||
use ruvector_core::distance::distance;
|
||||
|
||||
let mut distances: Vec<_> = vectors
|
||||
.iter()
|
||||
.map(|(id, v)| {
|
||||
let dist = distance(query, v, metric).unwrap();
|
||||
(id.clone(), dist)
|
||||
})
|
||||
.collect();
|
||||
|
||||
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
distances.into_iter().take(k).map(|(id, _)| id).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_100_vectors() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 100;
|
||||
let k = 10;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 200,
|
||||
max_elements: 1000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
// Generate and insert vectors
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 42);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
for (i, vector) in normalized_vectors.iter().enumerate() {
|
||||
index.add(format!("vec_{}", i), vector.clone())?;
|
||||
}
|
||||
|
||||
assert_eq!(index.len(), num_vectors);
|
||||
|
||||
// Test search accuracy with multiple queries
|
||||
let num_queries = 10;
|
||||
let mut total_recall = 0.0;
|
||||
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * (num_vectors / num_queries);
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
// Get HNSW results
|
||||
let results = index.search(query, k)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
// Get ground truth with brute force
|
||||
let vectors_with_ids: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
|
||||
.collect();
|
||||
|
||||
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
|
||||
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
println!(
|
||||
"100 vectors - Average recall@{}: {:.2}%",
|
||||
k,
|
||||
avg_recall * 100.0
|
||||
);
|
||||
|
||||
// For small datasets, we expect very high recall
|
||||
assert!(
|
||||
avg_recall >= 0.90,
|
||||
"Recall should be at least 90% for 100 vectors"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_1k_vectors() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 1000;
|
||||
let k = 10;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 200,
|
||||
max_elements: 10000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
// Generate and insert vectors
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 12345);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
// Use batch insert for better performance
|
||||
let entries: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
assert_eq!(index.len(), num_vectors);
|
||||
|
||||
// Test search accuracy
|
||||
let num_queries = 20;
|
||||
let mut total_recall = 0.0;
|
||||
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * (num_vectors / num_queries);
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
let results = index.search(query, k)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
let vectors_with_ids: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
|
||||
.collect();
|
||||
|
||||
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
println!(
|
||||
"1K vectors - Average recall@{}: {:.2}%",
|
||||
k,
|
||||
avg_recall * 100.0
|
||||
);
|
||||
|
||||
// Should achieve at least 95% recall with ef_search=200
|
||||
assert!(
|
||||
avg_recall >= 0.95,
|
||||
"Recall should be at least 95% for 1K vectors with ef_search=200"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_10k_vectors() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 10000;
|
||||
let k = 10;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 200,
|
||||
max_elements: 100000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
println!("Generating {} vectors...", num_vectors);
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 98765);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
println!("Inserting vectors in batches...");
|
||||
// Insert in batches for better performance
|
||||
let batch_size = 1000;
|
||||
for batch_start in (0..num_vectors).step_by(batch_size) {
|
||||
let batch_end = (batch_start + batch_size).min(num_vectors);
|
||||
let entries: Vec<_> = normalized_vectors[batch_start..batch_end]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", batch_start + i), v.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
}
|
||||
|
||||
assert_eq!(index.len(), num_vectors);
|
||||
println!("Index built with {} vectors", index.len());
|
||||
|
||||
// Prepare all vectors for ground truth computation
|
||||
let all_vectors: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
// Test search accuracy with a sample of queries
|
||||
let num_queries = 20; // Reduced for faster testing
|
||||
let mut total_recall = 0.0;
|
||||
|
||||
println!("Running {} queries...", num_queries);
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * (num_vectors / num_queries);
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
let results = index.search(query, k)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
// Compare against all vectors for accurate ground truth
|
||||
let ground_truth = brute_force_search(query, &all_vectors, k, DistanceMetric::Cosine);
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
println!(
|
||||
"10K vectors - Average recall@{}: {:.2}%",
|
||||
k,
|
||||
avg_recall * 100.0
|
||||
);
|
||||
|
||||
// With ef_search=200 and m=32, we should achieve good recall
|
||||
assert!(
|
||||
avg_recall >= 0.70,
|
||||
"Recall should be at least 70% for 10K vectors, got {:.2}%",
|
||||
avg_recall * 100.0
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_ef_search_tuning() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 500;
|
||||
let k = 10;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 50, // Start with lower ef_search
|
||||
max_elements: 10000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 54321);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
let entries: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
|
||||
// Test different ef_search values
|
||||
let ef_values = vec![50, 100, 200, 500];
|
||||
|
||||
for ef in ef_values {
|
||||
let mut total_recall = 0.0;
|
||||
let num_queries = 10;
|
||||
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * 50;
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
let results = index.search_with_ef(query, k, ef)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
let vectors_with_ids: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
|
||||
.collect();
|
||||
|
||||
let ground_truth =
|
||||
brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
println!(
|
||||
"ef_search={} - Average recall@{}: {:.2}%",
|
||||
ef,
|
||||
k,
|
||||
avg_recall * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
// Verify that ef_search=200 achieves at least 95% recall
|
||||
let mut total_recall = 0.0;
|
||||
let num_queries = 10;
|
||||
|
||||
for i in 0..num_queries {
|
||||
let query_idx = i * 50;
|
||||
let query = &normalized_vectors[query_idx];
|
||||
|
||||
let results = index.search_with_ef(query, k, 200)?;
|
||||
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
|
||||
|
||||
let vectors_with_ids: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
|
||||
.collect();
|
||||
|
||||
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
|
||||
let recall = calculate_recall(&ground_truth, &result_ids);
|
||||
total_recall += recall;
|
||||
}
|
||||
|
||||
let avg_recall = total_recall / num_queries as f32;
|
||||
assert!(
|
||||
avg_recall >= 0.95,
|
||||
"ef_search=200 should achieve at least 95% recall"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_serialization_large() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 500;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 100,
|
||||
max_elements: 10000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 11111);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
let entries: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
index.add_batch(entries)?;
|
||||
|
||||
// Serialize
|
||||
println!("Serializing index with {} vectors...", num_vectors);
|
||||
let bytes = index.serialize()?;
|
||||
println!(
|
||||
"Serialized size: {} bytes ({:.2} KB)",
|
||||
bytes.len(),
|
||||
bytes.len() as f32 / 1024.0
|
||||
);
|
||||
|
||||
// Deserialize
|
||||
println!("Deserializing index...");
|
||||
let restored_index = HnswIndex::deserialize(&bytes)?;
|
||||
|
||||
assert_eq!(restored_index.len(), num_vectors);
|
||||
|
||||
// Test that search works on restored index
|
||||
let query = &normalized_vectors[0];
|
||||
let original_results = index.search(query, 10)?;
|
||||
let restored_results = restored_index.search(query, 10)?;
|
||||
|
||||
// Results should be identical
|
||||
assert_eq!(original_results.len(), restored_results.len());
|
||||
|
||||
println!("Serialization test passed!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_different_metrics() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 200;
|
||||
let k = 5;
|
||||
|
||||
// Note: DotProduct can produce negative distances on normalized vectors,
|
||||
// which causes issues with the underlying hnsw_rs library.
|
||||
// We test Cosine and Euclidean which are the most commonly used metrics.
|
||||
let metrics = vec![DistanceMetric::Cosine, DistanceMetric::Euclidean];
|
||||
|
||||
for metric in metrics {
|
||||
println!("Testing metric: {:?}", metric);
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 100,
|
||||
max_elements: 1000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, metric, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 99999);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
for (i, vector) in normalized_vectors.iter().enumerate() {
|
||||
index.add(format!("vec_{}", i), vector.clone())?;
|
||||
}
|
||||
|
||||
// Test search
|
||||
let query = &normalized_vectors[0];
|
||||
let results = index.search(query, k)?;
|
||||
|
||||
assert!(!results.is_empty());
|
||||
println!(" Found {} results for metric {:?}", results.len(), metric);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_parallel_batch_insert() -> Result<()> {
|
||||
let dimensions = 128;
|
||||
let num_vectors = 2000;
|
||||
|
||||
let config = HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 100,
|
||||
max_elements: 10000,
|
||||
};
|
||||
|
||||
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
|
||||
|
||||
let vectors = generate_random_vectors(num_vectors, dimensions, 77777);
|
||||
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
|
||||
|
||||
let entries: Vec<_> = normalized_vectors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
|
||||
.collect();
|
||||
|
||||
// Time the batch insert
|
||||
let start = std::time::Instant::now();
|
||||
index.add_batch(entries)?;
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!("Batch inserted {} vectors in {:?}", num_vectors, duration);
|
||||
println!(
|
||||
"Throughput: {:.0} vectors/sec",
|
||||
num_vectors as f64 / duration.as_secs_f64()
|
||||
);
|
||||
|
||||
assert_eq!(index.len(), num_vectors);
|
||||
|
||||
// Verify search still works
|
||||
let query = &normalized_vectors[0];
|
||||
let results = index.search(query, 10)?;
|
||||
assert!(!results.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
453
crates/ruvector-core/tests/integration_tests.rs
Normal file
453
crates/ruvector-core/tests/integration_tests.rs
Normal file
@@ -0,0 +1,453 @@
|
||||
//! Integration tests for end-to-end workflows
|
||||
//!
|
||||
//! These tests verify that all components work together correctly.
|
||||
|
||||
use ruvector_core::types::{DbOptions, DistanceMetric, HnswConfig, SearchQuery};
|
||||
use ruvector_core::{VectorDB, VectorEntry};
|
||||
use std::collections::HashMap;
|
||||
use tempfile::tempdir;
|
||||
|
||||
// ============================================================================
|
||||
// End-to-End Workflow Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_complete_insert_search_workflow() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
options.distance_metric = DistanceMetric::Cosine;
|
||||
options.hnsw_config = Some(HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 50,
|
||||
max_elements: 100_000,
|
||||
});
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert training data
|
||||
let vectors: Vec<VectorEntry> = (0..100)
|
||||
.map(|i| {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("index".to_string(), serde_json::json!(i));
|
||||
|
||||
VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..128).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: Some(metadata),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let ids = db.insert_batch(vectors).unwrap();
|
||||
assert_eq!(ids.len(), 100);
|
||||
|
||||
// Search for similar vectors
|
||||
let query: Vec<f32> = (0..128).map(|j| (j as f32) * 0.01).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: Some(100),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
assert!(results[0].vector.is_some());
|
||||
assert!(results[0].metadata.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_operations_10k_vectors() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 384;
|
||||
options.distance_metric = DistanceMetric::Euclidean;
|
||||
options.hnsw_config = Some(HnswConfig::default());
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Generate 10K vectors
|
||||
println!("Generating 10K vectors...");
|
||||
let vectors: Vec<VectorEntry> = (0..10_000)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..384).map(|j| ((i + j) as f32) * 0.001).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Batch insert
|
||||
println!("Batch inserting 10K vectors...");
|
||||
let start = std::time::Instant::now();
|
||||
let ids = db.insert_batch(vectors).unwrap();
|
||||
let duration = start.elapsed();
|
||||
println!("Batch insert took: {:?}", duration);
|
||||
|
||||
assert_eq!(ids.len(), 10_000);
|
||||
assert_eq!(db.len().unwrap(), 10_000);
|
||||
|
||||
// Perform multiple searches
|
||||
println!("Performing searches...");
|
||||
for i in 0..10 {
|
||||
let query: Vec<f32> = (0..384).map(|j| ((i * 100 + j) as f32) * 0.001).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_persistence_and_reload() {
|
||||
let dir = tempdir().unwrap();
|
||||
let db_path = dir
|
||||
.path()
|
||||
.join("persistent.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
// Create and populate database
|
||||
{
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = db_path.clone();
|
||||
options.dimensions = 3;
|
||||
options.hnsw_config = None; // Use flat index for simpler persistence test
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32, (i * 2) as f32, (i * 3) as f32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(db.len().unwrap(), 10);
|
||||
}
|
||||
|
||||
// Reload database
|
||||
{
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = db_path.clone();
|
||||
options.dimensions = 3;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Verify data persisted
|
||||
assert_eq!(db.len().unwrap(), 10);
|
||||
|
||||
let entry = db.get("vec_5").unwrap().unwrap();
|
||||
assert_eq!(entry.vector, vec![5.0, 10.0, 15.0]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_operations_workflow() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert initial batch
|
||||
let initial: Vec<VectorEntry> = (0..50)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(initial).unwrap();
|
||||
assert_eq!(db.len().unwrap(), 50);
|
||||
|
||||
// Delete some vectors
|
||||
for i in 0..10 {
|
||||
db.delete(&format!("vec_{}", i)).unwrap();
|
||||
}
|
||||
assert_eq!(db.len().unwrap(), 40);
|
||||
|
||||
// Insert more individual vectors
|
||||
for i in 50..60 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(db.len().unwrap(), 50);
|
||||
|
||||
// Search
|
||||
let query: Vec<f32> = (0..64).map(|j| (j as f32) * 0.1).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 20,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(results.len() > 0);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Different Distance Metrics
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_all_distance_metrics() {
|
||||
let metrics = vec![
|
||||
DistanceMetric::Euclidean,
|
||||
DistanceMetric::Cosine,
|
||||
DistanceMetric::DotProduct,
|
||||
DistanceMetric::Manhattan,
|
||||
];
|
||||
|
||||
for metric in metrics {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 32;
|
||||
options.distance_metric = metric;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert test vectors
|
||||
for i in 0..20 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Search
|
||||
let query: Vec<f32> = (0..32).map(|j| (j as f32) * 0.1).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 5,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 5, "Failed for metric {:?}", metric);
|
||||
|
||||
// Verify scores are in ascending order (lower is better for distance)
|
||||
for i in 0..results.len() - 1 {
|
||||
assert!(
|
||||
results[i].score <= results[i + 1].score,
|
||||
"Results not sorted for metric {:?}: {} > {}",
|
||||
metric,
|
||||
results[i].score,
|
||||
results[i + 1].score
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// HNSW Configuration Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_different_configurations() {
|
||||
let configs = vec![
|
||||
HnswConfig {
|
||||
m: 8,
|
||||
ef_construction: 50,
|
||||
ef_search: 50,
|
||||
max_elements: 1000,
|
||||
},
|
||||
HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 100,
|
||||
max_elements: 1000,
|
||||
},
|
||||
HnswConfig {
|
||||
m: 32,
|
||||
ef_construction: 200,
|
||||
ef_search: 200,
|
||||
max_elements: 1000,
|
||||
},
|
||||
];
|
||||
|
||||
for config in configs {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = Some(config.clone());
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert vectors
|
||||
let vectors: Vec<VectorEntry> = (0..100)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(vectors).unwrap();
|
||||
|
||||
// Search with different ef_search values
|
||||
let query: Vec<f32> = (0..64).map(|j| (j as f32) * 0.01).collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: Some(config.ef_search),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10, "Failed for config M={}", config.m);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Metadata Filtering Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_complex_metadata_filtering() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 16;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert vectors with different categories and values
|
||||
for i in 0..50 {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("category".to_string(), serde_json::json!(i % 3));
|
||||
metadata.insert("value".to_string(), serde_json::json!(i / 10));
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..16).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: Some(metadata),
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Search with single filter
|
||||
let mut filter1 = HashMap::new();
|
||||
filter1.insert("category".to_string(), serde_json::json!(0));
|
||||
|
||||
let query: Vec<f32> = (0..16).map(|j| (j as f32) * 0.1).collect();
|
||||
let results1 = db
|
||||
.search(SearchQuery {
|
||||
vector: query.clone(),
|
||||
k: 100,
|
||||
filter: Some(filter1),
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Should only get vectors where i % 3 == 0
|
||||
for result in &results1 {
|
||||
let meta = result.metadata.as_ref().unwrap();
|
||||
assert_eq!(meta.get("category").unwrap(), &serde_json::json!(0));
|
||||
}
|
||||
|
||||
// Search with different filter
|
||||
let mut filter2 = HashMap::new();
|
||||
filter2.insert("value".to_string(), serde_json::json!(2));
|
||||
|
||||
let results2 = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 100,
|
||||
filter: Some(filter2),
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Should only get vectors where i / 10 == 2 (i.e., i in 20..30)
|
||||
for result in &results2 {
|
||||
let meta = result.metadata.as_ref().unwrap();
|
||||
assert_eq!(meta.get("value").unwrap(), &serde_json::json!(2));
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handling Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_dimension_validation() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Try to insert vector with wrong dimensions
|
||||
let result = db.insert(VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0], // Only 3 dimensions, should be 64
|
||||
metadata: None,
|
||||
});
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_with_wrong_dimension() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert some vectors
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v1".to_string()),
|
||||
vector: (0..64).map(|i| i as f32).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Try to search with wrong dimension query
|
||||
// Note: This might not error in the current implementation, but should be validated
|
||||
let query = vec![1.0, 2.0, 3.0]; // Wrong dimension
|
||||
let result = db.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
});
|
||||
|
||||
// Depending on implementation, this might error or return empty results
|
||||
// The important thing is it doesn't panic
|
||||
let _ = result;
|
||||
}
|
||||
345
crates/ruvector-core/tests/property_tests.rs
Normal file
345
crates/ruvector-core/tests/property_tests.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
//! Property-based tests using proptest
|
||||
//!
|
||||
//! These tests verify mathematical properties and invariants that should hold
|
||||
//! for all inputs within a given domain.
|
||||
|
||||
use proptest::prelude::*;
|
||||
use ruvector_core::distance::*;
|
||||
use ruvector_core::quantization::*;
|
||||
use ruvector_core::types::DistanceMetric;
|
||||
|
||||
// ============================================================================
|
||||
// Distance Metric Properties
|
||||
// ============================================================================
|
||||
|
||||
// Strategy to generate valid vectors with bounded values to prevent overflow
|
||||
// Using range that won't overflow when squared: sqrt(f32::MAX) ≈ 1.84e19
|
||||
// We use a more conservative range for numerical stability in distance calculations
|
||||
fn vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
|
||||
prop::collection::vec(-1000.0f32..1000.0f32, dim)
|
||||
}
|
||||
|
||||
// Strategy for normalized vectors (for cosine similarity)
|
||||
fn normalized_vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
|
||||
vector_strategy(dim).prop_map(move |v| {
|
||||
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
v.iter().map(|x| x / norm).collect()
|
||||
} else {
|
||||
vec![1.0 / (dim as f32).sqrt(); dim]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
proptest! {
|
||||
// Property: Distance to self is zero
|
||||
#[test]
|
||||
fn test_euclidean_self_distance_zero(v in vector_strategy(128)) {
|
||||
let dist = euclidean_distance(&v, &v);
|
||||
prop_assert!(dist < 0.001, "Distance to self should be ~0, got {}", dist);
|
||||
}
|
||||
|
||||
// Property: Euclidean distance is symmetric
|
||||
#[test]
|
||||
fn test_euclidean_symmetry(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist_ab = euclidean_distance(&a, &b);
|
||||
let dist_ba = euclidean_distance(&b, &a);
|
||||
prop_assert!((dist_ab - dist_ba).abs() < 0.001, "Distance should be symmetric");
|
||||
}
|
||||
|
||||
// Property: Triangle inequality for Euclidean distance
|
||||
#[test]
|
||||
fn test_euclidean_triangle_inequality(
|
||||
a in vector_strategy(32),
|
||||
b in vector_strategy(32),
|
||||
c in vector_strategy(32)
|
||||
) {
|
||||
let dist_ab = euclidean_distance(&a, &b);
|
||||
let dist_bc = euclidean_distance(&b, &c);
|
||||
let dist_ac = euclidean_distance(&a, &c);
|
||||
|
||||
// d(a,c) <= d(a,b) + d(b,c)
|
||||
prop_assert!(
|
||||
dist_ac <= dist_ab + dist_bc + 0.01, // Small epsilon for floating point
|
||||
"Triangle inequality violated: {} > {} + {}",
|
||||
dist_ac, dist_ab, dist_bc
|
||||
);
|
||||
}
|
||||
|
||||
// Property: Non-negativity of Euclidean distance
|
||||
#[test]
|
||||
fn test_euclidean_non_negative(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist = euclidean_distance(&a, &b);
|
||||
prop_assert!(dist >= 0.0, "Distance must be non-negative, got {}", dist);
|
||||
}
|
||||
|
||||
// Property: Cosine distance symmetry
|
||||
#[test]
|
||||
fn test_cosine_symmetry(
|
||||
a in normalized_vector_strategy(64),
|
||||
b in normalized_vector_strategy(64)
|
||||
) {
|
||||
let dist_ab = cosine_distance(&a, &b);
|
||||
let dist_ba = cosine_distance(&b, &a);
|
||||
prop_assert!((dist_ab - dist_ba).abs() < 0.01, "Cosine distance should be symmetric");
|
||||
}
|
||||
|
||||
// Property: Cosine distance to self is zero
|
||||
#[test]
|
||||
fn test_cosine_self_distance(v in normalized_vector_strategy(64)) {
|
||||
let dist = cosine_distance(&v, &v);
|
||||
prop_assert!(dist < 0.01, "Cosine distance to self should be ~0, got {}", dist);
|
||||
}
|
||||
|
||||
// Property: Manhattan distance symmetry
|
||||
#[test]
|
||||
fn test_manhattan_symmetry(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist_ab = manhattan_distance(&a, &b);
|
||||
let dist_ba = manhattan_distance(&b, &a);
|
||||
prop_assert!((dist_ab - dist_ba).abs() < 0.001);
|
||||
}
|
||||
|
||||
// Property: Manhattan distance non-negativity
|
||||
#[test]
|
||||
fn test_manhattan_non_negative(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist = manhattan_distance(&a, &b);
|
||||
prop_assert!(dist >= 0.0, "Manhattan distance must be non-negative");
|
||||
}
|
||||
|
||||
// Property: Dot product symmetry
|
||||
#[test]
|
||||
fn test_dot_product_symmetry(
|
||||
a in vector_strategy(64),
|
||||
b in vector_strategy(64)
|
||||
) {
|
||||
let dist_ab = dot_product_distance(&a, &b);
|
||||
let dist_ba = dot_product_distance(&b, &a);
|
||||
prop_assert!((dist_ab - dist_ba).abs() < 0.01);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Quantization Round-Trip Properties
|
||||
// ============================================================================
|
||||
|
||||
proptest! {
|
||||
// Property: Scalar quantization round-trip preserves approximate values
|
||||
#[test]
|
||||
fn test_scalar_quantization_roundtrip(
|
||||
v in prop::collection::vec(0.0f32..100.0f32, 64)
|
||||
) {
|
||||
let quantized = ScalarQuantized::quantize(&v);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
prop_assert_eq!(v.len(), reconstructed.len());
|
||||
|
||||
// Check reconstruction error is bounded
|
||||
for (orig, recon) in v.iter().zip(reconstructed.iter()) {
|
||||
let error = (orig - recon).abs();
|
||||
let relative_error = if *orig != 0.0 {
|
||||
error / orig.abs()
|
||||
} else {
|
||||
error
|
||||
};
|
||||
prop_assert!(
|
||||
relative_error < 0.5 || error < 1.0,
|
||||
"Reconstruction error too large: {} vs {}",
|
||||
orig, recon
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Property: Binary quantization preserves signs
|
||||
#[test]
|
||||
fn test_binary_quantization_sign_preservation(
|
||||
v in prop::collection::vec(-10.0f32..10.0f32, 64)
|
||||
.prop_filter("No zeros", |v| v.iter().all(|x| *x != 0.0))
|
||||
) {
|
||||
let quantized = BinaryQuantized::quantize(&v);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in v.iter().zip(reconstructed.iter()) {
|
||||
prop_assert_eq!(
|
||||
orig.signum(),
|
||||
*recon,
|
||||
"Sign not preserved for {} -> {}",
|
||||
orig, recon
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Property: Binary quantization distance to self is zero
|
||||
#[test]
|
||||
fn test_binary_quantization_self_distance(
|
||||
v in prop::collection::vec(-10.0f32..10.0f32, 64)
|
||||
) {
|
||||
let quantized = BinaryQuantized::quantize(&v);
|
||||
let dist = quantized.distance(&quantized);
|
||||
prop_assert_eq!(dist, 0.0, "Distance to self should be 0");
|
||||
}
|
||||
|
||||
// Property: Binary quantization distance is symmetric
|
||||
#[test]
|
||||
fn test_binary_quantization_symmetry(
|
||||
a in prop::collection::vec(-10.0f32..10.0f32, 64),
|
||||
b in prop::collection::vec(-10.0f32..10.0f32, 64)
|
||||
) {
|
||||
let qa = BinaryQuantized::quantize(&a);
|
||||
let qb = BinaryQuantized::quantize(&b);
|
||||
|
||||
let dist_ab = qa.distance(&qb);
|
||||
let dist_ba = qb.distance(&qa);
|
||||
|
||||
prop_assert_eq!(dist_ab, dist_ba, "Distance should be symmetric");
|
||||
}
|
||||
|
||||
// Property: Binary quantization distance is bounded
|
||||
#[test]
|
||||
fn test_binary_quantization_distance_bounded(
|
||||
a in prop::collection::vec(-10.0f32..10.0f32, 64),
|
||||
b in prop::collection::vec(-10.0f32..10.0f32, 64)
|
||||
) {
|
||||
let qa = BinaryQuantized::quantize(&a);
|
||||
let qb = BinaryQuantized::quantize(&b);
|
||||
|
||||
let dist = qa.distance(&qb);
|
||||
|
||||
// Hamming distance for 64 bits should be in [0, 64]
|
||||
prop_assert!(dist >= 0.0 && dist <= 64.0, "Distance {} out of bounds", dist);
|
||||
}
|
||||
|
||||
// Property: Scalar quantization distance is non-negative
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance_non_negative(
|
||||
a in prop::collection::vec(0.0f32..100.0f32, 64),
|
||||
b in prop::collection::vec(0.0f32..100.0f32, 64)
|
||||
) {
|
||||
let qa = ScalarQuantized::quantize(&a);
|
||||
let qb = ScalarQuantized::quantize(&b);
|
||||
|
||||
let dist = qa.distance(&qb);
|
||||
prop_assert!(dist >= 0.0, "Distance must be non-negative");
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Vector Operation Properties
|
||||
// ============================================================================
|
||||
|
||||
proptest! {
|
||||
// Property: Scaling preserves direction for cosine similarity
|
||||
#[test]
|
||||
fn test_cosine_scale_invariance(
|
||||
v in normalized_vector_strategy(32),
|
||||
scale in 0.1f32..10.0f32
|
||||
) {
|
||||
let scaled: Vec<f32> = v.iter().map(|x| x * scale).collect();
|
||||
let dist_original = cosine_distance(&v, &v);
|
||||
let dist_scaled = cosine_distance(&v, &scaled);
|
||||
|
||||
// Cosine distance should be approximately the same (scale invariant)
|
||||
prop_assert!(
|
||||
(dist_original - dist_scaled).abs() < 0.1,
|
||||
"Cosine distance should be scale invariant: {} vs {}",
|
||||
dist_original, dist_scaled
|
||||
);
|
||||
}
|
||||
|
||||
// Property: Adding the same vector to both preserves distance
|
||||
#[test]
|
||||
fn test_euclidean_translation_invariance(
|
||||
a in vector_strategy(32),
|
||||
b in vector_strategy(32),
|
||||
offset in vector_strategy(32)
|
||||
) {
|
||||
let a_offset: Vec<f32> = a.iter().zip(&offset).map(|(x, o)| x + o).collect();
|
||||
let b_offset: Vec<f32> = b.iter().zip(&offset).map(|(x, o)| x + o).collect();
|
||||
|
||||
let dist_original = euclidean_distance(&a, &b);
|
||||
let dist_offset = euclidean_distance(&a_offset, &b_offset);
|
||||
|
||||
prop_assert!(
|
||||
(dist_original - dist_offset).abs() < 0.01,
|
||||
"Euclidean distance should be translation invariant"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Batch Operations Properties
|
||||
// ============================================================================
|
||||
|
||||
proptest! {
|
||||
// Property: Batch distance calculation consistency
|
||||
#[test]
|
||||
fn test_batch_distances_consistency(
|
||||
query in vector_strategy(32),
|
||||
vectors in prop::collection::vec(vector_strategy(32), 10..20)
|
||||
) {
|
||||
// Calculate distances in batch
|
||||
let batch_dists = batch_distances(&query, &vectors, DistanceMetric::Euclidean).unwrap();
|
||||
|
||||
// Calculate distances individually
|
||||
let individual_dists: Vec<f32> = vectors.iter()
|
||||
.map(|v| euclidean_distance(&query, v))
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(batch_dists.len(), individual_dists.len());
|
||||
|
||||
for (batch, individual) in batch_dists.iter().zip(individual_dists.iter()) {
|
||||
prop_assert!(
|
||||
(batch - individual).abs() < 0.01,
|
||||
"Batch and individual distances should match: {} vs {}",
|
||||
batch, individual
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Dimension Handling Properties
|
||||
// ============================================================================
|
||||
|
||||
proptest! {
|
||||
// Property: Distance calculation fails on dimension mismatch
|
||||
#[test]
|
||||
fn test_dimension_mismatch_error(
|
||||
dim1 in 1usize..100,
|
||||
dim2 in 1usize..100
|
||||
) {
|
||||
prop_assume!(dim1 != dim2); // Only test when dimensions differ
|
||||
|
||||
let a = vec![1.0f32; dim1];
|
||||
let b = vec![1.0f32; dim2];
|
||||
|
||||
let result = distance(&a, &b, DistanceMetric::Euclidean);
|
||||
prop_assert!(result.is_err(), "Should error on dimension mismatch");
|
||||
}
|
||||
|
||||
// Property: Distance calculation succeeds on matching dimensions
|
||||
#[test]
|
||||
fn test_dimension_match_success(
|
||||
dim in 1usize..200,
|
||||
a in prop::collection::vec(any::<f32>().prop_filter("Must be finite", |x| x.is_finite()), 1..200),
|
||||
b in prop::collection::vec(any::<f32>().prop_filter("Must be finite", |x| x.is_finite()), 1..200)
|
||||
) {
|
||||
// Ensure same dimensions
|
||||
let a_resized = vec![1.0f32; dim];
|
||||
let b_resized = vec![1.0f32; dim];
|
||||
|
||||
let result = distance(&a_resized, &b_resized, DistanceMetric::Euclidean);
|
||||
prop_assert!(result.is_ok(), "Should succeed on matching dimensions");
|
||||
}
|
||||
}
|
||||
486
crates/ruvector-core/tests/stress_tests.rs
Normal file
486
crates/ruvector-core/tests/stress_tests.rs
Normal file
@@ -0,0 +1,486 @@
|
||||
//! Stress tests for scalability, concurrency, and resilience
|
||||
//!
|
||||
//! These tests push the system to its limits to verify robustness.
|
||||
|
||||
use ruvector_core::types::{DbOptions, HnswConfig, SearchQuery};
|
||||
use ruvector_core::{VectorDB, VectorEntry};
|
||||
use std::sync::{Arc, Barrier};
|
||||
use std::thread;
|
||||
use tempfile::tempdir;
|
||||
|
||||
// ============================================================================
|
||||
// Large-Scale Insertion Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore] // Run with: cargo test --test stress_tests -- --ignored --test-threads=1
|
||||
fn test_million_vector_insertion() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("million.db").to_string_lossy().to_string();
|
||||
options.dimensions = 128;
|
||||
options.hnsw_config = Some(HnswConfig {
|
||||
m: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 50,
|
||||
max_elements: 2_000_000,
|
||||
});
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
println!("Starting million-vector insertion test...");
|
||||
let batch_size = 10_000;
|
||||
let num_batches = 100; // Total: 1M vectors
|
||||
|
||||
for batch_idx in 0..num_batches {
|
||||
println!("Inserting batch {}/{}...", batch_idx + 1, num_batches);
|
||||
|
||||
let vectors: Vec<VectorEntry> = (0..batch_size)
|
||||
.map(|i| {
|
||||
let global_idx = batch_idx * batch_size + i;
|
||||
VectorEntry {
|
||||
id: Some(format!("vec_{}", global_idx)),
|
||||
vector: (0..128)
|
||||
.map(|j| ((global_idx + j) as f32) * 0.0001)
|
||||
.collect(),
|
||||
metadata: None,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
db.insert_batch(vectors).unwrap();
|
||||
let duration = start.elapsed();
|
||||
println!("Batch {} took: {:?}", batch_idx + 1, duration);
|
||||
}
|
||||
|
||||
println!("Final database size: {}", db.len().unwrap());
|
||||
assert_eq!(db.len().unwrap(), 1_000_000);
|
||||
|
||||
// Perform some searches to verify functionality
|
||||
println!("Testing search on 1M vectors...");
|
||||
for i in 0..10 {
|
||||
let query: Vec<f32> = (0..128)
|
||||
.map(|j| ((i * 10000 + j) as f32) * 0.0001)
|
||||
.collect();
|
||||
let start = std::time::Instant::now();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: Some(50),
|
||||
})
|
||||
.unwrap();
|
||||
let duration = start.elapsed();
|
||||
|
||||
println!(
|
||||
"Search {} took: {:?}, found {} results",
|
||||
i + 1,
|
||||
duration,
|
||||
results.len()
|
||||
);
|
||||
assert_eq!(results.len(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrent Query Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_queries() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("concurrent.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 64;
|
||||
options.hnsw_config = Some(HnswConfig::default());
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Insert test data
|
||||
println!("Inserting test data...");
|
||||
let vectors: Vec<VectorEntry> = (0..1000)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(vectors).unwrap();
|
||||
|
||||
// Spawn multiple threads doing concurrent searches
|
||||
println!("Starting 10 concurrent query threads...");
|
||||
let num_threads = 10;
|
||||
let queries_per_thread = 100;
|
||||
|
||||
let barrier = Arc::new(Barrier::new(num_threads));
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
// Wait for all threads to be ready
|
||||
barrier_clone.wait();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for i in 0..queries_per_thread {
|
||||
let query: Vec<f32> = (0..64)
|
||||
.map(|j| ((thread_id * 1000 + i + j) as f32) * 0.01)
|
||||
.collect();
|
||||
|
||||
let results = db_clone
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
println!(
|
||||
"Thread {} completed {} queries in {:?}",
|
||||
thread_id, queries_per_thread, duration
|
||||
);
|
||||
duration
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads and collect results
|
||||
let mut total_duration = std::time::Duration::ZERO;
|
||||
for handle in handles {
|
||||
let duration = handle.join().unwrap();
|
||||
total_duration += duration;
|
||||
}
|
||||
|
||||
let total_queries = num_threads * queries_per_thread;
|
||||
println!("Total queries: {}", total_queries);
|
||||
println!(
|
||||
"Average duration per thread: {:?}",
|
||||
total_duration / num_threads as u32
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_inserts_and_queries() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("mixed_concurrent.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 32;
|
||||
options.hnsw_config = Some(HnswConfig::default());
|
||||
|
||||
let db = Arc::new(VectorDB::new(options).unwrap());
|
||||
|
||||
// Initial data
|
||||
let initial: Vec<VectorEntry> = (0..100)
|
||||
.map(|i| VectorEntry {
|
||||
id: Some(format!("initial_{}", i)),
|
||||
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(initial).unwrap();
|
||||
|
||||
// Spawn reader threads
|
||||
let num_readers = 5;
|
||||
let num_writers = 2;
|
||||
let barrier = Arc::new(Barrier::new(num_readers + num_writers));
|
||||
let mut handles = vec![];
|
||||
|
||||
// Reader threads
|
||||
for reader_id in 0..num_readers {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
barrier_clone.wait();
|
||||
|
||||
for i in 0..50 {
|
||||
let query: Vec<f32> = (0..32)
|
||||
.map(|j| ((reader_id * 100 + i + j) as f32) * 0.1)
|
||||
.collect();
|
||||
let results = db_clone
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 5,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(results.len() > 0 && results.len() <= 5);
|
||||
}
|
||||
|
||||
println!("Reader {} completed", reader_id);
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Writer threads
|
||||
for writer_id in 0..num_writers {
|
||||
let db_clone = Arc::clone(&db);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
barrier_clone.wait();
|
||||
|
||||
for i in 0..20 {
|
||||
let entry = VectorEntry {
|
||||
id: Some(format!("writer_{}_{}", writer_id, i)),
|
||||
vector: (0..32)
|
||||
.map(|j| ((writer_id * 1000 + i + j) as f32) * 0.1)
|
||||
.collect(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
db_clone.insert(entry).unwrap();
|
||||
}
|
||||
|
||||
println!("Writer {} completed", writer_id);
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all threads
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
let final_len = db.len().unwrap();
|
||||
println!("Final database size: {}", final_len);
|
||||
assert!(final_len >= 100); // At least initial data should remain
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Pressure Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
#[ignore] // Run with: cargo test --test stress_tests -- --ignored
|
||||
fn test_memory_pressure_large_vectors() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir
|
||||
.path()
|
||||
.join("large_vectors.db")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
options.dimensions = 2048; // Very large vectors
|
||||
options.hnsw_config = Some(HnswConfig {
|
||||
m: 8,
|
||||
ef_construction: 50,
|
||||
ef_search: 50,
|
||||
max_elements: 100_000,
|
||||
});
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
println!("Testing with large 2048-dimensional vectors...");
|
||||
let num_vectors = 10_000;
|
||||
let batch_size = 1000;
|
||||
|
||||
for batch_idx in 0..(num_vectors / batch_size) {
|
||||
let vectors: Vec<VectorEntry> = (0..batch_size)
|
||||
.map(|i| {
|
||||
let global_idx = batch_idx * batch_size + i;
|
||||
VectorEntry {
|
||||
id: Some(format!("vec_{}", global_idx)),
|
||||
vector: (0..2048)
|
||||
.map(|j| ((global_idx + j) as f32) * 0.0001)
|
||||
.collect(),
|
||||
metadata: None,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
db.insert_batch(vectors).unwrap();
|
||||
println!(
|
||||
"Inserted batch {}/{}",
|
||||
batch_idx + 1,
|
||||
num_vectors / batch_size
|
||||
);
|
||||
}
|
||||
|
||||
println!("Database size: {}", db.len().unwrap());
|
||||
assert_eq!(db.len().unwrap(), num_vectors);
|
||||
|
||||
// Perform searches
|
||||
for i in 0..5 {
|
||||
let query: Vec<f32> = (0..2048)
|
||||
.map(|j| ((i * 1000 + j) as f32) * 0.0001)
|
||||
.collect();
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: query,
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 10);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Recovery Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_invalid_operations_dont_crash() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 32;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Try various invalid operations
|
||||
|
||||
// 1. Delete non-existent vector
|
||||
let _ = db.delete("nonexistent");
|
||||
|
||||
// 2. Get non-existent vector
|
||||
let _ = db.get("nonexistent");
|
||||
|
||||
// 3. Search with k=0
|
||||
let result = db.search(SearchQuery {
|
||||
vector: vec![0.0; 32],
|
||||
k: 0,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
});
|
||||
// Should either return empty or error gracefully
|
||||
let _ = result;
|
||||
|
||||
// 4. Insert and immediately delete in rapid succession
|
||||
for i in 0..100 {
|
||||
let id = db
|
||||
.insert(VectorEntry {
|
||||
id: Some(format!("temp_{}", i)),
|
||||
vector: vec![1.0; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
db.delete(&id).unwrap();
|
||||
}
|
||||
|
||||
// Database should still be functional
|
||||
db.insert(VectorEntry {
|
||||
id: Some("final".to_string()),
|
||||
vector: vec![1.0; 32],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert!(db.get("final").unwrap().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_repeated_operations() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 16;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert the same ID multiple times (should replace or error)
|
||||
for _ in 0..10 {
|
||||
let _ = db.insert(VectorEntry {
|
||||
id: Some("same_id".to_string()),
|
||||
vector: vec![1.0; 16],
|
||||
metadata: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Delete the same ID multiple times
|
||||
for _ in 0..5 {
|
||||
let _ = db.delete("same_id");
|
||||
}
|
||||
|
||||
// Search repeatedly with the same query
|
||||
let query = vec![1.0; 16];
|
||||
for _ in 0..100 {
|
||||
let _ = db.search(SearchQuery {
|
||||
vector: query.clone(),
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Extreme Parameter Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_extreme_k_values() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 16;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options).unwrap();
|
||||
|
||||
// Insert some vectors
|
||||
for i in 0..10 {
|
||||
db.insert(VectorEntry {
|
||||
id: Some(format!("vec_{}", i)),
|
||||
vector: vec![i as f32; 16],
|
||||
metadata: None,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Search with k larger than database size
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: vec![1.0; 16],
|
||||
k: 1000,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Should return at most 10 results
|
||||
assert!(results.len() <= 10);
|
||||
|
||||
// Search with k=1
|
||||
let results = db
|
||||
.search(SearchQuery {
|
||||
vector: vec![1.0; 16],
|
||||
k: 1,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
772
crates/ruvector-core/tests/test_memory_pool.rs
Normal file
772
crates/ruvector-core/tests/test_memory_pool.rs
Normal file
@@ -0,0 +1,772 @@
|
||||
//! Memory Pool and Allocation Tests
|
||||
//!
|
||||
//! This module tests the arena allocator and cache-optimized storage
|
||||
//! for correct memory management, eviction, and performance characteristics.
|
||||
|
||||
use ruvector_core::arena::{Arena, ArenaVec};
|
||||
use ruvector_core::cache_optimized::SoAVectorStorage;
|
||||
use std::sync::{Arc, Barrier};
|
||||
use std::thread;
|
||||
|
||||
// ============================================================================
|
||||
// Arena Allocator Tests
|
||||
// ============================================================================
|
||||
|
||||
mod arena_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arena_basic_allocation() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
|
||||
assert_eq!(vec.capacity(), 10);
|
||||
assert_eq!(vec.len(), 0);
|
||||
assert!(vec.is_empty());
|
||||
|
||||
vec.push(1.0);
|
||||
vec.push(2.0);
|
||||
vec.push(3.0);
|
||||
|
||||
assert_eq!(vec.len(), 3);
|
||||
assert!(!vec.is_empty());
|
||||
assert_eq!(vec[0], 1.0);
|
||||
assert_eq!(vec[1], 2.0);
|
||||
assert_eq!(vec[2], 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_multiple_allocations() {
|
||||
let arena = Arena::new(4096);
|
||||
|
||||
let vec1: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
let vec2: ArenaVec<f64> = arena.alloc_vec(50);
|
||||
let vec3: ArenaVec<u32> = arena.alloc_vec(200);
|
||||
let vec4: ArenaVec<i64> = arena.alloc_vec(75);
|
||||
|
||||
assert_eq!(vec1.capacity(), 100);
|
||||
assert_eq!(vec2.capacity(), 50);
|
||||
assert_eq!(vec3.capacity(), 200);
|
||||
assert_eq!(vec4.capacity(), 75);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_different_types() {
|
||||
let arena = Arena::new(2048);
|
||||
|
||||
// Allocate different types
|
||||
let mut floats: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
let mut doubles: ArenaVec<f64> = arena.alloc_vec(10);
|
||||
let mut ints: ArenaVec<i32> = arena.alloc_vec(10);
|
||||
let mut bytes: ArenaVec<u8> = arena.alloc_vec(10);
|
||||
|
||||
// Push values
|
||||
for i in 0..10 {
|
||||
floats.push(i as f32);
|
||||
doubles.push(i as f64);
|
||||
ints.push(i);
|
||||
bytes.push(i as u8);
|
||||
}
|
||||
|
||||
// Verify
|
||||
for i in 0..10 {
|
||||
assert_eq!(floats[i], i as f32);
|
||||
assert_eq!(doubles[i], i as f64);
|
||||
assert_eq!(ints[i], i as i32);
|
||||
assert_eq!(bytes[i], i as u8);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_reset() {
|
||||
let arena = Arena::new(4096);
|
||||
|
||||
// First allocation cycle
|
||||
{
|
||||
let mut vec1: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
let mut vec2: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
|
||||
for i in 0..50 {
|
||||
vec1.push(i as f32);
|
||||
vec2.push(i as f32 * 2.0);
|
||||
}
|
||||
}
|
||||
|
||||
let used_before = arena.used_bytes();
|
||||
assert!(used_before > 0, "Should have used some bytes");
|
||||
|
||||
arena.reset();
|
||||
|
||||
let used_after = arena.used_bytes();
|
||||
assert_eq!(used_after, 0, "Reset should set used bytes to 0");
|
||||
|
||||
// Allocated bytes should remain (memory is reused, not freed)
|
||||
let allocated = arena.allocated_bytes();
|
||||
assert!(allocated > 0, "Allocated bytes should remain after reset");
|
||||
|
||||
// Second allocation cycle - should reuse memory
|
||||
let mut vec3: ArenaVec<f32> = arena.alloc_vec(50);
|
||||
for i in 0..50 {
|
||||
vec3.push(i as f32);
|
||||
}
|
||||
|
||||
// Memory was reused
|
||||
assert!(
|
||||
arena.allocated_bytes() == allocated,
|
||||
"Should reuse existing allocation"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_chunk_growth() {
|
||||
// Small initial chunk size to force growth
|
||||
let arena = Arena::new(64);
|
||||
|
||||
// Allocate more than fits in one chunk
|
||||
let vec1: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
let vec2: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
let vec3: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
|
||||
assert_eq!(vec1.capacity(), 100);
|
||||
assert_eq!(vec2.capacity(), 100);
|
||||
assert_eq!(vec3.capacity(), 100);
|
||||
|
||||
// Should have allocated multiple chunks
|
||||
let allocated = arena.allocated_bytes();
|
||||
assert!(allocated > 64 * 3, "Should have grown beyond initial chunk");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_as_slice() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
|
||||
for i in 0..5 {
|
||||
vec.push((i * 10) as f32);
|
||||
}
|
||||
|
||||
let slice = vec.as_slice();
|
||||
assert_eq!(slice, &[0.0, 10.0, 20.0, 30.0, 40.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_as_mut_slice() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
|
||||
for i in 0..5 {
|
||||
vec.push((i * 10) as f32);
|
||||
}
|
||||
|
||||
{
|
||||
let slice = vec.as_mut_slice();
|
||||
slice[0] = 100.0;
|
||||
slice[4] = 500.0;
|
||||
}
|
||||
|
||||
assert_eq!(vec[0], 100.0);
|
||||
assert_eq!(vec[4], 500.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_deref() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
|
||||
vec.push(1.0);
|
||||
vec.push(2.0);
|
||||
vec.push(3.0);
|
||||
|
||||
// Test Deref trait (can use slice methods)
|
||||
assert_eq!(vec.len(), 3);
|
||||
assert_eq!(vec.iter().sum::<f32>(), 6.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_large_allocation() {
|
||||
let arena = Arena::new(1024);
|
||||
|
||||
// Allocate something larger than the chunk size
|
||||
let large_vec: ArenaVec<f32> = arena.alloc_vec(10000);
|
||||
assert_eq!(large_vec.capacity(), 10000);
|
||||
|
||||
// Should have grown to accommodate
|
||||
assert!(arena.allocated_bytes() >= 10000 * std::mem::size_of::<f32>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_statistics() {
|
||||
let arena = Arena::new(1024);
|
||||
|
||||
let initial_allocated = arena.allocated_bytes();
|
||||
let initial_used = arena.used_bytes();
|
||||
|
||||
assert_eq!(initial_allocated, 0);
|
||||
assert_eq!(initial_used, 0);
|
||||
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(100);
|
||||
|
||||
assert!(arena.allocated_bytes() > 0);
|
||||
assert!(arena.used_bytes() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "ArenaVec capacity exceeded")]
|
||||
fn test_arena_capacity_exceeded() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(5);
|
||||
|
||||
// Push more than capacity
|
||||
for i in 0..10 {
|
||||
vec.push(i as f32);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_with_default_chunk_size() {
|
||||
let arena = Arena::with_default_chunk_size();
|
||||
|
||||
// Default is 1MB
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(1000);
|
||||
assert!(arena.allocated_bytes() >= 1024 * 1024);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cache-Optimized Storage (SoA) Tests
|
||||
// ============================================================================
|
||||
|
||||
mod soa_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_soa_basic_operations() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
assert_eq!(storage.len(), 0);
|
||||
assert!(storage.is_empty());
|
||||
assert_eq!(storage.dimensions(), 3);
|
||||
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
storage.push(&[4.0, 5.0, 6.0]);
|
||||
|
||||
assert_eq!(storage.len(), 2);
|
||||
assert!(!storage.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_get_vector() {
|
||||
let mut storage = SoAVectorStorage::new(4, 8);
|
||||
|
||||
storage.push(&[1.0, 2.0, 3.0, 4.0]);
|
||||
storage.push(&[5.0, 6.0, 7.0, 8.0]);
|
||||
storage.push(&[9.0, 10.0, 11.0, 12.0]);
|
||||
|
||||
let mut output = vec![0.0; 4];
|
||||
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![1.0, 2.0, 3.0, 4.0]);
|
||||
|
||||
storage.get(1, &mut output);
|
||||
assert_eq!(output, vec![5.0, 6.0, 7.0, 8.0]);
|
||||
|
||||
storage.get(2, &mut output);
|
||||
assert_eq!(output, vec![9.0, 10.0, 11.0, 12.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_dimension_slice() {
|
||||
let mut storage = SoAVectorStorage::new(3, 8);
|
||||
|
||||
storage.push(&[1.0, 10.0, 100.0]);
|
||||
storage.push(&[2.0, 20.0, 200.0]);
|
||||
storage.push(&[3.0, 30.0, 300.0]);
|
||||
storage.push(&[4.0, 40.0, 400.0]);
|
||||
|
||||
// Dimension 0: all first elements
|
||||
let dim0 = storage.dimension_slice(0);
|
||||
assert_eq!(dim0, &[1.0, 2.0, 3.0, 4.0]);
|
||||
|
||||
// Dimension 1: all second elements
|
||||
let dim1 = storage.dimension_slice(1);
|
||||
assert_eq!(dim1, &[10.0, 20.0, 30.0, 40.0]);
|
||||
|
||||
// Dimension 2: all third elements
|
||||
let dim2 = storage.dimension_slice(2);
|
||||
assert_eq!(dim2, &[100.0, 200.0, 300.0, 400.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_dimension_slice_mut() {
|
||||
let mut storage = SoAVectorStorage::new(3, 8);
|
||||
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
storage.push(&[4.0, 5.0, 6.0]);
|
||||
|
||||
// Modify dimension 0
|
||||
{
|
||||
let dim0 = storage.dimension_slice_mut(0);
|
||||
dim0[0] = 100.0;
|
||||
dim0[1] = 400.0;
|
||||
}
|
||||
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![100.0, 2.0, 3.0]);
|
||||
|
||||
storage.get(1, &mut output);
|
||||
assert_eq!(output, vec![400.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_auto_growth() {
|
||||
// Start with small capacity
|
||||
let mut storage = SoAVectorStorage::new(4, 2);
|
||||
|
||||
// Push more vectors than initial capacity
|
||||
for i in 0..100 {
|
||||
storage.push(&[i as f32, (i * 2) as f32, (i * 3) as f32, (i * 4) as f32]);
|
||||
}
|
||||
|
||||
assert_eq!(storage.len(), 100);
|
||||
|
||||
// Verify all values are correct
|
||||
let mut output = vec![0.0; 4];
|
||||
for i in 0..100 {
|
||||
storage.get(i, &mut output);
|
||||
assert_eq!(
|
||||
output,
|
||||
vec![i as f32, (i * 2) as f32, (i * 3) as f32, (i * 4) as f32]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_batch_euclidean_distances() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
// Add orthogonal unit vectors
|
||||
storage.push(&[1.0, 0.0, 0.0]);
|
||||
storage.push(&[0.0, 1.0, 0.0]);
|
||||
storage.push(&[0.0, 0.0, 1.0]);
|
||||
|
||||
let query = vec![1.0, 0.0, 0.0];
|
||||
let mut distances = vec![0.0; 3];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
// Distance to itself should be 0
|
||||
assert!(distances[0] < 0.001, "Distance to self should be ~0");
|
||||
|
||||
// Distance to orthogonal vectors should be sqrt(2)
|
||||
let sqrt2 = (2.0_f32).sqrt();
|
||||
assert!(
|
||||
(distances[1] - sqrt2).abs() < 0.01,
|
||||
"Expected sqrt(2), got {}",
|
||||
distances[1]
|
||||
);
|
||||
assert!(
|
||||
(distances[2] - sqrt2).abs() < 0.01,
|
||||
"Expected sqrt(2), got {}",
|
||||
distances[2]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_batch_distances_large() {
|
||||
let dim = 128;
|
||||
let num_vectors = 1000;
|
||||
|
||||
let mut storage = SoAVectorStorage::new(dim, 16);
|
||||
|
||||
// Add random-ish vectors
|
||||
for i in 0..num_vectors {
|
||||
let vec: Vec<f32> = (0..dim)
|
||||
.map(|j| ((i * dim + j) % 100) as f32 * 0.01)
|
||||
.collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
let query: Vec<f32> = (0..dim).map(|j| (j % 50) as f32 * 0.02).collect();
|
||||
let mut distances = vec![0.0; num_vectors];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
// Verify all distances are non-negative and finite
|
||||
for (i, &dist) in distances.iter().enumerate() {
|
||||
assert!(
|
||||
dist >= 0.0 && dist.is_finite(),
|
||||
"Distance {} is invalid: {}",
|
||||
i,
|
||||
dist
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_common_embedding_dimensions() {
|
||||
// Test common embedding dimensions
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536] {
|
||||
let mut storage = SoAVectorStorage::new(dim, 4);
|
||||
|
||||
let vec: Vec<f32> = (0..dim).map(|i| i as f32 * 0.001).collect();
|
||||
storage.push(&vec);
|
||||
|
||||
let mut output = vec![0.0; dim];
|
||||
storage.get(0, &mut output);
|
||||
|
||||
assert_eq!(output, vec);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "dimensions must be between")]
|
||||
fn test_soa_zero_dimensions() {
|
||||
let _ = SoAVectorStorage::new(0, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_soa_wrong_vector_length() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
storage.push(&[1.0, 2.0]); // Wrong dimension
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_soa_get_out_of_bounds() {
|
||||
let storage = SoAVectorStorage::new(3, 4);
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output); // No vectors added
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_soa_dimension_slice_out_of_bounds() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
let _ = storage.dimension_slice(5); // Invalid dimension
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Memory Pressure Tests
|
||||
// ============================================================================
|
||||
|
||||
mod memory_pressure_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arena_many_small_allocations() {
|
||||
let arena = Arena::new(1024 * 1024); // 1MB
|
||||
|
||||
// Many small allocations
|
||||
for _ in 0..10000 {
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
}
|
||||
|
||||
// Should handle without issues
|
||||
assert!(arena.allocated_bytes() > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_alternating_sizes() {
|
||||
let arena = Arena::new(4096);
|
||||
|
||||
for i in 0..100 {
|
||||
let size = if i % 2 == 0 { 10 } else { 1000 };
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(size);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_large_capacity() {
|
||||
let mut storage = SoAVectorStorage::new(128, 10000);
|
||||
|
||||
for i in 0..10000 {
|
||||
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.0001).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
assert_eq!(storage.len(), 10000);
|
||||
|
||||
// Verify random access
|
||||
let mut output = vec![0.0; 128];
|
||||
storage.get(5000, &mut output);
|
||||
assert!((output[0] - (5000 * 128) as f32 * 0.0001).abs() < 0.0001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_batch_operations_under_pressure() {
|
||||
let dim = 512;
|
||||
let num_vectors = 5000;
|
||||
|
||||
let mut storage = SoAVectorStorage::new(dim, 128);
|
||||
|
||||
for i in 0..num_vectors {
|
||||
let vec: Vec<f32> = (0..dim).map(|j| ((i + j) % 1000) as f32 * 0.001).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
// Perform batch distance calculations
|
||||
let query: Vec<f32> = (0..dim).map(|j| (j % 500) as f32 * 0.002).collect();
|
||||
let mut distances = vec![0.0; num_vectors];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
// All distances should be valid
|
||||
for dist in &distances {
|
||||
assert!(dist.is_finite() && *dist >= 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrent Access Tests
|
||||
// ============================================================================
|
||||
|
||||
mod concurrent_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_soa_concurrent_reads() {
|
||||
// Create and populate storage
|
||||
let mut storage = SoAVectorStorage::new(64, 16);
|
||||
|
||||
for i in 0..1000 {
|
||||
let vec: Vec<f32> = (0..64).map(|j| (i * 64 + j) as f32 * 0.01).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
let storage = Arc::new(storage);
|
||||
let num_threads = 8;
|
||||
let barrier = Arc::new(Barrier::new(num_threads));
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let storage_clone = Arc::clone(&storage);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
barrier_clone.wait();
|
||||
|
||||
// Each thread performs many reads
|
||||
for i in 0..100 {
|
||||
let idx = (thread_id * 100 + i) % 1000;
|
||||
|
||||
// Read dimension slices
|
||||
let dim_slice = storage_clone.dimension_slice(idx % 64);
|
||||
assert!(!dim_slice.is_empty());
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_concurrent_batch_distances() {
|
||||
let mut storage = SoAVectorStorage::new(32, 16);
|
||||
|
||||
for i in 0..500 {
|
||||
let vec: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
let storage = Arc::new(storage);
|
||||
let num_threads = 4;
|
||||
let barrier = Arc::new(Barrier::new(num_threads));
|
||||
let mut handles = vec![];
|
||||
|
||||
for thread_id in 0..num_threads {
|
||||
let storage_clone = Arc::clone(&storage);
|
||||
let barrier_clone = Arc::clone(&barrier);
|
||||
|
||||
let handle = thread::spawn(move || {
|
||||
barrier_clone.wait();
|
||||
|
||||
for i in 0..50 {
|
||||
let query: Vec<f32> = (0..32)
|
||||
.map(|j| ((thread_id * 50 + i) * 32 + j) as f32 * 0.01)
|
||||
.collect();
|
||||
let mut distances = vec![0.0; 500];
|
||||
|
||||
storage_clone.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
// Verify results
|
||||
for dist in &distances {
|
||||
assert!(dist.is_finite());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases
|
||||
// ============================================================================
|
||||
|
||||
mod edge_cases {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_soa_single_vector() {
|
||||
let mut storage = SoAVectorStorage::new(3, 1);
|
||||
storage.push(&[1.0, 2.0, 3.0]);
|
||||
|
||||
assert_eq!(storage.len(), 1);
|
||||
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![1.0, 2.0, 3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_single_dimension() {
|
||||
let mut storage = SoAVectorStorage::new(1, 4);
|
||||
|
||||
storage.push(&[1.0]);
|
||||
storage.push(&[2.0]);
|
||||
storage.push(&[3.0]);
|
||||
|
||||
let dim0 = storage.dimension_slice(0);
|
||||
assert_eq!(dim0, &[1.0, 2.0, 3.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena_exact_capacity() {
|
||||
let arena = Arena::new(1024);
|
||||
let mut vec: ArenaVec<f32> = arena.alloc_vec(5);
|
||||
|
||||
// Fill to exactly capacity
|
||||
for i in 0..5 {
|
||||
vec.push(i as f32);
|
||||
}
|
||||
|
||||
assert_eq!(vec.len(), 5);
|
||||
assert_eq!(vec.capacity(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_zeros() {
|
||||
let mut storage = SoAVectorStorage::new(4, 4);
|
||||
|
||||
storage.push(&[0.0, 0.0, 0.0, 0.0]);
|
||||
storage.push(&[0.0, 0.0, 0.0, 0.0]);
|
||||
|
||||
let query = vec![0.0; 4];
|
||||
let mut distances = vec![0.0; 2];
|
||||
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
|
||||
assert!(distances[0] < 1e-6);
|
||||
assert!(distances[1] < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_negative_values() {
|
||||
let mut storage = SoAVectorStorage::new(3, 4);
|
||||
|
||||
storage.push(&[-1.0, -2.0, -3.0]);
|
||||
storage.push(&[-4.0, -5.0, -6.0]);
|
||||
|
||||
let mut output = vec![0.0; 3];
|
||||
storage.get(0, &mut output);
|
||||
assert_eq!(output, vec![-1.0, -2.0, -3.0]);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Characteristics Tests
|
||||
// ============================================================================
|
||||
|
||||
mod performance_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_arena_allocation_performance() {
|
||||
// This test verifies that arena allocation is efficient
|
||||
let arena = Arena::new(1024 * 1024); // 1MB
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _ in 0..100000 {
|
||||
let _vec: ArenaVec<f32> = arena.alloc_vec(10);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Should complete quickly (< 1 second for 100k allocations)
|
||||
assert!(
|
||||
duration.as_millis() < 1000,
|
||||
"Arena allocation took too long: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_dimension_access_pattern() {
|
||||
let mut storage = SoAVectorStorage::new(128, 16);
|
||||
|
||||
for i in 0..1000 {
|
||||
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
// Test dimension-wise access (this should be cache-efficient)
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for dim in 0..128 {
|
||||
let slice = storage.dimension_slice(dim);
|
||||
let _sum: f32 = slice.iter().sum();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Dimension-wise access should be fast due to cache locality
|
||||
assert!(
|
||||
duration.as_millis() < 100,
|
||||
"Dimension access took too long: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_soa_batch_distance_performance() {
|
||||
let mut storage = SoAVectorStorage::new(128, 128);
|
||||
|
||||
for i in 0..1000 {
|
||||
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.001).collect();
|
||||
storage.push(&vec);
|
||||
}
|
||||
|
||||
let query: Vec<f32> = (0..128).map(|j| j as f32 * 0.001).collect();
|
||||
let mut distances = vec![0.0; 1000];
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _ in 0..100 {
|
||||
storage.batch_euclidean_distances(&query, &mut distances);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
// 100 batch operations on 1000 vectors should be fast
|
||||
assert!(
|
||||
duration.as_millis() < 500,
|
||||
"Batch distance took too long: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
}
|
||||
767
crates/ruvector-core/tests/test_quantization.rs
Normal file
767
crates/ruvector-core/tests/test_quantization.rs
Normal file
@@ -0,0 +1,767 @@
|
||||
//! Quantization Accuracy Tests
|
||||
//!
|
||||
//! This module provides comprehensive tests for quantization techniques,
|
||||
//! verifying accuracy, compression ratios, and distance calculations.
|
||||
|
||||
use ruvector_core::quantization::*;
|
||||
|
||||
// ============================================================================
|
||||
// Scalar Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
mod scalar_quantization_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_basic() {
|
||||
let vector = vec![0.0, 0.5, 1.0, 1.5, 2.0];
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(quantized.data.len(), 5);
|
||||
assert!(quantized.scale > 0.0, "Scale should be positive");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_min_max() {
|
||||
let vector = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
|
||||
// Min should be -10.0
|
||||
assert!((quantized.min - (-10.0)).abs() < 0.001);
|
||||
|
||||
// Scale should map range 20 to 255
|
||||
let expected_scale = 20.0 / 255.0;
|
||||
assert!(
|
||||
(quantized.scale - expected_scale).abs() < 0.001,
|
||||
"Scale mismatch: expected {}, got {}",
|
||||
expected_scale,
|
||||
quantized.scale
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_reconstruction_accuracy() {
|
||||
let test_vectors = vec![
|
||||
vec![1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
vec![0.0, 0.25, 0.5, 0.75, 1.0],
|
||||
vec![-100.0, 0.0, 100.0],
|
||||
vec![0.001, 0.002, 0.003, 0.004, 0.005],
|
||||
];
|
||||
|
||||
for vector in test_vectors {
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(vector.len(), reconstructed.len());
|
||||
|
||||
// Calculate max error based on range
|
||||
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
|
||||
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
let max_allowed_error = (max - min) / 128.0; // Allow 2 quantization steps error
|
||||
|
||||
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
|
||||
let error = (orig - recon).abs();
|
||||
assert!(
|
||||
error <= max_allowed_error,
|
||||
"Reconstruction error {} exceeds max {} for value {}",
|
||||
error,
|
||||
max_allowed_error,
|
||||
orig
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_constant_values() {
|
||||
let constant = vec![5.0, 5.0, 5.0, 5.0, 5.0];
|
||||
let quantized = ScalarQuantized::quantize(&constant);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in constant.iter().zip(reconstructed.iter()) {
|
||||
assert!(
|
||||
(orig - recon).abs() < 0.1,
|
||||
"Constant value reconstruction failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance_self() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
|
||||
let distance = quantized.distance(&quantized);
|
||||
assert!(
|
||||
distance < 0.001,
|
||||
"Distance to self should be ~0, got {}",
|
||||
distance
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance_symmetry() {
|
||||
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let v2 = vec![5.0, 4.0, 3.0, 2.0, 1.0];
|
||||
|
||||
let q1 = ScalarQuantized::quantize(&v1);
|
||||
let q2 = ScalarQuantized::quantize(&v2);
|
||||
|
||||
let dist_ab = q1.distance(&q2);
|
||||
let dist_ba = q2.distance(&q1);
|
||||
|
||||
assert!(
|
||||
(dist_ab - dist_ba).abs() < 0.1,
|
||||
"Distance not symmetric: {} vs {}",
|
||||
dist_ab,
|
||||
dist_ba
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance_triangle_inequality() {
|
||||
let v1 = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let v2 = vec![0.0, 1.0, 0.0, 0.0];
|
||||
let v3 = vec![0.0, 0.0, 1.0, 0.0];
|
||||
|
||||
let q1 = ScalarQuantized::quantize(&v1);
|
||||
let q2 = ScalarQuantized::quantize(&v2);
|
||||
let q3 = ScalarQuantized::quantize(&v3);
|
||||
|
||||
let d12 = q1.distance(&q2);
|
||||
let d23 = q2.distance(&q3);
|
||||
let d13 = q1.distance(&q3);
|
||||
|
||||
// Triangle inequality: d(1,3) <= d(1,2) + d(2,3)
|
||||
// Allow some slack for quantization errors
|
||||
assert!(
|
||||
d13 <= d12 + d23 + 0.5,
|
||||
"Triangle inequality violated: {} > {} + {}",
|
||||
d13,
|
||||
d12,
|
||||
d23
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_common_embedding_sizes() {
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let vector: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
|
||||
let quantized = ScalarQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(quantized.data.len(), dim);
|
||||
assert_eq!(reconstructed.len(), dim);
|
||||
|
||||
// Verify compression ratio (4x for f32 -> u8)
|
||||
let original_size = dim * std::mem::size_of::<f32>();
|
||||
let quantized_size = quantized.data.len() + std::mem::size_of::<f32>() * 2; // data + min + scale
|
||||
assert!(
|
||||
quantized_size < original_size,
|
||||
"No compression achieved for dim {}",
|
||||
dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_extreme_values() {
|
||||
// Test with large values
|
||||
let large = vec![1e10, 2e10, 3e10];
|
||||
let quantized = ScalarQuantized::quantize(&large);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in large.iter().zip(reconstructed.iter()) {
|
||||
let relative_error = (orig - recon).abs() / orig.abs();
|
||||
assert!(
|
||||
relative_error < 0.02,
|
||||
"Large value reconstruction error too high: {}",
|
||||
relative_error
|
||||
);
|
||||
}
|
||||
|
||||
// Test with small values
|
||||
let small = vec![1e-5, 2e-5, 3e-5, 4e-5, 5e-5];
|
||||
let quantized = ScalarQuantized::quantize(&small);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in small.iter().zip(reconstructed.iter()) {
|
||||
let error = (orig - recon).abs();
|
||||
let range = 4e-5;
|
||||
assert!(
|
||||
error < range / 100.0,
|
||||
"Small value reconstruction error too high: {}",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_negative_values() {
|
||||
let negative = vec![-5.0, -4.0, -3.0, -2.0, -1.0];
|
||||
let quantized = ScalarQuantized::quantize(&negative);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in negative.iter().zip(reconstructed.iter()) {
|
||||
assert!(
|
||||
(orig - recon).abs() < 0.1,
|
||||
"Negative value reconstruction failed: {} vs {}",
|
||||
orig,
|
||||
recon
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Binary Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
mod binary_quantization_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_basic() {
|
||||
let vector = vec![1.0, -1.0, 0.5, -0.5, 0.1];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(quantized.dimensions, 5);
|
||||
assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_packing() {
|
||||
// Test byte packing
|
||||
for dim in 1..=32 {
|
||||
let vector: Vec<f32> = (0..dim)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
let expected_bytes = (dim + 7) / 8;
|
||||
assert_eq!(
|
||||
quantized.bits.len(),
|
||||
expected_bytes,
|
||||
"Wrong byte count for dim {}",
|
||||
dim
|
||||
);
|
||||
assert_eq!(quantized.dimensions, dim);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_sign_preservation() {
|
||||
let test_vectors = vec![
|
||||
vec![1.0, -1.0, 2.0, -2.0],
|
||||
vec![0.001, -0.001, 100.0, -100.0],
|
||||
vec![f32::MAX / 2.0, f32::MIN / 2.0],
|
||||
];
|
||||
|
||||
for vector in test_vectors {
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
|
||||
if *orig > 0.0 {
|
||||
assert_eq!(*recon, 1.0, "Positive value should reconstruct to 1.0");
|
||||
} else if *orig < 0.0 {
|
||||
assert_eq!(*recon, -1.0, "Negative value should reconstruct to -1.0");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_zero_handling() {
|
||||
let vector = vec![0.0, 0.0, 0.0, 0.0];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
// Zero maps to negative bit (0), which reconstructs to -1.0
|
||||
for val in reconstructed {
|
||||
assert_eq!(val, -1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_hamming_distance() {
|
||||
// Test specific Hamming distance cases
|
||||
let cases = vec![
|
||||
// (v1, v2, expected_distance)
|
||||
(vec![1.0, 1.0, 1.0, 1.0], vec![1.0, 1.0, 1.0, 1.0], 0.0), // identical
|
||||
(vec![1.0, 1.0, 1.0, 1.0], vec![-1.0, -1.0, -1.0, -1.0], 4.0), // opposite
|
||||
(vec![1.0, 1.0, -1.0, -1.0], vec![1.0, -1.0, -1.0, 1.0], 2.0), // 2 bits differ
|
||||
(vec![1.0, -1.0, 1.0, -1.0], vec![-1.0, 1.0, -1.0, 1.0], 4.0), // all differ
|
||||
];
|
||||
|
||||
for (v1, v2, expected) in cases {
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let distance = q1.distance(&q2);
|
||||
assert!(
|
||||
(distance - expected).abs() < 0.001,
|
||||
"Hamming distance mismatch: expected {}, got {}",
|
||||
expected,
|
||||
distance
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_distance_symmetry() {
|
||||
let v1 = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
|
||||
let v2 = vec![-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0];
|
||||
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let d12 = q1.distance(&q2);
|
||||
let d21 = q2.distance(&q1);
|
||||
|
||||
assert_eq!(d12, d21, "Binary distance should be symmetric");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_distance_bounds() {
|
||||
for dim in [8, 16, 32, 64, 128, 256] {
|
||||
let v1: Vec<f32> = (0..dim)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
let v2: Vec<f32> = (0..dim)
|
||||
.map(|i| if i % 3 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
|
||||
let q1 = BinaryQuantized::quantize(&v1);
|
||||
let q2 = BinaryQuantized::quantize(&v2);
|
||||
|
||||
let distance = q1.distance(&q2);
|
||||
|
||||
// Distance should be in [0, dim]
|
||||
assert!(
|
||||
distance >= 0.0 && distance <= dim as f32,
|
||||
"Distance {} out of bounds [0, {}]",
|
||||
distance,
|
||||
dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_compression_ratio() {
|
||||
for dim in [128, 256, 512, 1024] {
|
||||
let vector: Vec<f32> = (0..dim)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
// f32 to 1 bit = theoretical 32x compression for data only
|
||||
// Actual ratio depends on overhead but should be significant
|
||||
let original_data_size = dim * std::mem::size_of::<f32>();
|
||||
let quantized_data_size = quantized.bits.len();
|
||||
|
||||
let data_compression_ratio = original_data_size as f32 / quantized_data_size as f32;
|
||||
assert!(
|
||||
data_compression_ratio >= 31.0,
|
||||
"Data compression ratio {} less than expected ~32x for dim {}",
|
||||
data_compression_ratio,
|
||||
dim
|
||||
);
|
||||
|
||||
// Verify bits.len() is correct: ceil(dim / 8)
|
||||
assert_eq!(quantized.bits.len(), (dim + 7) / 8);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_common_embedding_sizes() {
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let vector: Vec<f32> = (0..dim).map(|i| (i as f32 - dim as f32 / 2.0)).collect();
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(reconstructed.len(), dim);
|
||||
|
||||
// Check all values are +1 or -1
|
||||
for val in &reconstructed {
|
||||
assert!(*val == 1.0 || *val == -1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Product Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
mod product_quantization_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_training() {
|
||||
let vectors: Vec<Vec<f32>> = (0..100)
|
||||
.map(|i| (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect())
|
||||
.collect();
|
||||
|
||||
let num_subspaces = 4;
|
||||
let codebook_size = 16;
|
||||
|
||||
let pq = ProductQuantized::train(&vectors, num_subspaces, codebook_size, 10).unwrap();
|
||||
|
||||
assert_eq!(pq.codebooks.len(), num_subspaces);
|
||||
for codebook in &pq.codebooks {
|
||||
assert_eq!(codebook.len(), codebook_size);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_encode() {
|
||||
let vectors: Vec<Vec<f32>> = (0..100)
|
||||
.map(|i| (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect())
|
||||
.collect();
|
||||
|
||||
let num_subspaces = 4;
|
||||
let codebook_size = 16;
|
||||
|
||||
let pq = ProductQuantized::train(&vectors, num_subspaces, codebook_size, 10).unwrap();
|
||||
|
||||
let test_vector: Vec<f32> = (0..32).map(|i| i as f32 * 0.02).collect();
|
||||
let codes = pq.encode(&test_vector);
|
||||
|
||||
assert_eq!(codes.len(), num_subspaces);
|
||||
for code in &codes {
|
||||
assert!(*code < codebook_size as u8);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_empty_input_error() {
|
||||
let result = ProductQuantized::train(&[], 4, 16, 10);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_codebook_size_limit() {
|
||||
let vectors: Vec<Vec<f32>> = (0..10)
|
||||
.map(|i| (0..16).map(|j| (i * 16 + j) as f32).collect())
|
||||
.collect();
|
||||
|
||||
// Codebook size > 256 should error
|
||||
let result = ProductQuantized::train(&vectors, 4, 300, 10);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_quantization_various_subspaces() {
|
||||
let dim = 64;
|
||||
let vectors: Vec<Vec<f32>> = (0..200)
|
||||
.map(|i| (0..dim).map(|j| (i * dim + j) as f32 * 0.001).collect())
|
||||
.collect();
|
||||
|
||||
for num_subspaces in [1, 2, 4, 8, 16] {
|
||||
let pq = ProductQuantized::train(&vectors, num_subspaces, 16, 5).unwrap();
|
||||
|
||||
assert_eq!(pq.codebooks.len(), num_subspaces);
|
||||
|
||||
let subspace_dim = dim / num_subspaces;
|
||||
for codebook in &pq.codebooks {
|
||||
for centroid in codebook {
|
||||
assert_eq!(centroid.len(), subspace_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Comparative Tests
|
||||
// ============================================================================
|
||||
|
||||
mod comparative_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_vs_binary_reconstruction() {
|
||||
let vector = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0];
|
||||
|
||||
let scalar = ScalarQuantized::quantize(&vector);
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
|
||||
let scalar_recon = scalar.reconstruct();
|
||||
let binary_recon = binary.reconstruct();
|
||||
|
||||
// Scalar should have better accuracy
|
||||
let scalar_error: f32 = vector
|
||||
.iter()
|
||||
.zip(scalar_recon.iter())
|
||||
.map(|(o, r)| (o - r).abs())
|
||||
.sum::<f32>()
|
||||
/ vector.len() as f32;
|
||||
|
||||
// Binary only preserves sign
|
||||
for (orig, recon) in vector.iter().zip(binary_recon.iter()) {
|
||||
assert_eq!(orig.signum(), recon.signum());
|
||||
}
|
||||
|
||||
// Scalar error should be small
|
||||
assert!(
|
||||
scalar_error < 0.5,
|
||||
"Scalar reconstruction error {} too high",
|
||||
scalar_error
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_preserves_relative_ordering() {
|
||||
// Test that vectors closest in original space are also closest in quantized space
|
||||
let v1 = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let v2 = vec![0.9, 0.1, 0.0, 0.0]; // close to v1
|
||||
let v3 = vec![0.0, 0.0, 0.0, 1.0]; // far from v1
|
||||
|
||||
// For scalar quantization
|
||||
let q1_s = ScalarQuantized::quantize(&v1);
|
||||
let q2_s = ScalarQuantized::quantize(&v2);
|
||||
let q3_s = ScalarQuantized::quantize(&v3);
|
||||
|
||||
let d12_s = q1_s.distance(&q2_s);
|
||||
let d13_s = q1_s.distance(&q3_s);
|
||||
|
||||
// v2 should be closer to v1 than v3
|
||||
assert!(
|
||||
d12_s < d13_s,
|
||||
"Scalar: v2 should be closer to v1 than v3: {} vs {}",
|
||||
d12_s,
|
||||
d13_s
|
||||
);
|
||||
|
||||
// For binary quantization
|
||||
let q1_b = BinaryQuantized::quantize(&v1);
|
||||
let q2_b = BinaryQuantized::quantize(&v2);
|
||||
let q3_b = BinaryQuantized::quantize(&v3);
|
||||
|
||||
let d12_b = q1_b.distance(&q2_b);
|
||||
let d13_b = q1_b.distance(&q3_b);
|
||||
|
||||
// Same relative ordering should hold
|
||||
assert!(
|
||||
d12_b <= d13_b,
|
||||
"Binary: v2 should be at most as far as v3: {} vs {}",
|
||||
d12_b,
|
||||
d13_b
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratios() {
|
||||
let dim = 512;
|
||||
let vector: Vec<f32> = (0..dim).map(|i| i as f32 * 0.01).collect();
|
||||
|
||||
// Original size
|
||||
let original_size = dim * std::mem::size_of::<f32>(); // 2048 bytes
|
||||
|
||||
// Scalar quantization: u8 per element + 2 floats for min/scale
|
||||
let scalar = ScalarQuantized::quantize(&vector);
|
||||
let scalar_size = scalar.data.len() + 2 * std::mem::size_of::<f32>(); // ~520 bytes
|
||||
let scalar_ratio = original_size as f32 / scalar_size as f32;
|
||||
|
||||
// Binary quantization: 1 bit per element + usize for dimensions
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
let binary_size = binary.bits.len() + std::mem::size_of::<usize>(); // ~72 bytes
|
||||
let binary_ratio = original_size as f32 / binary_size as f32;
|
||||
|
||||
println!("Original: {} bytes", original_size);
|
||||
println!(
|
||||
"Scalar: {} bytes ({:.1}x compression)",
|
||||
scalar_size, scalar_ratio
|
||||
);
|
||||
println!(
|
||||
"Binary: {} bytes ({:.1}x compression)",
|
||||
binary_size, binary_ratio
|
||||
);
|
||||
|
||||
// Verify expected ratios
|
||||
assert!(scalar_ratio > 3.5, "Scalar should achieve ~4x compression");
|
||||
assert!(
|
||||
binary_ratio > 25.0,
|
||||
"Binary should achieve ~32x compression"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Cases and Error Handling
|
||||
// ============================================================================
|
||||
|
||||
mod edge_cases {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_single_element_vector() {
|
||||
let vector = vec![42.0];
|
||||
|
||||
let scalar = ScalarQuantized::quantize(&vector);
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(scalar.data.len(), 1);
|
||||
assert_eq!(binary.bits.len(), 1);
|
||||
assert_eq!(binary.dimensions, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_vector() {
|
||||
let dim = 8192;
|
||||
let vector: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
|
||||
|
||||
let scalar = ScalarQuantized::quantize(&vector);
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(scalar.data.len(), dim);
|
||||
assert_eq!(binary.dimensions, dim);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_positive() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = binary.reconstruct();
|
||||
|
||||
// All values should reconstruct to 1.0
|
||||
for val in reconstructed {
|
||||
assert_eq!(val, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_negative() {
|
||||
let vector = vec![-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0];
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = binary.reconstruct();
|
||||
|
||||
// All values should reconstruct to -1.0
|
||||
for val in reconstructed {
|
||||
assert_eq!(val, -1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_alternating_pattern() {
|
||||
let vector: Vec<f32> = (0..100)
|
||||
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
|
||||
.collect();
|
||||
|
||||
let binary = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = binary.reconstruct();
|
||||
|
||||
for (i, val) in reconstructed.iter().enumerate() {
|
||||
let expected = if i % 2 == 0 { 1.0 } else { -1.0 };
|
||||
assert_eq!(*val, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_deterministic() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
|
||||
// Quantize multiple times - should get same result
|
||||
let q1 = ScalarQuantized::quantize(&vector);
|
||||
let q2 = ScalarQuantized::quantize(&vector);
|
||||
|
||||
assert_eq!(q1.data, q2.data);
|
||||
assert_eq!(q1.min, q2.min);
|
||||
assert_eq!(q1.scale, q2.scale);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Performance Characteristic Tests
|
||||
// ============================================================================
|
||||
|
||||
mod performance_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_speed() {
|
||||
let vector: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _ in 0..10000 {
|
||||
let _ = ScalarQuantized::quantize(&vector);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let ops_per_sec = 10000.0 / duration.as_secs_f64();
|
||||
|
||||
println!(
|
||||
"Scalar quantization: {:.0} ops/sec for 1024-dim vectors",
|
||||
ops_per_sec
|
||||
);
|
||||
|
||||
// Should be fast
|
||||
assert!(
|
||||
duration.as_millis() < 5000,
|
||||
"Scalar quantization too slow: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_speed() {
|
||||
let vector: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for _ in 0..10000 {
|
||||
let _ = BinaryQuantized::quantize(&vector);
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
let ops_per_sec = 10000.0 / duration.as_secs_f64();
|
||||
|
||||
println!(
|
||||
"Binary quantization: {:.0} ops/sec for 1024-dim vectors",
|
||||
ops_per_sec
|
||||
);
|
||||
|
||||
// Should be fast
|
||||
assert!(
|
||||
duration.as_millis() < 5000,
|
||||
"Binary quantization too slow: {:?}",
|
||||
duration
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_calculation_speed() {
|
||||
let v1: Vec<f32> = (0..512).map(|i| i as f32 * 0.01).collect();
|
||||
let v2: Vec<f32> = (0..512).map(|i| (i as f32 * 0.01) + 0.5).collect();
|
||||
|
||||
let q1_s = ScalarQuantized::quantize(&v1);
|
||||
let q2_s = ScalarQuantized::quantize(&v2);
|
||||
|
||||
let q1_b = BinaryQuantized::quantize(&v1);
|
||||
let q2_b = BinaryQuantized::quantize(&v2);
|
||||
|
||||
// Scalar distance
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..100000 {
|
||||
let _ = q1_s.distance(&q2_s);
|
||||
}
|
||||
let scalar_duration = start.elapsed();
|
||||
|
||||
// Binary distance (Hamming)
|
||||
let start = std::time::Instant::now();
|
||||
for _ in 0..100000 {
|
||||
let _ = q1_b.distance(&q2_b);
|
||||
}
|
||||
let binary_duration = start.elapsed();
|
||||
|
||||
println!("Scalar distance: {:?} for 100k ops", scalar_duration);
|
||||
println!("Binary distance: {:?} for 100k ops", binary_duration);
|
||||
|
||||
// Binary should be faster (just XOR and popcount)
|
||||
// But both should be fast
|
||||
assert!(scalar_duration.as_millis() < 1000);
|
||||
assert!(binary_duration.as_millis() < 1000);
|
||||
}
|
||||
}
|
||||
552
crates/ruvector-core/tests/test_simd_correctness.rs
Normal file
552
crates/ruvector-core/tests/test_simd_correctness.rs
Normal file
@@ -0,0 +1,552 @@
|
||||
//! SIMD Correctness Tests
|
||||
//!
|
||||
//! This module verifies that SIMD implementations produce identical results
|
||||
//! to scalar fallback implementations across various input sizes and edge cases.
|
||||
|
||||
use ruvector_core::simd_intrinsics::*;
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions for Scalar Computations (Ground Truth)
|
||||
// ============================================================================
|
||||
|
||||
fn scalar_euclidean(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| {
|
||||
let diff = x - y;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
|
||||
fn scalar_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm_a > f32::EPSILON && norm_b > f32::EPSILON {
|
||||
dot / (norm_a * norm_b)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
fn scalar_manhattan(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Euclidean Distance Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_vs_scalar_small() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-5,
|
||||
"Euclidean mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_vs_scalar_exact_simd_width() {
|
||||
// Test with exact AVX2 width (8 floats)
|
||||
let a: Vec<f32> = (0..8).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (0..8).map(|i| (i + 1) as f32).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-5,
|
||||
"8-element Euclidean mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_vs_scalar_non_aligned() {
|
||||
// Test with non-SIMD-aligned sizes
|
||||
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65, 127, 129] {
|
||||
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.01,
|
||||
"Size {} Euclidean mismatch: SIMD={}, scalar={}",
|
||||
size,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_vs_scalar_common_embedding_sizes() {
|
||||
// Test common embedding dimensions
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let a: Vec<f32> = (0..dim).map(|i| ((i % 100) as f32) * 0.01).collect();
|
||||
let b: Vec<f32> = (0..dim).map(|i| (((i + 50) % 100) as f32) * 0.01).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.1,
|
||||
"Dim {} Euclidean mismatch: SIMD={}, scalar={}",
|
||||
dim,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_identical_vectors() {
|
||||
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let result = euclidean_distance_simd(&v, &v);
|
||||
assert!(
|
||||
result < 1e-6,
|
||||
"Distance to self should be ~0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_zero_vectors() {
|
||||
let zeros = vec![0.0; 16];
|
||||
let result = euclidean_distance_simd(&zeros, &zeros);
|
||||
assert!(result < 1e-6, "Distance between zeros should be 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_negative_values() {
|
||||
let a = vec![-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0];
|
||||
let b = vec![-5.0, -6.0, -7.0, -8.0, -9.0, -10.0, -11.0, -12.0];
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-5,
|
||||
"Negative values Euclidean mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_simd_mixed_signs() {
|
||||
let a = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0];
|
||||
let b = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0];
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"Mixed signs Euclidean mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Dot Product Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_vs_scalar_small() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let simd_result = dot_product_simd(&a, &b);
|
||||
let scalar_result = scalar_dot_product(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"Dot product mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_vs_scalar_exact_simd_width() {
|
||||
let a: Vec<f32> = (1..=8).map(|i| i as f32).collect();
|
||||
let b: Vec<f32> = (1..=8).map(|i| i as f32).collect();
|
||||
|
||||
let simd_result = dot_product_simd(&a, &b);
|
||||
let scalar_result = scalar_dot_product(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"8-element dot product mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_vs_scalar_non_aligned() {
|
||||
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65, 127, 129] {
|
||||
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = dot_product_simd(&a, &b);
|
||||
let scalar_result = scalar_dot_product(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.1,
|
||||
"Size {} dot product mismatch: SIMD={}, scalar={}",
|
||||
size,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_common_embedding_sizes() {
|
||||
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
|
||||
let a: Vec<f32> = (0..dim).map(|i| ((i % 10) as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (0..dim).map(|i| (((i + 5) % 10) as f32) * 0.1).collect();
|
||||
|
||||
let simd_result = dot_product_simd(&a, &b);
|
||||
let scalar_result = scalar_dot_product(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.5,
|
||||
"Dim {} dot product mismatch: SIMD={}, scalar={}",
|
||||
dim,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_simd_orthogonal_vectors() {
|
||||
let a = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
|
||||
let result = dot_product_simd(&a, &b);
|
||||
assert!(result.abs() < 1e-6, "Orthogonal dot product should be 0");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cosine Similarity Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_vs_scalar_small() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let simd_result = cosine_similarity_simd(&a, &b);
|
||||
let scalar_result = scalar_cosine_similarity(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"Cosine mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_vs_scalar_non_aligned() {
|
||||
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65] {
|
||||
let a: Vec<f32> = (1..=size).map(|i| (i as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (1..=size).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = cosine_similarity_simd(&a, &b);
|
||||
let scalar_result = scalar_cosine_similarity(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.01,
|
||||
"Size {} cosine mismatch: SIMD={}, scalar={}",
|
||||
size,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_identical_vectors() {
|
||||
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let result = cosine_similarity_simd(&v, &v);
|
||||
assert!(
|
||||
(result - 1.0).abs() < 1e-5,
|
||||
"Identical vectors should have similarity 1.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_opposite_vectors() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![-1.0, -2.0, -3.0, -4.0];
|
||||
|
||||
let result = cosine_similarity_simd(&a, &b);
|
||||
assert!(
|
||||
(result + 1.0).abs() < 1e-5,
|
||||
"Opposite vectors should have similarity -1.0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_simd_orthogonal_vectors() {
|
||||
let a = vec![1.0, 0.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0, 0.0];
|
||||
|
||||
let result = cosine_similarity_simd(&a, &b);
|
||||
assert!(
|
||||
result.abs() < 1e-5,
|
||||
"Orthogonal vectors should have similarity 0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Manhattan Distance Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_simd_vs_scalar_small() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let b = vec![5.0, 6.0, 7.0, 8.0];
|
||||
|
||||
let simd_result = manhattan_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_manhattan(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 1e-4,
|
||||
"Manhattan mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_simd_vs_scalar_non_aligned() {
|
||||
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65] {
|
||||
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
|
||||
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
|
||||
|
||||
let simd_result = manhattan_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_manhattan(&a, &b);
|
||||
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.01,
|
||||
"Size {} Manhattan mismatch: SIMD={}, scalar={}",
|
||||
size,
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_simd_identical_vectors() {
|
||||
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let result = manhattan_distance_simd(&v, &v);
|
||||
assert!(
|
||||
result < 1e-6,
|
||||
"Manhattan to self should be 0, got {}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Numerical Stability Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_simd_large_values() {
|
||||
// Test with large but finite values
|
||||
let large_val = 1e10;
|
||||
let a: Vec<f32> = (0..16).map(|i| large_val + (i as f32)).collect();
|
||||
let b: Vec<f32> = (0..16).map(|i| large_val + (i as f32) + 1.0).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
simd_result.is_finite() && scalar_result.is_finite(),
|
||||
"Results should be finite for large values"
|
||||
);
|
||||
assert!(
|
||||
(simd_result - scalar_result).abs() < 0.1,
|
||||
"Large values mismatch: SIMD={}, scalar={}",
|
||||
simd_result,
|
||||
scalar_result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_small_values() {
|
||||
// Test with small values
|
||||
let small_val = 1e-10;
|
||||
let a: Vec<f32> = (0..16).map(|i| small_val * (i as f32 + 1.0)).collect();
|
||||
let b: Vec<f32> = (0..16).map(|i| small_val * (i as f32 + 2.0)).collect();
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
simd_result.is_finite() && scalar_result.is_finite(),
|
||||
"Results should be finite for small values"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_denormalized_values() {
|
||||
// Test with denormalized floats
|
||||
let a = vec![f32::MIN_POSITIVE; 8];
|
||||
let b = vec![f32::MIN_POSITIVE * 2.0; 8];
|
||||
|
||||
let simd_result = euclidean_distance_simd(&a, &b);
|
||||
let scalar_result = scalar_euclidean(&a, &b);
|
||||
|
||||
assert!(
|
||||
simd_result.is_finite() && scalar_result.is_finite(),
|
||||
"Results should be finite for denormalized values"
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Legacy Alias Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_legacy_avx2_aliases_match_simd() {
|
||||
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let b = vec![9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0];
|
||||
|
||||
// Legacy AVX2 functions should produce same results as SIMD functions
|
||||
assert_eq!(
|
||||
euclidean_distance_avx2(&a, &b),
|
||||
euclidean_distance_simd(&a, &b)
|
||||
);
|
||||
assert_eq!(dot_product_avx2(&a, &b), dot_product_simd(&a, &b));
|
||||
assert_eq!(
|
||||
cosine_similarity_avx2(&a, &b),
|
||||
cosine_similarity_simd(&a, &b)
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Batch Operation Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_simd_batch_consistency() {
|
||||
let query: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
|
||||
let vectors: Vec<Vec<f32>> = (0..100)
|
||||
.map(|j| (0..64).map(|i| ((i + j) as f32) * 0.1).collect())
|
||||
.collect();
|
||||
|
||||
// Compute distances using SIMD
|
||||
let simd_distances: Vec<f32> = vectors
|
||||
.iter()
|
||||
.map(|v| euclidean_distance_simd(&query, v))
|
||||
.collect();
|
||||
|
||||
// Compute distances using scalar
|
||||
let scalar_distances: Vec<f32> = vectors
|
||||
.iter()
|
||||
.map(|v| scalar_euclidean(&query, v))
|
||||
.collect();
|
||||
|
||||
// Compare
|
||||
for (i, (simd, scalar)) in simd_distances
|
||||
.iter()
|
||||
.zip(scalar_distances.iter())
|
||||
.enumerate()
|
||||
{
|
||||
assert!(
|
||||
(simd - scalar).abs() < 0.01,
|
||||
"Vector {} mismatch: SIMD={}, scalar={}",
|
||||
i,
|
||||
simd,
|
||||
scalar
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Edge Case Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_simd_single_element() {
|
||||
let a = vec![1.0];
|
||||
let b = vec![2.0];
|
||||
|
||||
let euclidean = euclidean_distance_simd(&a, &b);
|
||||
let dot = dot_product_simd(&a, &b);
|
||||
let manhattan = manhattan_distance_simd(&a, &b);
|
||||
|
||||
assert!((euclidean - 1.0).abs() < 1e-6);
|
||||
assert!((dot - 2.0).abs() < 1e-6);
|
||||
assert!((manhattan - 1.0).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simd_two_elements() {
|
||||
let a = vec![1.0, 0.0];
|
||||
let b = vec![0.0, 1.0];
|
||||
|
||||
let euclidean = euclidean_distance_simd(&a, &b);
|
||||
let expected = (2.0_f32).sqrt(); // sqrt(1 + 1)
|
||||
|
||||
assert!(
|
||||
(euclidean - expected).abs() < 1e-5,
|
||||
"Two element test: got {}, expected {}",
|
||||
euclidean,
|
||||
expected
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Stress Tests for SIMD
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_simd_many_operations() {
|
||||
let a: Vec<f32> = (0..512).map(|i| (i as f32) * 0.001).collect();
|
||||
let b: Vec<f32> = (0..512).map(|i| ((i + 256) as f32) * 0.001).collect();
|
||||
|
||||
// Perform many operations to stress test
|
||||
for _ in 0..1000 {
|
||||
let _ = euclidean_distance_simd(&a, &b);
|
||||
let _ = dot_product_simd(&a, &b);
|
||||
let _ = cosine_similarity_simd(&a, &b);
|
||||
let _ = manhattan_distance_simd(&a, &b);
|
||||
}
|
||||
|
||||
// Final verification
|
||||
let result = euclidean_distance_simd(&a, &b);
|
||||
assert!(
|
||||
result.is_finite(),
|
||||
"Result should be finite after stress test"
|
||||
);
|
||||
}
|
||||
555
crates/ruvector-core/tests/unit_tests.rs
Normal file
555
crates/ruvector-core/tests/unit_tests.rs
Normal file
@@ -0,0 +1,555 @@
|
||||
//! Unit tests with mocking using mockall (London School TDD)
|
||||
//!
|
||||
//! These tests use mocks to isolate components and test behavior in isolation.
|
||||
|
||||
use mockall::mock;
|
||||
use mockall::predicate::*;
|
||||
use ruvector_core::error::{Result, RuvectorError};
|
||||
use ruvector_core::types::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ============================================================================
|
||||
// Mock Definitions
|
||||
// ============================================================================
|
||||
|
||||
// Mock for storage operations
|
||||
mock! {
|
||||
pub Storage {
|
||||
fn insert(&self, entry: &VectorEntry) -> Result<VectorId>;
|
||||
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>>;
|
||||
fn get(&self, id: &str) -> Result<Option<VectorEntry>>;
|
||||
fn delete(&self, id: &str) -> Result<bool>;
|
||||
fn len(&self) -> Result<usize>;
|
||||
fn is_empty(&self) -> Result<bool>;
|
||||
fn all_ids(&self) -> Result<Vec<VectorId>>;
|
||||
}
|
||||
}
|
||||
|
||||
// Mock for index operations
|
||||
mock! {
|
||||
pub Index {
|
||||
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()>;
|
||||
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()>;
|
||||
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
|
||||
fn remove(&mut self, id: &VectorId) -> Result<bool>;
|
||||
fn len(&self) -> usize;
|
||||
fn is_empty(&self) -> bool;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Distance Metric Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod distance_tests {
|
||||
use super::*;
|
||||
use ruvector_core::distance::*;
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_same_vector() {
|
||||
let v = vec![1.0, 2.0, 3.0];
|
||||
let dist = euclidean_distance(&v, &v);
|
||||
assert!(dist < 0.001, "Distance to self should be ~0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_orthogonal() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0];
|
||||
let dist = euclidean_distance(&a, &b);
|
||||
assert!(
|
||||
(dist - 1.414).abs() < 0.01,
|
||||
"Expected sqrt(2), got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_parallel_vectors() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![2.0, 4.0, 6.0]; // Parallel to a
|
||||
let dist = cosine_distance(&a, &b);
|
||||
assert!(
|
||||
dist < 0.01,
|
||||
"Parallel vectors should have ~0 cosine distance, got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_orthogonal() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![0.0, 1.0, 0.0];
|
||||
let dist = cosine_distance(&a, &b);
|
||||
assert!(
|
||||
dist > 0.9 && dist < 1.1,
|
||||
"Orthogonal vectors should have distance ~1, got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_product_positive() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
let dist = dot_product_distance(&a, &b);
|
||||
assert!(
|
||||
dist < 0.0,
|
||||
"Dot product distance should be negative for similar vectors"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_manhattan_distance() {
|
||||
let a = vec![0.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 1.0, 1.0];
|
||||
let dist = manhattan_distance(&a, &b);
|
||||
assert!(
|
||||
(dist - 3.0).abs() < 0.001,
|
||||
"Manhattan distance should be 3.0, got {}",
|
||||
dist
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch() {
|
||||
let a = vec![1.0, 2.0];
|
||||
let b = vec![1.0, 2.0, 3.0];
|
||||
let result = distance(&a, &b, DistanceMetric::Euclidean);
|
||||
assert!(result.is_err(), "Should error on dimension mismatch");
|
||||
|
||||
match result {
|
||||
Err(RuvectorError::DimensionMismatch { expected, actual }) => {
|
||||
assert_eq!(expected, 2);
|
||||
assert_eq!(actual, 3);
|
||||
}
|
||||
_ => panic!("Expected DimensionMismatch error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod quantization_tests {
|
||||
use super::*;
|
||||
use ruvector_core::quantization::*;
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_reconstruction() {
|
||||
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let quantized = ScalarQuantized::quantize(&original);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
assert_eq!(original.len(), reconstructed.len());
|
||||
|
||||
for (orig, recon) in original.iter().zip(reconstructed.iter()) {
|
||||
let error = (orig - recon).abs();
|
||||
assert!(error < 0.1, "Reconstruction error {} too large", error);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_quantization_distance() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![1.1, 2.1, 3.1];
|
||||
|
||||
let qa = ScalarQuantized::quantize(&a);
|
||||
let qb = ScalarQuantized::quantize(&b);
|
||||
|
||||
let quantized_dist = qa.distance(&qb);
|
||||
assert!(quantized_dist >= 0.0, "Distance should be non-negative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_sign_preservation() {
|
||||
let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5, -0.5];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
let reconstructed = quantized.reconstruct();
|
||||
|
||||
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
|
||||
assert_eq!(orig.signum(), *recon, "Sign should be preserved");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_hamming() {
|
||||
let a = vec![1.0, 1.0, 1.0, 1.0];
|
||||
let b = vec![1.0, 1.0, -1.0, -1.0];
|
||||
|
||||
let qa = BinaryQuantized::quantize(&a);
|
||||
let qb = BinaryQuantized::quantize(&b);
|
||||
|
||||
let dist = qa.distance(&qb);
|
||||
assert_eq!(dist, 2.0, "Hamming distance should be 2.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_quantization_zero_distance() {
|
||||
let vector = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let quantized = BinaryQuantized::quantize(&vector);
|
||||
|
||||
let dist = quantized.distance(&quantized);
|
||||
assert_eq!(dist, 0.0, "Distance to self should be 0");
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Storage Layer Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod storage_tests {
|
||||
use super::*;
|
||||
use ruvector_core::storage::VectorStorage;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_insert_with_explicit_id() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("explicit_id".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let id = storage.insert(&entry)?;
|
||||
assert_eq!(id, "explicit_id");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_auto_generates_id() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let id = storage.insert(&entry)?;
|
||||
assert!(!id.is_empty(), "Should generate a UUID");
|
||||
assert!(id.contains('-'), "Should be a valid UUID format");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_with_metadata() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("key".to_string(), serde_json::json!("value"));
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: Some("meta_test".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: Some(metadata.clone()),
|
||||
};
|
||||
|
||||
storage.insert(&entry)?;
|
||||
let retrieved = storage.get("meta_test")?.unwrap();
|
||||
|
||||
assert_eq!(retrieved.metadata, Some(metadata));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dimension_mismatch_error() {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3).unwrap();
|
||||
|
||||
let entry = VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0], // Wrong dimension
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let result = storage.insert(&entry);
|
||||
assert!(result.is_err());
|
||||
|
||||
match result {
|
||||
Err(RuvectorError::DimensionMismatch { expected, actual }) => {
|
||||
assert_eq!(expected, 3);
|
||||
assert_eq!(actual, 2);
|
||||
}
|
||||
_ => panic!("Expected DimensionMismatch error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_nonexistent() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let result = storage.get("nonexistent")?;
|
||||
assert!(result.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_nonexistent() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let deleted = storage.delete("nonexistent")?;
|
||||
assert!(!deleted);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert_empty() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let ids = storage.insert_batch(&[])?;
|
||||
assert_eq!(ids.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert_dimension_mismatch() {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3).unwrap();
|
||||
|
||||
let entries = vec![
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
},
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0], // Wrong dimension
|
||||
metadata: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = storage.insert_batch(&entries);
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should error on dimension mismatch in batch"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_ids_empty() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
let ids = storage.all_ids()?;
|
||||
assert_eq!(ids.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_ids_after_insert() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
|
||||
|
||||
storage.insert(&VectorEntry {
|
||||
id: Some("id1".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
storage.insert(&VectorEntry {
|
||||
id: Some("id2".to_string()),
|
||||
vector: vec![4.0, 5.0, 6.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
let ids = storage.all_ids()?;
|
||||
assert_eq!(ids.len(), 2);
|
||||
assert!(ids.contains(&"id1".to_string()));
|
||||
assert!(ids.contains(&"id2".to_string()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// VectorDB Tests (High-level API)
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod vector_db_tests {
|
||||
use super::*;
|
||||
use ruvector_core::VectorDB;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_empty_database() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
assert!(db.is_empty()?);
|
||||
assert_eq!(db.len()?, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_updates_len() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(db.len()?, 1);
|
||||
assert!(!db.is_empty()?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_updates_len() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
let id = db.insert(VectorEntry {
|
||||
id: Some("test_id".to_string()),
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
metadata: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(db.len()?, 1);
|
||||
|
||||
let deleted = db.delete(&id)?;
|
||||
assert!(deleted);
|
||||
assert_eq!(db.len()?, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_empty_database() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
options.hnsw_config = None; // Use flat index
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![1.0, 2.0, 3.0],
|
||||
k: 10,
|
||||
filter: None,
|
||||
ef_search: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(results.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_with_filter() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
options.hnsw_config = None;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
// Insert vectors with metadata
|
||||
let mut meta1 = HashMap::new();
|
||||
meta1.insert("category".to_string(), serde_json::json!("A"));
|
||||
|
||||
let mut meta2 = HashMap::new();
|
||||
meta2.insert("category".to_string(), serde_json::json!("B"));
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v1".to_string()),
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
metadata: Some(meta1),
|
||||
})?;
|
||||
|
||||
db.insert(VectorEntry {
|
||||
id: Some("v2".to_string()),
|
||||
vector: vec![0.9, 0.1, 0.0],
|
||||
metadata: Some(meta2),
|
||||
})?;
|
||||
|
||||
// Search with filter
|
||||
let mut filter = HashMap::new();
|
||||
filter.insert("category".to_string(), serde_json::json!("A"));
|
||||
|
||||
let results = db.search(SearchQuery {
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
k: 10,
|
||||
filter: Some(filter),
|
||||
ef_search: None,
|
||||
})?;
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "v1");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_insert() -> Result<()> {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut options = DbOptions::default();
|
||||
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
|
||||
options.dimensions = 3;
|
||||
|
||||
let db = VectorDB::new(options)?;
|
||||
|
||||
let entries = vec![
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
metadata: None,
|
||||
},
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![0.0, 1.0, 0.0],
|
||||
metadata: None,
|
||||
},
|
||||
VectorEntry {
|
||||
id: None,
|
||||
vector: vec![0.0, 0.0, 1.0],
|
||||
metadata: None,
|
||||
},
|
||||
];
|
||||
|
||||
let ids = db.insert_batch(entries)?;
|
||||
assert_eq!(ids.len(), 3);
|
||||
assert_eq!(db.len()?, 3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user