Squashed 'vendor/ruvector/' content from commit b64c2172

git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
commit d803bfe2b1
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,101 @@
[package]
name = "ruvector-core"
version.workspace = true
edition.workspace = true
rust-version.workspace = true
license.workspace = true
authors.workspace = true
repository.workspace = true
readme = "README.md"
description = "High-performance Rust vector database core with HNSW indexing"
[dependencies]
# Core functionality
redb = { workspace = true, optional = true }
memmap2 = { workspace = true, optional = true }
hnsw_rs = { workspace = true, optional = true }
simsimd = { workspace = true, optional = true }
rayon = { workspace = true, optional = true }
crossbeam = { workspace = true, optional = true }
# Serialization
rkyv = { workspace = true }
bincode = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
# Error handling
thiserror = { workspace = true }
anyhow = { workspace = true }
tracing = { workspace = true }
# Math and numerics
ndarray = { workspace = true, features = ["serde"] }
rand = { workspace = true }
rand_distr = { workspace = true }
# Performance
dashmap = { workspace = true }
parking_lot = { workspace = true }
once_cell = { workspace = true }
# Time and UUID
chrono = { workspace = true }
uuid = { workspace = true, features = ["v4"] }
# HTTP client for API embeddings (not available in WASM)
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"], optional = true }
[dev-dependencies]
criterion = { workspace = true }
proptest = { workspace = true }
mockall = { workspace = true }
tempfile = "3.13"
tracing-subscriber = { workspace = true }
[[bench]]
name = "distance_metrics"
harness = false
[[bench]]
name = "hnsw_search"
harness = false
[[bench]]
name = "quantization_bench"
harness = false
[[bench]]
name = "batch_operations"
harness = false
[[bench]]
name = "comprehensive_bench"
harness = false
[[bench]]
name = "real_benchmark"
harness = false
[[bench]]
name = "bench_simd"
harness = false
[[bench]]
name = "bench_memory"
harness = false
[features]
default = ["simd", "storage", "hnsw", "api-embeddings", "parallel"]
simd = ["simsimd"] # SIMD acceleration (not available in WASM)
parallel = ["rayon", "crossbeam"] # Parallel processing (not available in WASM)
storage = ["redb", "memmap2"] # File-based storage (not available in WASM)
hnsw = ["hnsw_rs"] # HNSW indexing (not available in WASM due to mmap dependency)
memory-only = [] # Pure in-memory storage for WASM
uuid-support = [] # Deprecated: uuid is now always included
real-embeddings = [] # Feature flag for embedding provider API (use ApiEmbedding for production)
api-embeddings = ["reqwest"] # API-based embeddings (not available in WASM)
[lib]
crate-type = ["rlib"]
bench = false

View File

@@ -0,0 +1,471 @@
# Ruvector Core
[![Crates.io](https://img.shields.io/crates/v/ruvector-core.svg)](https://crates.io/crates/ruvector-core)
[![Documentation](https://docs.rs/ruvector-core/badge.svg)](https://docs.rs/ruvector-core)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![Rust](https://img.shields.io/badge/rust-1.77%2B-orange.svg)](https://www.rust-lang.org)
**The pure-Rust vector database engine behind RuVector -- HNSW indexing, quantization, and SIMD acceleration in a single crate.**
`ruvector-core` is the foundational library that powers the entire [RuVector](https://github.com/ruvnet/ruvector) ecosystem. It gives you a production-grade vector database you can embed directly into any Rust application: insert vectors, search them in under a millisecond, filter by metadata, and compress storage up to 32x -- all without external services. If you need vector search as a library instead of a server, this is the crate.
| | ruvector-core | Typical Vector Database |
|---|---|---|
| **Deployment** | Embed as a Rust dependency -- no server, no network calls | Run a separate service, manage connections |
| **Query latency** | <0.5 ms p50 at 1M vectors with HNSW | ~1-5 ms depending on network and index |
| **Memory compression** | Scalar (4x), Product (8-32x), Binary (32x) quantization built in | Often requires paid tiers or external tools |
| **SIMD acceleration** | SimSIMD hardware-optimized distance calculations, automatic | Manual tuning or not available |
| **Search modes** | Dense vectors, sparse BM25, hybrid, MMR diversity, filtered -- all in one API | Typically dense-only; hybrid and filtering are add-ons |
| **Storage** | Zero-copy mmap with `redb` -- instant loading, no deserialization | Load time scales with dataset size |
| **Concurrency** | Lock-free indexing with parallel batch processing via Rayon | Varies; many require single-writer locks |
| **Dependencies** | Minimal -- pure Rust, compiles anywhere `rustc` runs | Often depends on C/C++ libraries (BLAS, LAPACK) |
| **Cost** | Free forever -- open source (MIT) | Per-vector or per-query pricing on managed tiers |
## Installation
Add `ruvector-core` to your `Cargo.toml`:
```toml
[dependencies]
ruvector-core = "0.1.0"
```
### Feature Flags
```toml
[dependencies]
ruvector-core = { version = "0.1.0", features = ["simd", "uuid-support"] }
```
Available features:
- `simd` (default): Enable SIMD-optimized distance calculations
- `uuid-support` (default): Enable UUID generation for vector IDs
## Key Features
| Feature | What It Does | Why It Matters |
|---------|-------------|----------------|
| **HNSW Indexing** | Hierarchical Navigable Small World graphs for O(log n) approximate nearest neighbor search | Sub-millisecond queries at million-vector scale |
| **Multiple Distance Metrics** | Euclidean, Cosine, Dot Product, Manhattan | Match the metric to your embedding model without conversion |
| **Scalar Quantization** | Compress vectors to 8-bit integers (4x reduction) | Cut memory by 75% with 98% recall preserved |
| **Product Quantization** | Split vectors into subspaces with codebooks (8-32x reduction) | Store millions of vectors on a single machine |
| **Binary Quantization** | 1-bit representation (32x reduction) | Ultra-fast screening pass for massive datasets |
| **SIMD Distance** | Hardware-accelerated distance via SimSIMD | Up to 80K QPS on 8 cores without code changes |
| **Zero-Copy I/O** | Memory-mapped storage loads instantly | No deserialization step -- open a file and search immediately |
| **Hybrid Search** | Combine dense vector similarity with sparse BM25 text scoring | One query handles both semantic and keyword matching |
| **Metadata Filtering** | Apply key-value filters during search | No post-filtering needed -- results are already filtered |
| **MMR Diversification** | Maximal Marginal Relevance re-ranking | Avoid redundant results when top-K are too similar |
| **Conformal Prediction** | Uncertainty quantification on search results | Know when to trust (or distrust) a match |
| **Lock-Free Indexing** | Concurrent reads and writes without blocking | High-throughput ingestion while serving queries |
| **Batch Processing** | Parallel insert and search via Rayon | Saturate all cores for bulk operations |
## Quick Start
### Basic Usage
```rust
use ruvector_core::{VectorDB, DbOptions, VectorEntry, SearchQuery, DistanceMetric};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create a new vector database
let mut options = DbOptions::default();
options.dimensions = 384; // Vector dimensions
options.storage_path = "./my_vectors.db".to_string();
options.distance_metric = DistanceMetric::Cosine;
let db = VectorDB::new(options)?;
// Insert vectors
db.insert(VectorEntry {
id: Some("doc1".to_string()),
vector: vec![0.1, 0.2, 0.3, /* ... 384 dimensions */],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("doc2".to_string()),
vector: vec![0.4, 0.5, 0.6, /* ... 384 dimensions */],
metadata: None,
})?;
// Search for similar vectors
let results = db.search(SearchQuery {
vector: vec![0.1, 0.2, 0.3, /* ... 384 dimensions */],
k: 10, // Return top 10 results
filter: None,
ef_search: None,
})?;
for result in results {
println!("ID: {}, Score: {}", result.id, result.score);
}
Ok(())
}
```
### Batch Operations
```rust
use ruvector_core::{VectorDB, VectorEntry};
// Insert multiple vectors efficiently
let entries = vec![
VectorEntry {
id: Some("doc1".to_string()),
vector: vec![0.1, 0.2, 0.3],
metadata: None,
},
VectorEntry {
id: Some("doc2".to_string()),
vector: vec![0.4, 0.5, 0.6],
metadata: None,
},
];
let ids = db.insert_batch(entries)?;
println!("Inserted {} vectors", ids.len());
```
### With Metadata Filtering
```rust
use std::collections::HashMap;
use serde_json::json;
// Insert with metadata
db.insert(VectorEntry {
id: Some("product1".to_string()),
vector: vec![0.1, 0.2, 0.3],
metadata: Some(HashMap::from([
("category".to_string(), json!("electronics")),
("price".to_string(), json!(299.99)),
])),
})?;
// Search with metadata filter
let results = db.search(SearchQuery {
vector: vec![0.1, 0.2, 0.3],
k: 10,
filter: Some(HashMap::from([
("category".to_string(), json!("electronics")),
])),
ef_search: None,
})?;
```
### HNSW Configuration
```rust
use ruvector_core::{DbOptions, HnswConfig, DistanceMetric};
let mut options = DbOptions::default();
options.dimensions = 384;
options.distance_metric = DistanceMetric::Cosine;
// Configure HNSW index parameters
options.hnsw_config = Some(HnswConfig {
m: 32, // Connections per layer (16-64 typical)
ef_construction: 200, // Build-time accuracy (100-500 typical)
ef_search: 100, // Search-time accuracy (50-200 typical)
max_elements: 10_000_000, // Maximum vectors
});
let db = VectorDB::new(options)?;
```
### Quantization
```rust
use ruvector_core::{DbOptions, QuantizationConfig};
let mut options = DbOptions::default();
options.dimensions = 384;
// Enable scalar quantization (4x compression)
options.quantization = Some(QuantizationConfig::Scalar);
// Or product quantization (8-32x compression)
options.quantization = Some(QuantizationConfig::Product {
subspaces: 8, // Number of subspaces
k: 256, // Codebook size
});
let db = VectorDB::new(options)?;
```
## API Overview
### Core Types
```rust
// Main database interface
pub struct VectorDB { /* ... */ }
// Vector entry with optional ID and metadata
pub struct VectorEntry {
pub id: Option<VectorId>,
pub vector: Vec<f32>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
// Search query parameters
pub struct SearchQuery {
pub vector: Vec<f32>,
pub k: usize,
pub filter: Option<HashMap<String, serde_json::Value>>,
pub ef_search: Option<usize>,
}
// Search result with score
pub struct SearchResult {
pub id: VectorId,
pub score: f32,
pub vector: Option<Vec<f32>>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
```
### Main Operations
```rust
impl VectorDB {
// Create new database with options
pub fn new(options: DbOptions) -> Result<Self>;
// Create with just dimensions (uses defaults)
pub fn with_dimensions(dimensions: usize) -> Result<Self>;
// Insert single vector
pub fn insert(&self, entry: VectorEntry) -> Result<VectorId>;
// Insert multiple vectors
pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>>;
// Search for similar vectors
pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>>;
// Delete vector by ID
pub fn delete(&self, id: &str) -> Result<bool>;
// Get vector by ID
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>>;
// Get total count
pub fn len(&self) -> Result<usize>;
// Check if empty
pub fn is_empty(&self) -> Result<bool>;
}
```
### Distance Metrics
```rust
pub enum DistanceMetric {
Euclidean, // L2 distance - default for embeddings
Cosine, // Cosine similarity (1 - similarity)
DotProduct, // Negative dot product (for maximization)
Manhattan, // L1 distance
}
```
### Advanced Features
```rust
// Hybrid search (dense + sparse)
use ruvector_core::{HybridSearch, HybridConfig};
let hybrid = HybridSearch::new(HybridConfig {
alpha: 0.7, // Balance between dense (0.7) and sparse (0.3)
..Default::default()
});
// Filtered search with expressions
use ruvector_core::{FilteredSearch, FilterExpression};
let filtered = FilteredSearch::new(db);
let expr = FilterExpression::And(vec![
FilterExpression::Equals("category".to_string(), json!("books")),
FilterExpression::GreaterThan("price".to_string(), json!(10.0)),
]);
// MMR diversification
use ruvector_core::{MMRSearch, MMRConfig};
let mmr = MMRSearch::new(MMRConfig {
lambda: 0.5, // Balance relevance (0.5) and diversity (0.5)
..Default::default()
});
```
## Performance
### Latency (Single Query)
```
Operation Flat Index HNSW Index
---------------------------------------------
Search (1K vecs) ~0.1ms ~0.2ms
Search (100K vecs) ~10ms ~0.5ms
Search (1M vecs) ~100ms <1ms
Insert ~0.1ms ~1ms
Batch (1000) ~50ms ~500ms
```
### Memory Usage (1M Vectors, 384 Dimensions)
```
Configuration Memory Recall
---------------------------------------------
Full Precision (f32) ~1.5GB 100%
Scalar Quantization ~400MB 98%
Product Quantization ~200MB 95%
Binary Quantization ~50MB 85%
```
### Throughput (Queries Per Second)
```
Configuration QPS Latency (p50)
-----------------------------------------------------
Single Thread ~2,000 ~0.5ms
Multi-Thread (8 cores) ~50,000 <0.5ms
With SIMD ~80,000 <0.3ms
With Quantization ~100,000 <0.2ms
```
## Configuration Guide
### For Maximum Accuracy
```rust
let options = DbOptions {
dimensions: 384,
distance_metric: DistanceMetric::Cosine,
hnsw_config: Some(HnswConfig {
m: 64,
ef_construction: 500,
ef_search: 200,
max_elements: 10_000_000,
}),
quantization: None, // Full precision
..Default::default()
};
```
### For Maximum Speed
```rust
let options = DbOptions {
dimensions: 384,
distance_metric: DistanceMetric::DotProduct,
hnsw_config: Some(HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 50,
max_elements: 10_000_000,
}),
quantization: Some(QuantizationConfig::Binary),
..Default::default()
};
```
### For Balanced Performance
```rust
let options = DbOptions::default(); // Recommended defaults
```
## Building and Testing
### Build
```bash
# Build with default features
cargo build --release
# Build without SIMD
cargo build --release --no-default-features --features uuid-support
# Build for specific target with optimizations
RUSTFLAGS="-C target-cpu=native" cargo build --release
```
### Testing
```bash
# Run all tests
cargo test
# Run with specific features
cargo test --features simd
# Run with logging
RUST_LOG=debug cargo test
```
### Benchmarks
```bash
# Run all benchmarks
cargo bench
# Run specific benchmark
cargo bench --bench hnsw_search
# Run with features
cargo bench --features simd
```
Available benchmarks:
- `distance_metrics` - SIMD-optimized distance calculations
- `hnsw_search` - HNSW index search performance
- `quantization_bench` - Quantization techniques
- `batch_operations` - Batch insert/search operations
- `comprehensive_bench` - Full system benchmarks
## Related Crates
`ruvector-core` is the foundation for platform-specific bindings:
- **[ruvector-node](../ruvector-node/)** - Node.js bindings via NAPI-RS
- **[ruvector-wasm](../ruvector-wasm/)** - WebAssembly bindings for browsers
- **[ruvector-gnn](../ruvector-gnn/)** - Graph Neural Network layer for learned search
- **[ruvector-cli](../ruvector-cli/)** - Command-line interface
- **[ruvector-bench](../ruvector-bench/)** - Performance benchmarks
## Documentation
- **[Main README](../../README.md)** - Complete project overview
- **[Getting Started Guide](../../docs/guide/GETTING_STARTED.md)** - Quick start tutorial
- **[Rust API Reference](../../docs/api/RUST_API.md)** - Detailed API documentation
- **[Advanced Features Guide](../../docs/guide/ADVANCED_FEATURES.md)** - Quantization, indexing, tuning
- **[Performance Tuning](../../docs/optimization/PERFORMANCE_TUNING_GUIDE.md)** - Optimization strategies
- **[API Documentation](https://docs.rs/ruvector-core)** - Full API reference on docs.rs
## Acknowledgments
Built with state-of-the-art algorithms and libraries:
- **[hnsw_rs](https://crates.io/crates/hnsw_rs)** - HNSW implementation
- **[simsimd](https://crates.io/crates/simsimd)** - SIMD distance calculations
- **[redb](https://crates.io/crates/redb)** - Embedded database
- **[rayon](https://crates.io/crates/rayon)** - Data parallelism
- **[memmap2](https://crates.io/crates/memmap2)** - Memory-mapped files
## License
**MIT License** - see [LICENSE](../../LICENSE) for details.
---
<div align="center">
**Part of [RuVector](https://github.com/ruvnet/ruvector) - Built by [rUv](https://ruv.io)**
[![Star on GitHub](https://img.shields.io/github/stars/ruvnet/ruvector?style=social)](https://github.com/ruvnet/ruvector)
[Documentation](https://docs.rs/ruvector-core) | [Crates.io](https://crates.io/crates/ruvector-core) | [GitHub](https://github.com/ruvnet/ruvector)
</div>

View File

@@ -0,0 +1,204 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use ruvector_core::types::{DistanceMetric, SearchQuery};
use ruvector_core::{DbOptions, VectorDB, VectorEntry};
use tempfile::tempdir;
fn bench_batch_insert(c: &mut Criterion) {
let mut group = c.benchmark_group("batch_insert");
for batch_size in [100, 1000, 10000].iter() {
group.bench_with_input(
BenchmarkId::from_parameter(batch_size),
batch_size,
|bench, &size| {
bench.iter_batched(
|| {
// Setup: Create DB and vectors
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path =
dir.path().join("bench.db").to_string_lossy().to_string();
options.dimensions = 128;
options.hnsw_config = None; // Use flat index for faster insertion
let db = VectorDB::new(options).unwrap();
let vectors: Vec<VectorEntry> = (0..size)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..128).map(|j| ((i + j) as f32) * 0.01).collect(),
metadata: None,
})
.collect();
(db, vectors, dir)
},
|(db, vectors, _dir)| {
// Benchmark: Batch insert
db.insert_batch(black_box(vectors)).unwrap()
},
criterion::BatchSize::LargeInput,
);
},
);
}
group.finish();
}
fn bench_individual_insert_vs_batch(c: &mut Criterion) {
let mut group = c.benchmark_group("individual_vs_batch_insert");
let size = 1000;
// Individual inserts
group.bench_function("individual_1000", |bench| {
bench.iter_batched(
|| {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("bench.db").to_string_lossy().to_string();
options.dimensions = 64;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
let vectors: Vec<VectorEntry> = (0..size)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 64],
metadata: None,
})
.collect();
(db, vectors, dir)
},
|(db, vectors, _dir)| {
for vector in vectors {
db.insert(black_box(vector)).unwrap();
}
},
criterion::BatchSize::LargeInput,
);
});
// Batch insert
group.bench_function("batch_1000", |bench| {
bench.iter_batched(
|| {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("bench.db").to_string_lossy().to_string();
options.dimensions = 64;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
let vectors: Vec<VectorEntry> = (0..size)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 64],
metadata: None,
})
.collect();
(db, vectors, dir)
},
|(db, vectors, _dir)| db.insert_batch(black_box(vectors)).unwrap(),
criterion::BatchSize::LargeInput,
);
});
group.finish();
}
fn bench_parallel_searches(c: &mut Criterion) {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("search_bench.db")
.to_string_lossy()
.to_string();
options.dimensions = 128;
options.distance_metric = DistanceMetric::Cosine;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
// Insert test data
let vectors: Vec<VectorEntry> = (0..1000)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..128).map(|j| ((i + j) as f32) * 0.01).collect(),
metadata: None,
})
.collect();
db.insert_batch(vectors).unwrap();
// Benchmark multiple sequential searches
c.bench_function("sequential_searches_100", |bench| {
bench.iter(|| {
for i in 0..100 {
let query: Vec<f32> = (0..128).map(|j| ((i + j) as f32) * 0.01).collect();
let _ = db
.search(SearchQuery {
vector: black_box(query),
k: 10,
filter: None,
ef_search: None,
})
.unwrap();
}
});
});
}
fn bench_batch_delete(c: &mut Criterion) {
let mut group = c.benchmark_group("batch_delete");
for size in [100, 1000].iter() {
group.bench_with_input(BenchmarkId::from_parameter(size), size, |bench, &size| {
bench.iter_batched(
|| {
// Setup: Create DB with vectors
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path =
dir.path().join("bench.db").to_string_lossy().to_string();
options.dimensions = 32;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
let vectors: Vec<VectorEntry> = (0..size)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 32],
metadata: None,
})
.collect();
let ids = db.insert_batch(vectors).unwrap();
(db, ids, dir)
},
|(db, ids, _dir)| {
// Benchmark: Delete all
for id in ids {
db.delete(black_box(&id)).unwrap();
}
},
criterion::BatchSize::LargeInput,
);
});
}
group.finish();
}
criterion_group!(
benches,
bench_batch_insert,
bench_individual_insert_vs_batch,
bench_parallel_searches,
bench_batch_delete
);
criterion_main!(benches);

View File

@@ -0,0 +1,474 @@
//! Memory Allocation and Pool Benchmarks
//!
//! This module benchmarks arena allocation, cache-optimized storage,
//! and memory access patterns.
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ruvector_core::arena::Arena;
use ruvector_core::cache_optimized::SoAVectorStorage;
// ============================================================================
// Arena Allocation Benchmarks
// ============================================================================
fn bench_arena_allocation(c: &mut Criterion) {
let mut group = c.benchmark_group("arena_allocation");
for count in [10, 100, 1000, 10000] {
group.throughput(Throughput::Elements(count));
// Benchmark arena allocation
group.bench_with_input(BenchmarkId::new("arena", count), &count, |bench, &count| {
bench.iter(|| {
let arena = Arena::new(1024 * 1024);
for _ in 0..count {
let _vec = arena.alloc_vec::<f32>(black_box(64));
}
});
});
// Compare with standard Vec allocation
group.bench_with_input(
BenchmarkId::new("std_vec", count),
&count,
|bench, &count| {
bench.iter(|| {
let mut vecs = Vec::with_capacity(count as usize);
for _ in 0..count {
vecs.push(Vec::<f32>::with_capacity(black_box(64)));
}
vecs
});
},
);
}
group.finish();
}
fn bench_arena_allocation_sizes(c: &mut Criterion) {
let mut group = c.benchmark_group("arena_allocation_sizes");
for size in [8, 32, 64, 128, 256, 512, 1024, 4096] {
group.throughput(Throughput::Bytes(size as u64 * 4)); // f32 = 4 bytes
group.bench_with_input(BenchmarkId::new("alloc", size), &size, |bench, &size| {
bench.iter(|| {
let arena = Arena::new(1024 * 1024);
for _ in 0..1000 {
let _vec = arena.alloc_vec::<f32>(black_box(size));
}
});
});
}
group.finish();
}
fn bench_arena_reset_reuse(c: &mut Criterion) {
let mut group = c.benchmark_group("arena_reset_reuse");
for iterations in [10, 100, 1000] {
group.bench_with_input(
BenchmarkId::new("with_reset", iterations),
&iterations,
|bench, &iterations| {
bench.iter(|| {
let arena = Arena::new(1024 * 1024);
for _ in 0..iterations {
// Allocate
for _ in 0..100 {
let _vec = arena.alloc_vec::<f32>(64);
}
// Reset for reuse
arena.reset();
}
});
},
);
group.bench_with_input(
BenchmarkId::new("without_reset", iterations),
&iterations,
|bench, &iterations| {
bench.iter(|| {
for _ in 0..iterations {
let arena = Arena::new(1024 * 1024);
for _ in 0..100 {
let _vec = arena.alloc_vec::<f32>(64);
}
// No reset, create new arena each time
}
});
},
);
}
group.finish();
}
fn bench_arena_push_operations(c: &mut Criterion) {
let mut group = c.benchmark_group("arena_push");
for count in [100, 1000, 10000] {
group.throughput(Throughput::Elements(count));
// Arena push
group.bench_with_input(BenchmarkId::new("arena", count), &count, |bench, &count| {
bench.iter(|| {
let arena = Arena::new(1024 * 1024);
let mut vec = arena.alloc_vec::<f32>(count as usize);
for i in 0..count {
vec.push(black_box(i as f32));
}
vec
});
});
// Standard Vec push
group.bench_with_input(
BenchmarkId::new("std_vec", count),
&count,
|bench, &count| {
bench.iter(|| {
let mut vec = Vec::with_capacity(count as usize);
for i in 0..count {
vec.push(black_box(i as f32));
}
vec
});
},
);
}
group.finish();
}
// ============================================================================
// SoA Vector Storage Benchmarks
// ============================================================================
fn bench_soa_storage_push(c: &mut Criterion) {
let mut group = c.benchmark_group("soa_storage_push");
for dim in [64, 128, 256, 384, 512, 768] {
let vector: Vec<f32> = (0..dim).map(|i| i as f32 * 0.01).collect();
group.throughput(Throughput::Elements(dim as u64));
group.bench_with_input(BenchmarkId::new("soa", dim), &dim, |bench, _| {
bench.iter(|| {
let mut storage = SoAVectorStorage::new(dim, 128);
for _ in 0..1000 {
storage.push(black_box(&vector));
}
storage
});
});
// Compare with Vec<Vec<f32>>
group.bench_with_input(BenchmarkId::new("vec_of_vec", dim), &dim, |bench, _| {
bench.iter(|| {
let mut storage: Vec<Vec<f32>> = Vec::with_capacity(1000);
for _ in 0..1000 {
storage.push(black_box(vector.clone()));
}
storage
});
});
}
group.finish();
}
fn bench_soa_storage_get(c: &mut Criterion) {
let mut group = c.benchmark_group("soa_storage_get");
for dim in [128, 384, 768] {
let mut storage = SoAVectorStorage::new(dim, 128);
for i in 0..10000 {
let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 * 0.001).collect();
storage.push(&vector);
}
let mut output = vec![0.0_f32; dim];
group.bench_with_input(BenchmarkId::new("sequential", dim), &dim, |bench, _| {
bench.iter(|| {
for i in 0..10000 {
storage.get(black_box(i), &mut output);
}
});
});
group.bench_with_input(BenchmarkId::new("random", dim), &dim, |bench, _| {
let indices: Vec<usize> = (0..10000).map(|i| (i * 37 + 13) % 10000).collect();
bench.iter(|| {
for &idx in &indices {
storage.get(black_box(idx), &mut output);
}
});
});
}
group.finish();
}
fn bench_soa_dimension_slice(c: &mut Criterion) {
let mut group = c.benchmark_group("soa_dimension_slice");
for dim in [64, 128, 256, 512] {
let mut storage = SoAVectorStorage::new(dim, 128);
for i in 0..10000 {
let vector: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 * 0.001).collect();
storage.push(&vector);
}
group.bench_with_input(
BenchmarkId::new("access_all_dims", dim),
&dim,
|bench, &dim| {
bench.iter(|| {
let mut sum = 0.0_f32;
for d in 0..dim {
let slice = storage.dimension_slice(black_box(d));
sum += slice.iter().sum::<f32>();
}
sum
});
},
);
group.bench_with_input(
BenchmarkId::new("access_single_dim", dim),
&dim,
|bench, _| {
bench.iter(|| {
let slice = storage.dimension_slice(black_box(0));
slice.iter().sum::<f32>()
});
},
);
}
group.finish();
}
fn bench_soa_batch_distances(c: &mut Criterion) {
let mut group = c.benchmark_group("soa_batch_distances");
for (dim, count) in [
(128, 1000),
(384, 1000),
(768, 1000),
(128, 10000),
(384, 5000),
] {
let mut storage = SoAVectorStorage::new(dim, 128);
for i in 0..count {
let vector: Vec<f32> = (0..dim)
.map(|j| ((i * dim + j) % 1000) as f32 * 0.001)
.collect();
storage.push(&vector);
}
let query: Vec<f32> = (0..dim).map(|j| j as f32 * 0.002).collect();
let mut distances = vec![0.0_f32; count];
group.throughput(Throughput::Elements(count as u64));
group.bench_with_input(
BenchmarkId::new(format!("{}d_x{}", dim, count), dim),
&dim,
|bench, _| {
bench.iter(|| {
storage.batch_euclidean_distances(black_box(&query), &mut distances);
});
},
);
}
group.finish();
}
// ============================================================================
// Memory Access Pattern Benchmarks
// ============================================================================
fn bench_memory_layout_comparison(c: &mut Criterion) {
let mut group = c.benchmark_group("memory_layout");
let dim = 384;
let count = 10000;
// SoA layout
let mut soa_storage = SoAVectorStorage::new(dim, 128);
for i in 0..count {
let vector: Vec<f32> = (0..dim)
.map(|j| ((i * dim + j) % 1000) as f32 * 0.001)
.collect();
soa_storage.push(&vector);
}
// AoS layout (Vec<Vec<f32>>)
let aos_storage: Vec<Vec<f32>> = (0..count)
.map(|i| {
(0..dim)
.map(|j| ((i * dim + j) % 1000) as f32 * 0.001)
.collect()
})
.collect();
let query: Vec<f32> = (0..dim).map(|j| j as f32 * 0.002).collect();
let mut soa_distances = vec![0.0_f32; count];
group.bench_function("soa_batch_euclidean", |bench| {
bench.iter(|| {
soa_storage.batch_euclidean_distances(black_box(&query), &mut soa_distances);
});
});
group.bench_function("aos_naive_euclidean", |bench| {
bench.iter(|| {
let distances: Vec<f32> = aos_storage
.iter()
.map(|v| {
query
.iter()
.zip(v.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f32>()
.sqrt()
})
.collect();
distances
});
});
group.finish();
}
fn bench_cache_efficiency(c: &mut Criterion) {
let mut group = c.benchmark_group("cache_efficiency");
let dim = 512;
// Test with different vector counts to observe cache effects
for count in [100, 1000, 10000, 50000] {
let mut storage = SoAVectorStorage::new(dim, 128);
for i in 0..count {
let vector: Vec<f32> = (0..dim)
.map(|j| ((i * dim + j) % 1000) as f32 * 0.001)
.collect();
storage.push(&vector);
}
let query: Vec<f32> = (0..dim).map(|j| j as f32 * 0.001).collect();
let mut distances = vec![0.0_f32; count];
group.throughput(Throughput::Elements(count as u64));
group.bench_with_input(
BenchmarkId::new("batch_distance", count),
&count,
|bench, _| {
bench.iter(|| {
storage.batch_euclidean_distances(black_box(&query), &mut distances);
});
},
);
}
group.finish();
}
// ============================================================================
// Growth and Reallocation Benchmarks
// ============================================================================
fn bench_soa_growth(c: &mut Criterion) {
let mut group = c.benchmark_group("soa_growth");
// Test growth from small initial capacity
group.bench_function("grow_from_small", |bench| {
bench.iter(|| {
let mut storage = SoAVectorStorage::new(128, 4); // Very small initial
for i in 0..10000 {
let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.001).collect();
storage.push(black_box(&vector));
}
storage
});
});
// Test with pre-allocated capacity
group.bench_function("preallocated", |bench| {
bench.iter(|| {
let mut storage = SoAVectorStorage::new(128, 16384); // Pre-allocate
for i in 0..10000 {
let vector: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.001).collect();
storage.push(black_box(&vector));
}
storage
});
});
group.finish();
}
// ============================================================================
// Mixed Type Allocation Benchmarks
// ============================================================================
fn bench_arena_mixed_types(c: &mut Criterion) {
let mut group = c.benchmark_group("arena_mixed_types");
group.bench_function("mixed_allocations", |bench| {
bench.iter(|| {
let arena = Arena::new(1024 * 1024);
for _ in 0..100 {
let _f32_vec = arena.alloc_vec::<f32>(black_box(64));
let _f64_vec = arena.alloc_vec::<f64>(black_box(32));
let _u32_vec = arena.alloc_vec::<u32>(black_box(128));
let _u8_vec = arena.alloc_vec::<u8>(black_box(256));
}
});
});
group.bench_function("uniform_allocations", |bench| {
bench.iter(|| {
let arena = Arena::new(1024 * 1024);
for _ in 0..400 {
let _f32_vec = arena.alloc_vec::<f32>(black_box(64));
}
});
});
group.finish();
}
// ============================================================================
// Criterion Groups
// ============================================================================
criterion_group!(
benches,
bench_arena_allocation,
bench_arena_allocation_sizes,
bench_arena_reset_reuse,
bench_arena_push_operations,
bench_soa_storage_push,
bench_soa_storage_get,
bench_soa_dimension_slice,
bench_soa_batch_distances,
bench_memory_layout_comparison,
bench_cache_efficiency,
bench_soa_growth,
bench_arena_mixed_types,
);
criterion_main!(benches);

View File

@@ -0,0 +1,335 @@
//! SIMD Performance Benchmarks
//!
//! This module benchmarks SIMD-optimized distance calculations
//! across various vector dimensions and operation types.
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ruvector_core::simd_intrinsics::*;
// ============================================================================
// Helper Functions
// ============================================================================
fn generate_vectors(dim: usize) -> (Vec<f32>, Vec<f32>) {
let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
let b: Vec<f32> = (0..dim).map(|i| ((i + 100) as f32) * 0.01).collect();
(a, b)
}
fn generate_batch_vectors(dim: usize, count: usize) -> (Vec<f32>, Vec<Vec<f32>>) {
let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
let vectors: Vec<Vec<f32>> = (0..count)
.map(|j| (0..dim).map(|i| ((i + j * 10) as f32) * 0.01).collect())
.collect();
(query, vectors)
}
// ============================================================================
// Euclidean Distance Benchmarks
// ============================================================================
fn bench_euclidean_by_dimension(c: &mut Criterion) {
let mut group = c.benchmark_group("euclidean_by_dimension");
for dim in [32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] {
let (a, b) = generate_vectors(dim);
group.throughput(Throughput::Elements(dim as u64));
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| euclidean_distance_simd(black_box(&a), black_box(&b)));
});
}
group.finish();
}
fn bench_euclidean_small_vectors(c: &mut Criterion) {
let mut group = c.benchmark_group("euclidean_small_vectors");
// Test small vector sizes that may not benefit from SIMD
for dim in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16] {
let (a, b) = generate_vectors(dim);
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| euclidean_distance_simd(black_box(&a), black_box(&b)));
});
}
group.finish();
}
fn bench_euclidean_non_aligned(c: &mut Criterion) {
let mut group = c.benchmark_group("euclidean_non_aligned");
// Test non-SIMD-aligned sizes
for dim in [31, 33, 63, 65, 127, 129, 255, 257, 383, 385, 511, 513] {
let (a, b) = generate_vectors(dim);
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| euclidean_distance_simd(black_box(&a), black_box(&b)));
});
}
group.finish();
}
// ============================================================================
// Dot Product Benchmarks
// ============================================================================
fn bench_dot_product_by_dimension(c: &mut Criterion) {
let mut group = c.benchmark_group("dot_product_by_dimension");
for dim in [32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] {
let (a, b) = generate_vectors(dim);
group.throughput(Throughput::Elements(dim as u64));
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| dot_product_simd(black_box(&a), black_box(&b)));
});
}
group.finish();
}
fn bench_dot_product_common_embeddings(c: &mut Criterion) {
let mut group = c.benchmark_group("dot_product_common_embeddings");
// Common embedding model dimensions
let dims = [
(128, "small"),
(384, "all-MiniLM-L6"),
(512, "e5-small"),
(768, "all-mpnet-base"),
(1024, "e5-large"),
(1536, "text-embedding-ada-002"),
(2048, "llama-7b-hidden"),
];
for (dim, name) in dims {
let (a, b) = generate_vectors(dim);
group.bench_with_input(BenchmarkId::new(name, dim), &dim, |bench, _| {
bench.iter(|| dot_product_simd(black_box(&a), black_box(&b)));
});
}
group.finish();
}
// ============================================================================
// Cosine Similarity Benchmarks
// ============================================================================
fn bench_cosine_by_dimension(c: &mut Criterion) {
let mut group = c.benchmark_group("cosine_by_dimension");
for dim in [32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] {
let (a, b) = generate_vectors(dim);
group.throughput(Throughput::Elements(dim as u64));
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| cosine_similarity_simd(black_box(&a), black_box(&b)));
});
}
group.finish();
}
// ============================================================================
// Manhattan Distance Benchmarks
// ============================================================================
fn bench_manhattan_by_dimension(c: &mut Criterion) {
let mut group = c.benchmark_group("manhattan_by_dimension");
for dim in [32, 64, 128, 256, 384, 512, 768, 1024, 1536, 2048] {
let (a, b) = generate_vectors(dim);
group.throughput(Throughput::Elements(dim as u64));
group.bench_with_input(BenchmarkId::new("simd", dim), &dim, |bench, _| {
bench.iter(|| manhattan_distance_simd(black_box(&a), black_box(&b)));
});
}
group.finish();
}
// ============================================================================
// Batch Operations Benchmarks
// ============================================================================
fn bench_batch_euclidean(c: &mut Criterion) {
let mut group = c.benchmark_group("batch_euclidean");
for count in [10, 100, 1000, 10000] {
let (query, vectors) = generate_batch_vectors(384, count);
group.throughput(Throughput::Elements(count as u64));
group.bench_with_input(BenchmarkId::new("384d", count), &count, |bench, _| {
bench.iter(|| {
for v in &vectors {
euclidean_distance_simd(black_box(&query), black_box(v));
}
});
});
}
group.finish();
}
fn bench_batch_dot_product(c: &mut Criterion) {
let mut group = c.benchmark_group("batch_dot_product");
for count in [10, 100, 1000, 10000] {
let (query, vectors) = generate_batch_vectors(768, count);
group.throughput(Throughput::Elements(count as u64));
group.bench_with_input(BenchmarkId::new("768d", count), &count, |bench, _| {
bench.iter(|| {
for v in &vectors {
dot_product_simd(black_box(&query), black_box(v));
}
});
});
}
group.finish();
}
// ============================================================================
// Comparison Benchmarks (All Metrics)
// ============================================================================
fn bench_all_metrics_comparison(c: &mut Criterion) {
let mut group = c.benchmark_group("metrics_comparison");
let dim = 384; // Common embedding dimension
let (a, b) = generate_vectors(dim);
group.bench_function("euclidean", |bench| {
bench.iter(|| euclidean_distance_simd(black_box(&a), black_box(&b)));
});
group.bench_function("dot_product", |bench| {
bench.iter(|| dot_product_simd(black_box(&a), black_box(&b)));
});
group.bench_function("cosine", |bench| {
bench.iter(|| cosine_similarity_simd(black_box(&a), black_box(&b)));
});
group.bench_function("manhattan", |bench| {
bench.iter(|| manhattan_distance_simd(black_box(&a), black_box(&b)));
});
group.finish();
}
// ============================================================================
// Memory Access Pattern Benchmarks
// ============================================================================
fn bench_sequential_vs_random_access(c: &mut Criterion) {
let mut group = c.benchmark_group("access_patterns");
let dim = 512;
let count = 1000;
// Generate vectors
let vectors: Vec<Vec<f32>> = (0..count)
.map(|j| (0..dim).map(|i| ((i + j * 10) as f32) * 0.01).collect())
.collect();
let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
// Sequential access indices
let sequential_indices: Vec<usize> = (0..count).collect();
// Random-ish access indices
let random_indices: Vec<usize> = (0..count)
.map(|i| (i * 37 + 13) % count) // Pseudo-random
.collect();
group.bench_function("sequential", |bench| {
bench.iter(|| {
for &idx in &sequential_indices {
euclidean_distance_simd(black_box(&query), black_box(&vectors[idx]));
}
});
});
group.bench_function("random", |bench| {
bench.iter(|| {
for &idx in &random_indices {
euclidean_distance_simd(black_box(&query), black_box(&vectors[idx]));
}
});
});
group.finish();
}
// ============================================================================
// Throughput Measurement
// ============================================================================
fn bench_throughput_ops_per_second(c: &mut Criterion) {
let mut group = c.benchmark_group("throughput");
group.sample_size(50);
for dim in [128, 384, 768, 1536] {
let (a, b) = generate_vectors(dim);
// Report throughput in operations/second
group.bench_with_input(BenchmarkId::new("euclidean_ops", dim), &dim, |bench, _| {
bench.iter(|| {
// Perform 100 operations per iteration
for _ in 0..100 {
euclidean_distance_simd(black_box(&a), black_box(&b));
}
});
});
group.bench_with_input(
BenchmarkId::new("dot_product_ops", dim),
&dim,
|bench, _| {
bench.iter(|| {
for _ in 0..100 {
dot_product_simd(black_box(&a), black_box(&b));
}
});
},
);
}
group.finish();
}
// ============================================================================
// Criterion Groups
// ============================================================================
criterion_group!(
benches,
bench_euclidean_by_dimension,
bench_euclidean_small_vectors,
bench_euclidean_non_aligned,
bench_dot_product_by_dimension,
bench_dot_product_common_embeddings,
bench_cosine_by_dimension,
bench_manhattan_by_dimension,
bench_batch_euclidean,
bench_batch_dot_product,
bench_all_metrics_comparison,
bench_sequential_vs_random_access,
bench_throughput_ops_per_second,
);
criterion_main!(benches);

View File

@@ -0,0 +1,262 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ruvector_core::arena::Arena;
use ruvector_core::cache_optimized::SoAVectorStorage;
use ruvector_core::distance::*;
use ruvector_core::lockfree::{LockFreeCounter, LockFreeStats, ObjectPool};
use ruvector_core::simd_intrinsics::*;
use ruvector_core::types::DistanceMetric;
use std::sync::Arc;
use std::thread;
// Benchmark SIMD intrinsics vs SimSIMD
fn bench_simd_comparison(c: &mut Criterion) {
let mut group = c.benchmark_group("simd_comparison");
for size in [128, 384, 768, 1536].iter() {
let a: Vec<f32> = (0..*size).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (0..*size).map(|i| (i + 1) as f32 * 0.1).collect();
// Euclidean distance
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(
BenchmarkId::new("euclidean_simsimd", size),
size,
|bench, _| {
bench.iter(|| euclidean_distance(black_box(&a), black_box(&b)));
},
);
group.bench_with_input(
BenchmarkId::new("euclidean_avx2", size),
size,
|bench, _| {
bench.iter(|| euclidean_distance_avx2(black_box(&a), black_box(&b)));
},
);
// Dot product
group.bench_with_input(
BenchmarkId::new("dot_product_simsimd", size),
size,
|bench, _| {
bench.iter(|| dot_product_distance(black_box(&a), black_box(&b)));
},
);
group.bench_with_input(
BenchmarkId::new("dot_product_avx2", size),
size,
|bench, _| {
bench.iter(|| dot_product_avx2(black_box(&a), black_box(&b)));
},
);
// Cosine similarity
group.bench_with_input(
BenchmarkId::new("cosine_simsimd", size),
size,
|bench, _| {
bench.iter(|| cosine_distance(black_box(&a), black_box(&b)));
},
);
group.bench_with_input(BenchmarkId::new("cosine_avx2", size), size, |bench, _| {
bench.iter(|| cosine_similarity_avx2(black_box(&a), black_box(&b)));
});
}
group.finish();
}
// Benchmark Structure-of-Arrays vs Array-of-Structures
fn bench_cache_optimization(c: &mut Criterion) {
let mut group = c.benchmark_group("cache_optimization");
let dimensions = 384;
let num_vectors = 10000;
// Prepare data
let vectors: Vec<Vec<f32>> = (0..num_vectors)
.map(|i| (0..dimensions).map(|j| (i * j) as f32 * 0.001).collect())
.collect();
let query: Vec<f32> = (0..dimensions).map(|i| i as f32 * 0.01).collect();
// Array-of-Structures (traditional Vec<Vec<f32>>)
group.bench_function("aos_batch_distance", |bench| {
bench.iter(|| {
let mut distances: Vec<f32> = Vec::with_capacity(num_vectors);
for vector in &vectors {
let dist = euclidean_distance(black_box(&query), black_box(vector));
distances.push(dist);
}
black_box(distances)
});
});
// Structure-of-Arrays
let mut soa_storage = SoAVectorStorage::new(dimensions, num_vectors);
for vector in &vectors {
soa_storage.push(vector);
}
group.bench_function("soa_batch_distance", |bench| {
bench.iter(|| {
let mut distances = vec![0.0; num_vectors];
soa_storage.batch_euclidean_distances(black_box(&query), &mut distances);
black_box(distances)
});
});
group.finish();
}
// Benchmark arena allocation vs standard allocation
fn bench_arena_allocation(c: &mut Criterion) {
let mut group = c.benchmark_group("arena_allocation");
let num_allocations = 1000;
let vec_size = 100;
group.bench_function("standard_allocation", |bench| {
bench.iter(|| {
let mut vecs = Vec::new();
for _ in 0..num_allocations {
let mut v = Vec::with_capacity(vec_size);
for j in 0..vec_size {
v.push(j as f32);
}
vecs.push(v);
}
black_box(vecs)
});
});
group.bench_function("arena_allocation", |bench| {
bench.iter(|| {
let arena = Arena::new(1024 * 1024);
let mut vecs = Vec::new();
for _ in 0..num_allocations {
let mut v = arena.alloc_vec::<f32>(vec_size);
for j in 0..vec_size {
v.push(j as f32);
}
vecs.push(v);
}
black_box(vecs)
});
});
group.finish();
}
// Benchmark lock-free operations vs locked operations
fn bench_lockfree(c: &mut Criterion) {
let mut group = c.benchmark_group("lockfree");
// Counter benchmark
group.bench_function("lockfree_counter_single_thread", |bench| {
let counter = LockFreeCounter::new(0);
bench.iter(|| {
for _ in 0..10000 {
counter.increment();
}
});
});
group.bench_function("lockfree_counter_multi_thread", |bench| {
bench.iter(|| {
let counter = Arc::new(LockFreeCounter::new(0));
let mut handles = vec![];
for _ in 0..4 {
let counter_clone = Arc::clone(&counter);
handles.push(thread::spawn(move || {
for _ in 0..2500 {
counter_clone.increment();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
black_box(counter.get())
});
});
// Stats collector benchmark
group.bench_function("lockfree_stats", |bench| {
let stats = LockFreeStats::new();
bench.iter(|| {
for i in 0..1000 {
stats.record_query(i);
}
black_box(stats.snapshot())
});
});
// Object pool benchmark
group.bench_function("object_pool_acquire_release", |bench| {
let pool = ObjectPool::new(10, || Vec::<f32>::with_capacity(1000));
bench.iter(|| {
let mut obj = pool.acquire();
for i in 0..100 {
obj.push(i as f32);
}
black_box(&*obj);
});
});
group.finish();
}
// Benchmark thread scaling
fn bench_thread_scaling(c: &mut Criterion) {
let mut group = c.benchmark_group("thread_scaling");
let dimensions = 384;
let num_vectors = 10000;
let query: Vec<f32> = (0..dimensions).map(|i| i as f32 * 0.01).collect();
let vectors: Vec<Vec<f32>> = (0..num_vectors)
.map(|i| (0..dimensions).map(|j| (i * j) as f32 * 0.001).collect())
.collect();
for num_threads in [1, 2, 4, 8].iter() {
group.bench_with_input(
BenchmarkId::new("parallel_distance", num_threads),
num_threads,
|bench, &threads| {
bench.iter(|| {
rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build()
.unwrap()
.install(|| {
let result = batch_distances(
black_box(&query),
black_box(&vectors),
DistanceMetric::Euclidean,
)
.unwrap();
black_box(result)
})
});
},
);
}
group.finish();
}
criterion_group!(
benches,
bench_simd_comparison,
bench_cache_optimization,
bench_arena_allocation,
bench_lockfree,
bench_thread_scaling
);
criterion_main!(benches);

View File

@@ -0,0 +1,74 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use ruvector_core::distance::*;
use ruvector_core::types::DistanceMetric;
fn bench_euclidean(c: &mut Criterion) {
let mut group = c.benchmark_group("euclidean_distance");
for size in [128, 384, 768, 1536].iter() {
let a: Vec<f32> = (0..*size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..*size).map(|i| (i + 1) as f32).collect();
group.bench_with_input(BenchmarkId::from_parameter(size), size, |bench, _| {
bench.iter(|| euclidean_distance(black_box(&a), black_box(&b)));
});
}
group.finish();
}
fn bench_cosine(c: &mut Criterion) {
let mut group = c.benchmark_group("cosine_distance");
for size in [128, 384, 768, 1536].iter() {
let a: Vec<f32> = (0..*size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..*size).map(|i| (i + 1) as f32).collect();
group.bench_with_input(BenchmarkId::from_parameter(size), size, |bench, _| {
bench.iter(|| cosine_distance(black_box(&a), black_box(&b)));
});
}
group.finish();
}
fn bench_dot_product(c: &mut Criterion) {
let mut group = c.benchmark_group("dot_product_distance");
for size in [128, 384, 768, 1536].iter() {
let a: Vec<f32> = (0..*size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..*size).map(|i| (i + 1) as f32).collect();
group.bench_with_input(BenchmarkId::from_parameter(size), size, |bench, _| {
bench.iter(|| dot_product_distance(black_box(&a), black_box(&b)));
});
}
group.finish();
}
fn bench_batch_distances(c: &mut Criterion) {
let query: Vec<f32> = (0..384).map(|i| i as f32).collect();
let vectors: Vec<Vec<f32>> = (0..1000)
.map(|_| (0..384).map(|i| (i as f32) * 1.1).collect())
.collect();
c.bench_function("batch_distances_1000x384", |b| {
b.iter(|| {
batch_distances(
black_box(&query),
black_box(&vectors),
DistanceMetric::Cosine,
)
});
});
}
criterion_group!(
benches,
bench_euclidean,
bench_cosine,
bench_dot_product,
bench_batch_distances
);
criterion_main!(benches);

View File

@@ -0,0 +1,56 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use ruvector_core::types::{DbOptions, DistanceMetric, HnswConfig, SearchQuery};
use ruvector_core::{VectorDB, VectorEntry};
fn bench_hnsw_search(c: &mut Criterion) {
let mut group = c.benchmark_group("hnsw_search");
// Create temp database
let temp_dir = tempfile::tempdir().unwrap();
let options = DbOptions {
dimensions: 128,
distance_metric: DistanceMetric::Cosine,
storage_path: temp_dir
.path()
.join("test.db")
.to_string_lossy()
.to_string(),
hnsw_config: Some(HnswConfig::default()),
quantization: None,
};
let db = VectorDB::new(options).unwrap();
// Insert test vectors
let vectors: Vec<VectorEntry> = (0..1000)
.map(|i| VectorEntry {
id: Some(format!("v{}", i)),
vector: (0..128).map(|j| ((i + j) as f32) * 0.1).collect(),
metadata: None,
})
.collect();
db.insert_batch(vectors).unwrap();
// Benchmark search
let query: Vec<f32> = (0..128).map(|i| i as f32).collect();
for k in [1, 10, 100].iter() {
group.bench_with_input(BenchmarkId::from_parameter(k), k, |bench, &k| {
bench.iter(|| {
db.search(SearchQuery {
vector: black_box(query.clone()),
k: black_box(k),
filter: None,
ef_search: None,
})
.unwrap()
});
});
}
group.finish();
}
criterion_group!(benches, bench_hnsw_search);
criterion_main!(benches);

View File

@@ -0,0 +1,77 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use ruvector_core::quantization::*;
fn bench_scalar_quantization(c: &mut Criterion) {
let mut group = c.benchmark_group("scalar_quantization");
for size in [128, 384, 768, 1536].iter() {
let vector: Vec<f32> = (0..*size).map(|i| i as f32 * 0.1).collect();
group.bench_with_input(BenchmarkId::new("encode", size), size, |bench, _| {
bench.iter(|| ScalarQuantized::quantize(black_box(&vector)));
});
let quantized = ScalarQuantized::quantize(&vector);
group.bench_with_input(BenchmarkId::new("decode", size), size, |bench, _| {
bench.iter(|| quantized.reconstruct());
});
let quantized2 = ScalarQuantized::quantize(&vector);
group.bench_with_input(BenchmarkId::new("distance", size), size, |bench, _| {
bench.iter(|| quantized.distance(black_box(&quantized2)));
});
}
group.finish();
}
fn bench_binary_quantization(c: &mut Criterion) {
let mut group = c.benchmark_group("binary_quantization");
for size in [128, 384, 768, 1536].iter() {
let vector: Vec<f32> = (0..*size)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
group.bench_with_input(BenchmarkId::new("encode", size), size, |bench, _| {
bench.iter(|| BinaryQuantized::quantize(black_box(&vector)));
});
let quantized = BinaryQuantized::quantize(&vector);
group.bench_with_input(BenchmarkId::new("decode", size), size, |bench, _| {
bench.iter(|| quantized.reconstruct());
});
let quantized2 = BinaryQuantized::quantize(&vector);
group.bench_with_input(
BenchmarkId::new("hamming_distance", size),
size,
|bench, _| {
bench.iter(|| quantized.distance(black_box(&quantized2)));
},
);
}
group.finish();
}
fn bench_quantization_compression_ratio(c: &mut Criterion) {
let dimensions = 384;
let vector: Vec<f32> = (0..dimensions).map(|i| i as f32 * 0.01).collect();
c.bench_function("scalar_vs_binary_encoding", |b| {
b.iter(|| {
let scalar = ScalarQuantized::quantize(black_box(&vector));
let binary = BinaryQuantized::quantize(black_box(&vector));
(scalar, binary)
});
});
}
criterion_group!(
benches,
bench_scalar_quantization,
bench_binary_quantization,
bench_quantization_compression_ratio
);
criterion_main!(benches);

View File

@@ -0,0 +1,217 @@
//! Real Benchmarks for RuVector Core
//!
//! These are ACTUAL performance measurements, not simulations.
//! Run with: cargo bench -p ruvector-core --bench real_benchmark
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ruvector_core::types::{DbOptions, HnswConfig};
use ruvector_core::{DistanceMetric, SearchQuery, VectorDB, VectorEntry};
use tempfile::tempdir;
/// Generate random vectors for benchmarking
fn generate_vectors(count: usize, dim: usize) -> Vec<Vec<f32>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
(0..count)
.map(|i| {
(0..dim)
.map(|j| {
let mut hasher = DefaultHasher::new();
(i * dim + j).hash(&mut hasher);
let h = hasher.finish();
((h % 2000) as f32 / 1000.0) - 1.0 // Range [-1, 1]
})
.collect()
})
.collect()
}
/// Benchmark: Vector insertion (single)
fn bench_insert_single(c: &mut Criterion) {
let mut group = c.benchmark_group("insert_single");
for dim in [64, 128, 256, 512].iter() {
let vectors = generate_vectors(1000, *dim);
group.throughput(Throughput::Elements(1));
group.bench_with_input(BenchmarkId::new("dimensions", dim), dim, |b, &dim| {
let dir = tempdir().unwrap();
let options = DbOptions {
storage_path: dir.path().join("bench.db").to_string_lossy().to_string(),
dimensions: dim,
distance_metric: DistanceMetric::Cosine,
hnsw_config: Some(HnswConfig::default()),
quantization: None,
};
let db = VectorDB::new(options).unwrap();
let mut idx = 0;
b.iter(|| {
let entry = VectorEntry {
id: None,
vector: vectors[idx % vectors.len()].clone(),
metadata: None,
};
let _ = black_box(db.insert(entry));
idx += 1;
});
});
}
group.finish();
}
/// Benchmark: Vector insertion (batch)
fn bench_insert_batch(c: &mut Criterion) {
let mut group = c.benchmark_group("insert_batch");
for batch_size in [100, 500, 1000].iter() {
let vectors = generate_vectors(*batch_size, 128);
group.throughput(Throughput::Elements(*batch_size as u64));
group.bench_with_input(
BenchmarkId::new("batch_size", batch_size),
batch_size,
|b, &batch_size| {
b.iter(|| {
let dir = tempdir().unwrap();
let options = DbOptions {
storage_path: dir.path().join("bench.db").to_string_lossy().to_string(),
dimensions: 128,
distance_metric: DistanceMetric::Cosine,
hnsw_config: Some(HnswConfig::default()),
quantization: None,
};
let db = VectorDB::new(options).unwrap();
let entries: Vec<VectorEntry> = vectors
.iter()
.map(|v| VectorEntry {
id: None,
vector: v.clone(),
metadata: None,
})
.collect();
black_box(db.insert_batch(entries).unwrap())
});
},
);
}
group.finish();
}
/// Benchmark: Search (k-NN)
fn bench_search(c: &mut Criterion) {
let mut group = c.benchmark_group("search");
// Pre-populate database
let dir = tempdir().unwrap();
let options = DbOptions {
storage_path: dir.path().join("bench.db").to_string_lossy().to_string(),
dimensions: 128,
distance_metric: DistanceMetric::Cosine,
hnsw_config: Some(HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 50,
max_elements: 100000,
}),
quantization: None,
};
let db = VectorDB::new(options).unwrap();
// Insert 10k vectors
let vectors = generate_vectors(10000, 128);
let entries: Vec<VectorEntry> = vectors
.iter()
.map(|v| VectorEntry {
id: None,
vector: v.clone(),
metadata: None,
})
.collect();
db.insert_batch(entries).unwrap();
// Generate query vectors
let queries = generate_vectors(100, 128);
for k in [10, 50, 100].iter() {
group.throughput(Throughput::Elements(1));
group.bench_with_input(BenchmarkId::new("top_k", k), k, |b, &k| {
let mut query_idx = 0;
b.iter(|| {
let query = &queries[query_idx % queries.len()];
let search_query = SearchQuery {
vector: query.clone(),
k,
filter: None,
ef_search: None,
};
let results = black_box(db.search(search_query));
query_idx += 1;
results
});
});
}
group.finish();
}
/// Benchmark: Distance computation (raw)
fn bench_distance(c: &mut Criterion) {
use ruvector_core::distance::{cosine_distance, dot_product_distance, euclidean_distance};
let mut group = c.benchmark_group("distance");
for dim in [64, 128, 256, 512, 1024].iter() {
let v1: Vec<f32> = (0..*dim).map(|i| (i as f32 * 0.01).sin()).collect();
let v2: Vec<f32> = (0..*dim).map(|i| (i as f32 * 0.02).cos()).collect();
group.throughput(Throughput::Elements(1));
group.bench_with_input(BenchmarkId::new("cosine", dim), dim, |b, _| {
b.iter(|| black_box(cosine_distance(&v1, &v2)));
});
group.bench_with_input(BenchmarkId::new("euclidean", dim), dim, |b, _| {
b.iter(|| black_box(euclidean_distance(&v1, &v2)));
});
group.bench_with_input(BenchmarkId::new("dot_product", dim), dim, |b, _| {
b.iter(|| black_box(dot_product_distance(&v1, &v2)));
});
}
group.finish();
}
/// Benchmark: Quantization
fn bench_quantization(c: &mut Criterion) {
use ruvector_core::quantization::{QuantizedVector, ScalarQuantized};
let mut group = c.benchmark_group("quantization");
for dim in [128, 256, 512].iter() {
let vector: Vec<f32> = (0..*dim).map(|i| (i as f32 * 0.01).sin()).collect();
group.bench_with_input(BenchmarkId::new("scalar_quantize", dim), dim, |b, _| {
b.iter(|| black_box(ScalarQuantized::quantize(&vector)));
});
let quantized = ScalarQuantized::quantize(&vector);
group.bench_with_input(BenchmarkId::new("scalar_distance", dim), dim, |b, _| {
b.iter(|| black_box(quantized.distance(&quantized)));
});
}
group.finish();
}
criterion_group!(
benches,
bench_distance,
bench_quantization,
bench_insert_single,
bench_insert_batch,
bench_search,
);
criterion_main!(benches);

View File

@@ -0,0 +1,296 @@
# Text Embeddings for AgenticDB
This guide explains how to use real text embeddings with AgenticDB in ruvector-core.
## Quick Start
### Default (Hash-based - Testing Only)
```rust
use ruvector_core::{AgenticDB, types::DbOptions};
let mut options = DbOptions::default();
options.dimensions = 128;
options.storage_path = "agenticdb.db".to_string();
// Uses hash-based embeddings by default (fast but not semantic)
let db = AgenticDB::new(options)?;
// Store and retrieve episodes
let episode_id = db.store_episode(
"Solve math problem".to_string(),
vec!["read".to_string(), "calculate".to_string()],
vec!["got 42".to_string()],
"Should show work".to_string(),
)?;
```
⚠️ **Warning**: Hash-based embeddings don't understand semantic meaning!
- "dog" and "cat" will NOT be similar
- "dog" and "god" WILL be similar (same characters)
## Production: API-based Embeddings (Recommended)
### OpenAI
```rust
use ruvector_core::{AgenticDB, ApiEmbedding, types::DbOptions};
use std::sync::Arc;
let mut options = DbOptions::default();
options.dimensions = 1536; // text-embedding-3-small
options.storage_path = "agenticdb.db".to_string();
let api_key = std::env::var("OPENAI_API_KEY")?;
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
let db = AgenticDB::with_embedding_provider(options, provider)?;
// Now you have semantic embeddings!
let episodes = db.retrieve_similar_episodes("mathematics", 5)?;
```
**OpenAI Models:**
- `text-embedding-3-small` - 1536 dims, $0.02/1M tokens (recommended)
- `text-embedding-3-large` - 3072 dims, $0.13/1M tokens (best quality)
### Cohere
```rust
let api_key = std::env::var("COHERE_API_KEY")?;
let provider = Arc::new(ApiEmbedding::cohere(&api_key, "embed-english-v3.0"));
let mut options = DbOptions::default();
options.dimensions = 1024; // Cohere embedding size
let db = AgenticDB::with_embedding_provider(options, provider)?;
```
### Voyage AI
```rust
let api_key = std::env::var("VOYAGE_API_KEY")?;
let provider = Arc::new(ApiEmbedding::voyage(&api_key, "voyage-2"));
let mut options = DbOptions::default();
options.dimensions = 1024; // voyage-2 size
let db = AgenticDB::with_embedding_provider(options, provider)?;
```
## Custom Embedding Provider
Implement the `EmbeddingProvider` trait for any embedding system:
```rust
use ruvector_core::embeddings::EmbeddingProvider;
use ruvector_core::error::Result;
struct MyCustomEmbedding {
// Your model here
}
impl EmbeddingProvider for MyCustomEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
// Your embedding logic
todo!()
}
fn dimensions(&self) -> usize {
384 // Your embedding dimensions
}
fn name(&self) -> &str {
"MyCustomEmbedding"
}
}
```
## ONNX Runtime (Local, No API Costs)
For production use without API costs, use ONNX Runtime with pre-exported models:
```rust
// See examples/onnx-embeddings for complete implementation
use ort::{Session, Environment, Value};
struct OnnxEmbedding {
session: Session,
dimensions: usize,
}
impl EmbeddingProvider for OnnxEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
// Tokenize text
// Run ONNX inference
// Return embeddings
todo!()
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"OnnxEmbedding"
}
}
```
### Exporting Models to ONNX
```python
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
model_id = "sentence-transformers/all-MiniLM-L6-v2"
model = ORTModelForFeatureExtraction.from_pretrained(model_id, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.save_pretrained("./onnx-model")
tokenizer.save_pretrained("./onnx-model")
```
## Feature Flags
### `real-embeddings` (Optional)
This feature flag enables the `CandleEmbedding` type (currently a stub):
```toml
[dependencies]
ruvector-core = { version = "0.1", features = ["real-embeddings"] }
```
However, we recommend using API-based providers instead of implementing Candle integration yourself.
## Complete Example
```rust
use ruvector_core::{AgenticDB, ApiEmbedding, types::DbOptions};
use std::sync::Arc;
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Setup
let api_key = std::env::var("OPENAI_API_KEY")?;
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
let mut options = DbOptions::default();
options.dimensions = 1536;
options.storage_path = "agenticdb.db".to_string();
let db = AgenticDB::with_embedding_provider(options, provider)?;
println!("Using: {}", db.embedding_provider_name());
// Store reflexion episodes
let ep1 = db.store_episode(
"Debug memory leak in Rust".to_string(),
vec!["profile".to_string(), "find leak".to_string()],
vec!["fixed with Arc".to_string()],
"Should explain reference counting".to_string(),
)?;
let ep2 = db.store_episode(
"Optimize Python performance".to_string(),
vec!["profile".to_string(), "vectorize".to_string()],
vec!["10x speedup".to_string()],
"Should mention NumPy".to_string(),
)?;
// Semantic search - will find Rust episode for memory-related query
let episodes = db.retrieve_similar_episodes("memory management", 5)?;
for episode in episodes {
println!("Task: {}", episode.task);
println!("Critique: {}", episode.critique);
}
// Create skills
db.create_skill(
"Memory Profiling".to_string(),
"Profile application memory usage to find leaks".to_string(),
Default::default(),
vec!["valgrind".to_string(), "massif".to_string()],
)?;
// Search skills semantically
let skills = db.search_skills("finding memory leaks", 3)?;
for skill in skills {
println!("Skill: {} - {}", skill.name, skill.description);
}
Ok(())
}
```
## Performance Considerations
### API-based (OpenAI, Cohere, Voyage)
- **Pros**: Always up-to-date, no model storage, easy to use
- **Cons**: Network latency, API costs, requires internet
- **Best for**: Production apps with internet access
### ONNX Runtime (Local)
- **Pros**: No API costs, offline support, fast inference
- **Cons**: Model storage (~100MB), setup complexity
- **Best for**: Edge deployment, high-volume apps
### Hash-based (Default)
- **Pros**: Zero dependencies, instant, no setup
- **Cons**: Not semantic, only for testing
- **Best for**: Development, unit tests
## Recommendations
1. **Development/Testing**: Use hash-based (default)
2. **Production (Cloud)**: Use `ApiEmbedding::openai()`
3. **Production (Edge/Offline)**: Implement ONNX provider
4. **Custom Models**: Implement `EmbeddingProvider` trait
## Migration Path
```rust
// Start with hash for development
let db = AgenticDB::new(options)?;
// Switch to API for staging
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
let db = AgenticDB::with_embedding_provider(options, provider)?;
// Move to ONNX for production scale
let provider = Arc::new(OnnxEmbedding::from_file("model.onnx")?);
let db = AgenticDB::with_embedding_provider(options, provider)?;
```
The beauty is: **your AgenticDB code doesn't change**, just the provider!
## Error Handling
```rust
use ruvector_core::error::RuvectorError;
match AgenticDB::with_embedding_provider(options, provider) {
Ok(db) => {
// Use db
}
Err(RuvectorError::InvalidDimension(msg)) => {
eprintln!("Dimension mismatch: {}", msg);
}
Err(RuvectorError::ModelLoadError(msg)) => {
eprintln!("Failed to load model: {}", msg);
}
Err(RuvectorError::ModelInferenceError(msg)) => {
eprintln!("Inference failed: {}", msg);
}
Err(e) => {
eprintln!("Error: {}", e);
}
}
```
## See Also
- [AgenticDB API Documentation](../src/agenticdb.rs)
- [Embedding Provider Trait](../src/embeddings.rs)
- [ONNX Examples](../../examples/onnx-embeddings/)
- [Integration Tests](../tests/embeddings_test.rs)

View File

@@ -0,0 +1,184 @@
//! Example of using different embedding providers with AgenticDB
//!
//! Run with:
//! ```bash
//! # Default hash-based (testing only)
//! cargo run --example embeddings_example
//!
//! # With OpenAI API (requires OPENAI_API_KEY env var)
//! OPENAI_API_KEY=sk-... cargo run --example embeddings_example --features real-embeddings
//! ```
use ruvector_core::types::DbOptions;
use ruvector_core::{AgenticDB, ApiEmbedding, HashEmbedding};
use std::sync::Arc;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== AgenticDB Embeddings Example ===\n");
// Determine which provider to use
let use_api = std::env::var("OPENAI_API_KEY").is_ok();
let (db, provider_name) = if use_api {
println!("Using OpenAI API embeddings (real semantic search)");
let api_key = std::env::var("OPENAI_API_KEY")?;
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
let mut options = DbOptions::default();
options.dimensions = 1536; // OpenAI text-embedding-3-small
options.storage_path = "/tmp/agenticdb_api.db".to_string();
let db = AgenticDB::with_embedding_provider(options, provider)?;
(db, "OpenAI API")
} else {
println!("Using hash-based embeddings (testing only - not semantic)");
println!("Set OPENAI_API_KEY to use real embeddings\n");
let mut options = DbOptions::default();
options.dimensions = 128;
options.storage_path = "/tmp/agenticdb_hash.db".to_string();
let db = AgenticDB::new(options)?;
(db, "Hash-based")
};
println!("Provider: {}\n", db.embedding_provider_name());
// Store some reflexion episodes
println!("--- Storing Reflexion Episodes ---");
let ep1 = db.store_episode(
"Fix Rust borrow checker error".to_string(),
vec![
"Identified lifetime issue".to_string(),
"Added explicit lifetime annotations".to_string(),
"Refactored to use references".to_string(),
],
vec!["Code compiles now".to_string()],
"Should explain borrow checker rules better".to_string(),
)?;
println!(
"✓ Stored episode: Fix Rust borrow checker error (ID: {})",
ep1
);
let ep2 = db.store_episode(
"Optimize Python data processing".to_string(),
vec![
"Profiled with cProfile".to_string(),
"Vectorized with NumPy".to_string(),
"Parallelized with multiprocessing".to_string(),
],
vec!["10x performance improvement".to_string()],
"Could have used Pandas for better readability".to_string(),
)?;
println!(
"✓ Stored episode: Optimize Python data processing (ID: {})",
ep2
);
let ep3 = db.store_episode(
"Debug JavaScript async issue".to_string(),
vec![
"Added console.log statements".to_string(),
"Used Chrome DevTools debugger".to_string(),
"Fixed Promise chain".to_string(),
],
vec!["Race condition resolved".to_string()],
"Should use async/await instead of callbacks".to_string(),
)?;
println!(
"✓ Stored episode: Debug JavaScript async issue (ID: {})\n",
ep3
);
// Create some skills
println!("--- Creating Skills ---");
let skill1 = db.create_skill(
"Memory Profiling".to_string(),
"Profile application memory usage to detect leaks and optimize allocation".to_string(),
Default::default(),
vec![
"valgrind".to_string(),
"massif".to_string(),
"heaptrack".to_string(),
],
)?;
println!("✓ Created skill: Memory Profiling (ID: {})", skill1);
let skill2 = db.create_skill(
"Async Programming".to_string(),
"Write asynchronous code using promises, async/await, or futures".to_string(),
Default::default(),
vec![
"Promise.all()".to_string(),
"async/await".to_string(),
"tokio".to_string(),
],
)?;
println!("✓ Created skill: Async Programming (ID: {})", skill2);
let skill3 = db.create_skill(
"Performance Optimization".to_string(),
"Profile and optimize code performance using profilers and benchmarks".to_string(),
Default::default(),
vec![
"perf".to_string(),
"criterion".to_string(),
"flamegraph".to_string(),
],
)?;
println!(
"✓ Created skill: Performance Optimization (ID: {})\n",
skill3
);
// Search episodes
println!("--- Searching Episodes ---");
let query = "memory problems in programming";
println!("Query: \"{}\"", query);
let episodes = db.retrieve_similar_episodes(query, 3)?;
println!("Found {} similar episodes:\n", episodes.len());
for (i, episode) in episodes.iter().enumerate() {
println!("{}. Task: {}", i + 1, episode.task);
println!(" Critique: {}", episode.critique);
println!(" Actions: {}", episode.actions.join(""));
println!();
}
if use_api {
println!(" With OpenAI embeddings, results are semantically similar!");
println!(" 'memory problems' should match 'Rust borrow checker' and 'memory profiling'");
} else {
println!("⚠️ Hash-based embeddings are NOT semantic!");
println!(" Results are based on character overlap, not meaning.");
println!(" Set OPENAI_API_KEY to see real semantic search.");
}
// Search skills
println!("\n--- Searching Skills ---");
let query = "handling asynchronous operations";
println!("Query: \"{}\"", query);
let skills = db.search_skills(query, 3)?;
println!("Found {} similar skills:\n", skills.len());
for (i, skill) in skills.iter().enumerate() {
println!("{}. {}", i + 1, skill.name);
println!(" Description: {}", skill.description);
println!(" Examples: {}", skill.examples.join(", "));
println!();
}
println!("=== Example Complete ===");
println!("\nTips:");
println!("- Use hash-based embeddings for testing/development");
println!("- Use API embeddings (OpenAI, Cohere, Voyage) for production");
println!("- Implement ONNX provider for offline/edge deployment");
println!("- See docs/EMBEDDINGS.md for full guide");
Ok(())
}

View File

@@ -0,0 +1,264 @@
//! Quick benchmark to compare NEON SIMD vs scalar performance on Apple Silicon
//!
//! Run with: cargo run --example neon_benchmark --release -p ruvector-core
use std::time::Instant;
fn main() {
println!("╔════════════════════════════════════════════════════════════╗");
println!("║ NEON SIMD Benchmark for Apple Silicon (M4 Pro) ║");
println!("╚════════════════════════════════════════════════════════════╝\n");
// Test parameters
let dimensions = 128; // Common embedding dimension
let num_vectors = 10_000;
let num_queries = 1_000;
// Generate test data
let vectors: Vec<Vec<f32>> = (0..num_vectors)
.map(|i| {
(0..dimensions)
.map(|j| ((i * j) % 1000) as f32 / 1000.0)
.collect()
})
.collect();
let queries: Vec<Vec<f32>> = (0..num_queries)
.map(|i| {
(0..dimensions)
.map(|j| ((i * j + 500) % 1000) as f32 / 1000.0)
.collect()
})
.collect();
println!("Configuration:");
println!(" - Dimensions: {}", dimensions);
println!(" - Vectors: {}", num_vectors);
println!(" - Queries: {}", num_queries);
println!(
" - Total distance calculations: {}\n",
num_vectors * num_queries
);
#[cfg(target_arch = "aarch64")]
println!("Platform: ARM64 (Apple Silicon) - NEON enabled ✓\n");
#[cfg(target_arch = "x86_64")]
println!("Platform: x86_64 - AVX2 detection enabled\n");
// Benchmark Euclidean distance (SIMD)
println!("═══════════════════════════════════════════════════════════════");
println!("Euclidean Distance:");
println!("═══════════════════════════════════════════════════════════════");
let start = Instant::now();
let mut simd_sum = 0.0f32;
for query in &queries {
for vec in &vectors {
simd_sum += euclidean_simd(query, vec);
}
}
let simd_time = start.elapsed();
println!(
" SIMD: {:>8.2} ms (checksum: {:.4})",
simd_time.as_secs_f64() * 1000.0,
simd_sum
);
let start = Instant::now();
let mut scalar_sum = 0.0f32;
for query in &queries {
for vec in &vectors {
scalar_sum += euclidean_scalar(query, vec);
}
}
let scalar_time = start.elapsed();
println!(
" Scalar: {:>8.2} ms (checksum: {:.4})",
scalar_time.as_secs_f64() * 1000.0,
scalar_sum
);
let speedup = scalar_time.as_secs_f64() / simd_time.as_secs_f64();
println!(" Speedup: {:.2}x\n", speedup);
// Benchmark Dot Product (SIMD)
println!("═══════════════════════════════════════════════════════════════");
println!("Dot Product:");
println!("═══════════════════════════════════════════════════════════════");
let start = Instant::now();
let mut simd_sum = 0.0f32;
for query in &queries {
for vec in &vectors {
simd_sum += dot_simd(query, vec);
}
}
let simd_time = start.elapsed();
println!(
" SIMD: {:>8.2} ms (checksum: {:.4})",
simd_time.as_secs_f64() * 1000.0,
simd_sum
);
let start = Instant::now();
let mut scalar_sum = 0.0f32;
for query in &queries {
for vec in &vectors {
scalar_sum += dot_scalar(query, vec);
}
}
let scalar_time = start.elapsed();
println!(
" Scalar: {:>8.2} ms (checksum: {:.4})",
scalar_time.as_secs_f64() * 1000.0,
scalar_sum
);
let speedup = scalar_time.as_secs_f64() / simd_time.as_secs_f64();
println!(" Speedup: {:.2}x\n", speedup);
// Benchmark Cosine Similarity (SIMD)
println!("═══════════════════════════════════════════════════════════════");
println!("Cosine Similarity:");
println!("═══════════════════════════════════════════════════════════════");
let start = Instant::now();
let mut simd_sum = 0.0f32;
for query in &queries {
for vec in &vectors {
simd_sum += cosine_simd(query, vec);
}
}
let simd_time = start.elapsed();
println!(
" SIMD: {:>8.2} ms (checksum: {:.4})",
simd_time.as_secs_f64() * 1000.0,
simd_sum
);
let start = Instant::now();
let mut scalar_sum = 0.0f32;
for query in &queries {
for vec in &vectors {
scalar_sum += cosine_scalar(query, vec);
}
}
let scalar_time = start.elapsed();
println!(
" Scalar: {:>8.2} ms (checksum: {:.4})",
scalar_time.as_secs_f64() * 1000.0,
scalar_sum
);
let speedup = scalar_time.as_secs_f64() / simd_time.as_secs_f64();
println!(" Speedup: {:.2}x\n", speedup);
println!("═══════════════════════════════════════════════════════════════");
println!("Benchmark complete!");
}
// SIMD implementations (use the crate's SIMD functions)
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[inline]
fn euclidean_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "aarch64")]
unsafe {
let len = a.len();
let mut sum = vdupq_n_f32(0.0);
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
let diff = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, diff, diff);
}
let mut total = vaddvq_f32(sum);
for i in (chunks * 4)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(not(target_arch = "aarch64"))]
euclidean_scalar(a, b)
}
#[inline]
fn euclidean_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
#[inline]
fn dot_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "aarch64")]
unsafe {
let len = a.len();
let mut sum = vdupq_n_f32(0.0);
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
sum = vfmaq_f32(sum, va, vb);
}
let mut total = vaddvq_f32(sum);
for i in (chunks * 4)..len {
total += a[i] * b[i];
}
total
}
#[cfg(not(target_arch = "aarch64"))]
dot_scalar(a, b)
}
#[inline]
fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn cosine_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "aarch64")]
unsafe {
let len = a.len();
let mut dot = vdupq_n_f32(0.0);
let mut norm_a = vdupq_n_f32(0.0);
let mut norm_b = vdupq_n_f32(0.0);
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
dot = vfmaq_f32(dot, va, vb);
norm_a = vfmaq_f32(norm_a, va, va);
norm_b = vfmaq_f32(norm_b, vb, vb);
}
let mut dot_sum = vaddvq_f32(dot);
let mut norm_a_sum = vaddvq_f32(norm_a);
let mut norm_b_sum = vaddvq_f32(norm_b);
for i in (chunks * 4)..len {
dot_sum += a[i] * b[i];
norm_a_sum += a[i] * a[i];
norm_b_sum += b[i] * b[i];
}
dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
}
#[cfg(not(target_arch = "aarch64"))]
cosine_scalar(a, b)
}
#[inline]
fn cosine_scalar(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
dot / (norm_a * norm_b)
}

View File

@@ -0,0 +1,545 @@
//! # Hypergraph Support for N-ary Relationships
//!
//! Implements hypergraph structures for representing complex multi-entity relationships
//! beyond traditional pairwise similarity. Based on HyperGraphRAG (NeurIPS 2025) architecture.
use crate::error::{Result, RuvectorError};
use crate::types::{DistanceMetric, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
/// Hyperedge connecting multiple vectors with description and embedding
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hyperedge {
/// Unique identifier for the hyperedge
pub id: String,
/// Vector IDs connected by this hyperedge
pub nodes: Vec<VectorId>,
/// Natural language description of the relationship
pub description: String,
/// Embedding of the hyperedge description
pub embedding: Vec<f32>,
/// Confidence weight (0.0-1.0)
pub confidence: f32,
/// Optional metadata
pub metadata: HashMap<String, String>,
}
/// Temporal hyperedge with time attributes
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalHyperedge {
/// Base hyperedge
pub hyperedge: Hyperedge,
/// Creation timestamp (Unix epoch seconds)
pub timestamp: u64,
/// Optional expiration timestamp
pub expires_at: Option<u64>,
/// Temporal context (hourly, daily, monthly)
pub granularity: TemporalGranularity,
}
/// Temporal granularity for indexing
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TemporalGranularity {
Hourly,
Daily,
Monthly,
Yearly,
}
impl Hyperedge {
/// Create a new hyperedge
pub fn new(
nodes: Vec<VectorId>,
description: String,
embedding: Vec<f32>,
confidence: f32,
) -> Self {
Self {
id: Uuid::new_v4().to_string(),
nodes,
description,
embedding,
confidence: confidence.clamp(0.0, 1.0),
metadata: HashMap::new(),
}
}
/// Get hyperedge order (number of nodes)
pub fn order(&self) -> usize {
self.nodes.len()
}
/// Check if hyperedge contains a specific node
pub fn contains_node(&self, node: &VectorId) -> bool {
self.nodes.contains(node)
}
}
impl TemporalHyperedge {
/// Create a new temporal hyperedge with current timestamp
pub fn new(hyperedge: Hyperedge, granularity: TemporalGranularity) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
hyperedge,
timestamp,
expires_at: None,
granularity,
}
}
/// Check if hyperedge is expired
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
now > expires_at
} else {
false
}
}
/// Get time bucket for indexing
pub fn time_bucket(&self) -> u64 {
match self.granularity {
TemporalGranularity::Hourly => self.timestamp / 3600,
TemporalGranularity::Daily => self.timestamp / 86400,
TemporalGranularity::Monthly => self.timestamp / (86400 * 30),
TemporalGranularity::Yearly => self.timestamp / (86400 * 365),
}
}
}
/// Hypergraph index with bipartite graph storage
pub struct HypergraphIndex {
/// Entity nodes
entities: HashMap<VectorId, Vec<f32>>,
/// Hyperedges
hyperedges: HashMap<String, Hyperedge>,
/// Temporal hyperedges indexed by time bucket
temporal_index: HashMap<u64, Vec<String>>,
/// Bipartite graph: entity -> hyperedge IDs
entity_to_hyperedges: HashMap<VectorId, HashSet<String>>,
/// Bipartite graph: hyperedge -> entity IDs
hyperedge_to_entities: HashMap<String, HashSet<VectorId>>,
/// Distance metric for embeddings
distance_metric: DistanceMetric,
}
impl HypergraphIndex {
/// Create a new hypergraph index
pub fn new(distance_metric: DistanceMetric) -> Self {
Self {
entities: HashMap::new(),
hyperedges: HashMap::new(),
temporal_index: HashMap::new(),
entity_to_hyperedges: HashMap::new(),
hyperedge_to_entities: HashMap::new(),
distance_metric,
}
}
/// Add an entity node
pub fn add_entity(&mut self, id: VectorId, embedding: Vec<f32>) {
self.entities.insert(id.clone(), embedding);
self.entity_to_hyperedges.entry(id).or_default();
}
/// Add a hyperedge
pub fn add_hyperedge(&mut self, hyperedge: Hyperedge) -> Result<()> {
let edge_id = hyperedge.id.clone();
// Verify all nodes exist
for node in &hyperedge.nodes {
if !self.entities.contains_key(node) {
return Err(RuvectorError::InvalidInput(format!(
"Entity {} not found in hypergraph",
node
)));
}
}
// Update bipartite graph
for node in &hyperedge.nodes {
self.entity_to_hyperedges
.entry(node.clone())
.or_default()
.insert(edge_id.clone());
}
let nodes_set: HashSet<VectorId> = hyperedge.nodes.iter().cloned().collect();
self.hyperedge_to_entities
.insert(edge_id.clone(), nodes_set);
self.hyperedges.insert(edge_id, hyperedge);
Ok(())
}
/// Add a temporal hyperedge
pub fn add_temporal_hyperedge(&mut self, temporal_edge: TemporalHyperedge) -> Result<()> {
let bucket = temporal_edge.time_bucket();
let edge_id = temporal_edge.hyperedge.id.clone();
self.add_hyperedge(temporal_edge.hyperedge)?;
self.temporal_index.entry(bucket).or_default().push(edge_id);
Ok(())
}
/// Search hyperedges by embedding similarity
pub fn search_hyperedges(&self, query_embedding: &[f32], k: usize) -> Vec<(String, f32)> {
let mut results: Vec<(String, f32)> = self
.hyperedges
.iter()
.map(|(id, edge)| {
let distance = self.compute_distance(query_embedding, &edge.embedding);
(id.clone(), distance)
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
results.truncate(k);
results
}
/// Get k-hop neighbors in hypergraph
/// Returns all nodes reachable within k hops from the start node
pub fn k_hop_neighbors(&self, start_node: VectorId, k: usize) -> HashSet<VectorId> {
let mut visited = HashSet::new();
let mut current_layer = HashSet::new();
current_layer.insert(start_node.clone());
visited.insert(start_node); // Start node is at distance 0
for _hop in 0..k {
let mut next_layer = HashSet::new();
for node in current_layer.iter() {
// Get all hyperedges containing this node
if let Some(hyperedges) = self.entity_to_hyperedges.get(node) {
for edge_id in hyperedges {
// Get all nodes in this hyperedge
if let Some(nodes) = self.hyperedge_to_entities.get(edge_id) {
for neighbor in nodes.iter() {
if !visited.contains(neighbor) {
visited.insert(neighbor.clone());
next_layer.insert(neighbor.clone());
}
}
}
}
}
}
if next_layer.is_empty() {
break;
}
current_layer = next_layer;
}
visited
}
/// Query temporal hyperedges in a time range
pub fn query_temporal_range(&self, start_bucket: u64, end_bucket: u64) -> Vec<String> {
let mut results = Vec::new();
for bucket in start_bucket..=end_bucket {
if let Some(edges) = self.temporal_index.get(&bucket) {
results.extend(edges.iter().cloned());
}
}
results
}
/// Get hyperedge by ID
pub fn get_hyperedge(&self, id: &str) -> Option<&Hyperedge> {
self.hyperedges.get(id)
}
/// Get statistics
pub fn stats(&self) -> HypergraphStats {
let total_edges = self.hyperedges.len();
let total_entities = self.entities.len();
let avg_degree = if total_entities > 0 {
self.entity_to_hyperedges
.values()
.map(|edges| edges.len())
.sum::<usize>() as f32
/ total_entities as f32
} else {
0.0
};
HypergraphStats {
total_entities,
total_hyperedges: total_edges,
avg_entity_degree: avg_degree,
}
}
fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
crate::distance::distance(a, b, self.distance_metric).unwrap_or(f32::MAX)
}
}
/// Hypergraph statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HypergraphStats {
pub total_entities: usize,
pub total_hyperedges: usize,
pub avg_entity_degree: f32,
}
/// Causal hypergraph memory for agent reasoning
pub struct CausalMemory {
/// Hypergraph index
index: HypergraphIndex,
/// Causal relationship tracking: (cause_id, effect_id) -> success_count
causal_counts: HashMap<(VectorId, VectorId), u32>,
/// Action latencies: action_id -> avg_latency_ms
latencies: HashMap<VectorId, f32>,
/// Utility function weights
alpha: f32, // similarity weight
beta: f32, // causal uplift weight
gamma: f32, // latency penalty weight
}
impl CausalMemory {
/// Create a new causal memory with default utility weights
pub fn new(distance_metric: DistanceMetric) -> Self {
Self {
index: HypergraphIndex::new(distance_metric),
causal_counts: HashMap::new(),
latencies: HashMap::new(),
alpha: 0.7,
beta: 0.2,
gamma: 0.1,
}
}
/// Set custom utility function weights
pub fn with_weights(mut self, alpha: f32, beta: f32, gamma: f32) -> Self {
self.alpha = alpha;
self.beta = beta;
self.gamma = gamma;
self
}
/// Add a causal relationship
pub fn add_causal_edge(
&mut self,
cause: VectorId,
effect: VectorId,
context: Vec<VectorId>,
description: String,
embedding: Vec<f32>,
latency_ms: f32,
) -> Result<()> {
// Create hyperedge connecting cause, effect, and context
let mut nodes = vec![cause.clone(), effect.clone()];
nodes.extend(context);
let hyperedge = Hyperedge::new(nodes, description, embedding, 1.0);
self.index.add_hyperedge(hyperedge)?;
// Update causal counts
*self
.causal_counts
.entry((cause.clone(), effect.clone()))
.or_insert(0) += 1;
// Update latency
let entry = self.latencies.entry(cause).or_insert(0.0);
*entry = (*entry + latency_ms) / 2.0; // Running average
Ok(())
}
/// Query with utility function: U = α·similarity + β·causal_uplift - γ·latency
pub fn query_with_utility(
&self,
query_embedding: &[f32],
action_id: VectorId,
k: usize,
) -> Vec<(String, f32)> {
let mut results: Vec<(String, f32)> = self
.index
.hyperedges
.iter()
.filter(|(_, edge)| edge.contains_node(&action_id))
.map(|(id, edge)| {
let similarity = 1.0
- self
.index
.compute_distance(query_embedding, &edge.embedding);
let causal_uplift = self.compute_causal_uplift(&edge.nodes);
let latency = self.latencies.get(&action_id).copied().unwrap_or(0.0);
let utility = self.alpha * similarity + self.beta * causal_uplift
- self.gamma * (latency / 1000.0); // Normalize latency to 0-1 range
(id.clone(), utility)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); // Sort by utility descending
results.truncate(k);
results
}
fn compute_causal_uplift(&self, nodes: &[VectorId]) -> f32 {
if nodes.len() < 2 {
return 0.0;
}
// Compute average causal strength for pairs in this hyperedge
let mut total_uplift = 0.0;
let mut count = 0;
for i in 0..nodes.len() - 1 {
for j in i + 1..nodes.len() {
if let Some(&success_count) = self
.causal_counts
.get(&(nodes[i].clone(), nodes[j].clone()))
{
total_uplift += (success_count as f32).ln_1p(); // Log scale
count += 1;
}
}
}
if count > 0 {
total_uplift / count as f32
} else {
0.0
}
}
/// Get hypergraph index
pub fn index(&self) -> &HypergraphIndex {
&self.index
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hyperedge_creation() {
let nodes = vec!["1".to_string(), "2".to_string(), "3".to_string()];
let desc = "Test relationship".to_string();
let embedding = vec![0.1, 0.2, 0.3];
let edge = Hyperedge::new(nodes, desc, embedding, 0.95);
assert_eq!(edge.order(), 3);
assert!(edge.contains_node(&"1".to_string()));
assert!(!edge.contains_node(&"4".to_string()));
assert_eq!(edge.confidence, 0.95);
}
#[test]
fn test_temporal_hyperedge() {
let nodes = vec!["1".to_string(), "2".to_string()];
let desc = "Temporal relationship".to_string();
let embedding = vec![0.1, 0.2];
let edge = Hyperedge::new(nodes, desc, embedding, 1.0);
let temporal = TemporalHyperedge::new(edge, TemporalGranularity::Hourly);
assert!(!temporal.is_expired());
assert!(temporal.time_bucket() > 0);
}
#[test]
fn test_hypergraph_index() {
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
// Add entities
index.add_entity("1".to_string(), vec![1.0, 0.0, 0.0]);
index.add_entity("2".to_string(), vec![0.0, 1.0, 0.0]);
index.add_entity("3".to_string(), vec![0.0, 0.0, 1.0]);
// Add hyperedge
let edge = Hyperedge::new(
vec!["1".to_string(), "2".to_string(), "3".to_string()],
"Triple relationship".to_string(),
vec![0.5, 0.5, 0.5],
0.9,
);
index.add_hyperedge(edge).unwrap();
let stats = index.stats();
assert_eq!(stats.total_entities, 3);
assert_eq!(stats.total_hyperedges, 1);
}
#[test]
fn test_k_hop_neighbors() {
let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
// Create a small hypergraph
index.add_entity("1".to_string(), vec![1.0]);
index.add_entity("2".to_string(), vec![1.0]);
index.add_entity("3".to_string(), vec![1.0]);
index.add_entity("4".to_string(), vec![1.0]);
let edge1 = Hyperedge::new(
vec!["1".to_string(), "2".to_string()],
"e1".to_string(),
vec![1.0],
1.0,
);
let edge2 = Hyperedge::new(
vec!["2".to_string(), "3".to_string()],
"e2".to_string(),
vec![1.0],
1.0,
);
let edge3 = Hyperedge::new(
vec!["3".to_string(), "4".to_string()],
"e3".to_string(),
vec![1.0],
1.0,
);
index.add_hyperedge(edge1).unwrap();
index.add_hyperedge(edge2).unwrap();
index.add_hyperedge(edge3).unwrap();
let neighbors = index.k_hop_neighbors("1".to_string(), 2);
assert!(neighbors.contains(&"1".to_string()));
assert!(neighbors.contains(&"2".to_string()));
assert!(neighbors.contains(&"3".to_string()));
}
#[test]
fn test_causal_memory() {
let mut memory = CausalMemory::new(DistanceMetric::Cosine);
memory.index.add_entity("1".to_string(), vec![1.0, 0.0]);
memory.index.add_entity("2".to_string(), vec![0.0, 1.0]);
memory
.add_causal_edge(
"1".to_string(),
"2".to_string(),
vec![],
"Action 1 causes effect 2".to_string(),
vec![0.5, 0.5],
100.0,
)
.unwrap();
let results = memory.query_with_utility(&[0.6, 0.4], "1".to_string(), 5);
assert!(!results.is_empty());
}
}

View File

@@ -0,0 +1,441 @@
//! # Learned Index Structures
//!
//! Experimental learned indexes using neural networks to approximate data distribution.
//! Based on Recursive Model Index (RMI) concept with bounded error correction.
use crate::error::{Result, RuvectorError};
use crate::types::VectorId;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Trait for learned index structures
pub trait LearnedIndex {
/// Predict position for a key
fn predict(&self, key: &[f32]) -> Result<usize>;
/// Insert a key-value pair
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()>;
/// Search for a key
fn search(&self, key: &[f32]) -> Result<Option<VectorId>>;
/// Get index statistics
fn stats(&self) -> IndexStats;
}
/// Statistics for learned indexes
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub total_entries: usize,
pub model_size_bytes: usize,
pub avg_error: f32,
pub max_error: usize,
}
/// Simple linear model for CDF approximation
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LinearModel {
weights: Vec<f32>,
bias: f32,
}
impl LinearModel {
fn new(dimensions: usize) -> Self {
Self {
weights: vec![0.0; dimensions],
bias: 0.0,
}
}
fn predict(&self, input: &[f32]) -> f32 {
let mut result = self.bias;
for (w, x) in self.weights.iter().zip(input.iter()) {
result += w * x;
}
result.max(0.0)
}
fn train_simple(&mut self, data: &[(Vec<f32>, usize)]) {
if data.is_empty() {
return;
}
// Simple least squares approximation
let n = data.len() as f32;
let dim = self.weights.len();
// Reset weights
self.weights.fill(0.0);
self.bias = 0.0;
// Compute means
let mut mean_x = vec![0.0; dim];
let mut mean_y = 0.0;
for (x, y) in data {
for (i, &val) in x.iter().enumerate() {
mean_x[i] += val;
}
mean_y += *y as f32;
}
for val in mean_x.iter_mut() {
*val /= n;
}
mean_y /= n;
// Simple linear regression for first dimension
if dim > 0 {
let mut numerator = 0.0;
let mut denominator = 0.0;
for (x, y) in data {
let x_diff = x[0] - mean_x[0];
let y_diff = *y as f32 - mean_y;
numerator += x_diff * y_diff;
denominator += x_diff * x_diff;
}
if denominator.abs() > 1e-10 {
self.weights[0] = numerator / denominator;
}
self.bias = mean_y - self.weights[0] * mean_x[0];
}
}
}
/// Recursive Model Index (RMI)
/// Multi-stage neural models making coarse-then-fine predictions
pub struct RecursiveModelIndex {
/// Root model for coarse prediction
root_model: LinearModel,
/// Second-level models for fine prediction
leaf_models: Vec<LinearModel>,
/// Sorted data with error correction
data: Vec<(Vec<f32>, VectorId)>,
/// Error bounds for binary search fallback
max_error: usize,
/// Dimensions of vectors
dimensions: usize,
}
impl RecursiveModelIndex {
/// Create a new RMI with specified number of leaf models
pub fn new(dimensions: usize, num_leaf_models: usize) -> Self {
let leaf_models = (0..num_leaf_models)
.map(|_| LinearModel::new(dimensions))
.collect();
Self {
root_model: LinearModel::new(dimensions),
leaf_models,
data: Vec::new(),
max_error: 100,
dimensions,
}
}
/// Build the index from data
pub fn build(&mut self, mut data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
if data.is_empty() {
return Err(RuvectorError::InvalidInput(
"Cannot build index from empty data".into(),
));
}
if data[0].0.is_empty() {
return Err(RuvectorError::InvalidInput(
"Cannot build index from vectors with zero dimensions".into(),
));
}
if self.leaf_models.is_empty() {
return Err(RuvectorError::InvalidInput(
"Cannot build index with zero leaf models".into(),
));
}
// Sort data by first dimension (simple heuristic)
data.sort_by(|a, b| {
a.0[0]
.partial_cmp(&b.0[0])
.unwrap_or(std::cmp::Ordering::Equal)
});
let n = data.len();
// Train root model to predict leaf model index
let root_training_data: Vec<(Vec<f32>, usize)> = data
.iter()
.enumerate()
.map(|(i, (key, _))| {
let leaf_idx = (i * self.leaf_models.len()) / n;
(key.clone(), leaf_idx)
})
.collect();
self.root_model.train_simple(&root_training_data);
// Train each leaf model
let num_leaf_models = self.leaf_models.len();
let chunk_size = n / num_leaf_models;
for (i, model) in self.leaf_models.iter_mut().enumerate() {
let start = i * chunk_size;
let end = if i == num_leaf_models - 1 {
n
} else {
(i + 1) * chunk_size
};
if start < n {
let leaf_data: Vec<(Vec<f32>, usize)> = data[start..end.min(n)]
.iter()
.enumerate()
.map(|(j, (key, _))| (key.clone(), start + j))
.collect();
model.train_simple(&leaf_data);
}
}
self.data = data;
Ok(())
}
}
impl LearnedIndex for RecursiveModelIndex {
fn predict(&self, key: &[f32]) -> Result<usize> {
if key.len() != self.dimensions {
return Err(RuvectorError::InvalidInput(
"Key dimensions mismatch".into(),
));
}
if self.leaf_models.is_empty() {
return Err(RuvectorError::InvalidInput(
"Index not built: no leaf models available".into(),
));
}
if self.data.is_empty() {
return Err(RuvectorError::InvalidInput(
"Index not built: no data available".into(),
));
}
// Root model predicts leaf model
let leaf_idx = self.root_model.predict(key) as usize;
let leaf_idx = leaf_idx.min(self.leaf_models.len() - 1);
// Leaf model predicts position
let pos = self.leaf_models[leaf_idx].predict(key) as usize;
let pos = pos.min(self.data.len().saturating_sub(1));
Ok(pos)
}
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
// For simplicity, append and mark for rebuild
// Production implementation would use incremental updates
self.data.push((key, value));
Ok(())
}
fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
if self.data.is_empty() {
return Ok(None);
}
let predicted_pos = self.predict(key)?;
// Binary search around predicted position with error bound
let start = predicted_pos.saturating_sub(self.max_error);
let end = (predicted_pos + self.max_error).min(self.data.len());
for i in start..end {
if self.data[i].0 == key {
return Ok(Some(self.data[i].1.clone()));
}
}
Ok(None)
}
fn stats(&self) -> IndexStats {
let model_size = std::mem::size_of_val(&self.root_model)
+ self.leaf_models.len() * std::mem::size_of::<LinearModel>();
// Compute average prediction error
let mut total_error = 0.0;
let mut max_error = 0;
for (i, (key, _)) in self.data.iter().enumerate() {
if let Ok(pred_pos) = self.predict(key) {
let error = i.abs_diff(pred_pos);
total_error += error as f32;
max_error = max_error.max(error);
}
}
let avg_error = if !self.data.is_empty() {
total_error / self.data.len() as f32
} else {
0.0
};
IndexStats {
total_entries: self.data.len(),
model_size_bytes: model_size,
avg_error,
max_error,
}
}
}
/// Hybrid index combining learned index for static data with HNSW for dynamic updates
pub struct HybridIndex {
/// Learned index for static segment
learned: RecursiveModelIndex,
/// Dynamic updates buffer
dynamic_buffer: HashMap<Vec<u8>, VectorId>,
/// Threshold for rebuilding learned index
rebuild_threshold: usize,
}
impl HybridIndex {
/// Create a new hybrid index
pub fn new(dimensions: usize, num_leaf_models: usize, rebuild_threshold: usize) -> Self {
Self {
learned: RecursiveModelIndex::new(dimensions, num_leaf_models),
dynamic_buffer: HashMap::new(),
rebuild_threshold,
}
}
/// Build the learned portion from static data
pub fn build_static(&mut self, data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
self.learned.build(data)
}
/// Check if rebuild is needed
pub fn needs_rebuild(&self) -> bool {
self.dynamic_buffer.len() >= self.rebuild_threshold
}
/// Rebuild learned index incorporating dynamic updates
pub fn rebuild(&mut self) -> Result<()> {
let mut all_data: Vec<(Vec<f32>, VectorId)> = self.learned.data.clone();
for (key_bytes, value) in &self.dynamic_buffer {
let (key, _): (Vec<f32>, usize) =
bincode::decode_from_slice(key_bytes, bincode::config::standard())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
all_data.push((key, value.clone()));
}
self.learned.build(all_data)?;
self.dynamic_buffer.clear();
Ok(())
}
fn serialize_key(key: &[f32]) -> Vec<u8> {
bincode::encode_to_vec(key, bincode::config::standard()).unwrap_or_default()
}
}
impl LearnedIndex for HybridIndex {
fn predict(&self, key: &[f32]) -> Result<usize> {
self.learned.predict(key)
}
fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
let key_bytes = Self::serialize_key(&key);
self.dynamic_buffer.insert(key_bytes, value);
Ok(())
}
fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
// Check dynamic buffer first
let key_bytes = Self::serialize_key(key);
if let Some(value) = self.dynamic_buffer.get(&key_bytes) {
return Ok(Some(value.clone()));
}
// Fall back to learned index
self.learned.search(key)
}
fn stats(&self) -> IndexStats {
let mut stats = self.learned.stats();
stats.total_entries += self.dynamic_buffer.len();
stats
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_model() {
let mut model = LinearModel::new(2);
let data = vec![
(vec![0.0, 0.0], 0),
(vec![1.0, 1.0], 10),
(vec![2.0, 2.0], 20),
];
model.train_simple(&data);
let pred = model.predict(&[1.5, 1.5]);
assert!(pred >= 0.0 && pred <= 30.0);
}
#[test]
fn test_rmi_build() {
let mut rmi = RecursiveModelIndex::new(2, 4);
let data: Vec<(Vec<f32>, VectorId)> = (0..100)
.map(|i| {
let x = i as f32 / 100.0;
(vec![x, x * x], i.to_string())
})
.collect();
rmi.build(data).unwrap();
let stats = rmi.stats();
assert_eq!(stats.total_entries, 100);
assert!(stats.avg_error < 50.0); // Should have reasonable error
}
#[test]
fn test_rmi_search() {
let mut rmi = RecursiveModelIndex::new(1, 2);
let data = vec![
(vec![0.0], "0".to_string()),
(vec![0.5], "1".to_string()),
(vec![1.0], "2".to_string()),
];
rmi.build(data).unwrap();
let result = rmi.search(&[0.5]).unwrap();
assert_eq!(result, Some("1".to_string()));
}
#[test]
fn test_hybrid_index() {
let mut hybrid = HybridIndex::new(1, 2, 10);
let static_data = vec![(vec![0.0], "0".to_string()), (vec![1.0], "1".to_string())];
hybrid.build_static(static_data).unwrap();
// Add dynamic updates
hybrid.insert(vec![2.0], "2".to_string()).unwrap();
assert_eq!(hybrid.search(&[2.0]).unwrap(), Some("2".to_string()));
assert_eq!(hybrid.search(&[0.0]).unwrap(), Some("0".to_string()));
}
}

View File

@@ -0,0 +1,17 @@
//! # Advanced Techniques
//!
//! This module contains experimental and advanced features for next-generation vector search:
//! - **Hypergraphs**: n-ary relationships beyond pairwise similarity
//! - **Learned Indexes**: Neural network-based index structures
//! - **Neural Hashing**: Similarity-preserving binary projections
//! - **Topological Data Analysis**: Embedding quality assessment
pub mod hypergraph;
pub mod learned_index;
pub mod neural_hash;
pub mod tda;
pub use hypergraph::{CausalMemory, Hyperedge, HypergraphIndex, TemporalHyperedge};
pub use learned_index::{HybridIndex, LearnedIndex, RecursiveModelIndex};
pub use neural_hash::{DeepHashEmbedding, NeuralHash};
pub use tda::{EmbeddingQuality, TopologicalAnalyzer};

View File

@@ -0,0 +1,427 @@
//! # Neural Hash Functions
//!
//! Learn similarity-preserving binary projections for extreme compression.
//! Achieves 32-128x compression with 90-95% recall preservation.
use crate::types::VectorId;
use ndarray::{Array1, Array2};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Neural hash function for similarity-preserving binary codes
pub trait NeuralHash {
/// Encode a vector to binary code
fn encode(&self, vector: &[f32]) -> Vec<u8>;
/// Compute Hamming distance between two codes
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32;
/// Estimate similarity from Hamming distance
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32;
}
/// Deep hash embedding with learned projections
#[derive(Clone, Serialize, Deserialize)]
pub struct DeepHashEmbedding {
/// Projection matrices for each layer
projections: Vec<Array2<f32>>,
/// Biases for each layer
biases: Vec<Array1<f32>>,
/// Number of output bits
output_bits: usize,
/// Input dimensions
input_dims: usize,
}
impl DeepHashEmbedding {
/// Create a new deep hash embedding
pub fn new(input_dims: usize, hidden_dims: Vec<usize>, output_bits: usize) -> Self {
let mut rng = rand::thread_rng();
let mut projections = Vec::new();
let mut biases = Vec::new();
let mut layer_dims = vec![input_dims];
layer_dims.extend(&hidden_dims);
layer_dims.push(output_bits);
// Initialize random projections (Xavier initialization)
for i in 0..layer_dims.len() - 1 {
let in_dim = layer_dims[i];
let out_dim = layer_dims[i + 1];
let scale = (2.0 / (in_dim + out_dim) as f32).sqrt();
let proj = Array2::from_shape_fn((out_dim, in_dim), |_| {
rng.gen::<f32>() * 2.0 * scale - scale
});
let bias = Array1::zeros(out_dim);
projections.push(proj);
biases.push(bias);
}
Self {
projections,
biases,
output_bits,
input_dims,
}
}
/// Forward pass through the network
fn forward(&self, input: &[f32]) -> Vec<f32> {
let mut activations = Array1::from_vec(input.to_vec());
for (proj, bias) in self.projections.iter().zip(self.biases.iter()) {
// Linear layer: y = Wx + b
activations = proj.dot(&activations) + bias;
// ReLU activation (except last layer)
if proj.nrows() != self.output_bits {
activations.mapv_inplace(|x| x.max(0.0));
}
}
activations.to_vec()
}
/// Train on pairs of similar/dissimilar examples
pub fn train(
&mut self,
positive_pairs: &[(Vec<f32>, Vec<f32>)],
negative_pairs: &[(Vec<f32>, Vec<f32>)],
learning_rate: f32,
epochs: usize,
) {
// Simplified training with contrastive loss
// Production would use proper backpropagation
for _ in 0..epochs {
// Positive pairs should have small Hamming distance
for (a, b) in positive_pairs {
let code_a = self.encode(a);
let code_b = self.encode(b);
let dist = self.hamming_distance(&code_a, &code_b);
// If distance is too large, update towards similarity
if dist as f32 > self.output_bits as f32 * 0.3 {
self.update_weights(a, b, learning_rate, true);
}
}
// Negative pairs should have large Hamming distance
for (a, b) in negative_pairs {
let code_a = self.encode(a);
let code_b = self.encode(b);
let dist = self.hamming_distance(&code_a, &code_b);
// If distance is too small, update towards dissimilarity
if (dist as f32) < self.output_bits as f32 * 0.6 {
self.update_weights(a, b, learning_rate, false);
}
}
}
}
fn update_weights(&mut self, a: &[f32], b: &[f32], lr: f32, attract: bool) {
// Simplified gradient update (production would use proper autodiff)
let direction = if attract { 1.0 } else { -1.0 };
// Update only the last layer for simplicity
if let Some(last_proj) = self.projections.last_mut() {
let a_arr = Array1::from_vec(a.to_vec());
let b_arr = Array1::from_vec(b.to_vec());
for i in 0..last_proj.nrows() {
for j in 0..last_proj.ncols() {
let grad = direction * lr * (a_arr[j] - b_arr[j]);
last_proj[[i, j]] += grad * 0.001; // Small update
}
}
}
}
/// Get dimensions
pub fn dimensions(&self) -> (usize, usize) {
(self.input_dims, self.output_bits)
}
}
impl NeuralHash for DeepHashEmbedding {
fn encode(&self, vector: &[f32]) -> Vec<u8> {
if vector.len() != self.input_dims {
return vec![0; self.output_bits.div_ceil(8)];
}
let logits = self.forward(vector);
// Threshold at 0 to get binary codes
let mut bits = vec![0u8; self.output_bits.div_ceil(8)];
for (i, &logit) in logits.iter().enumerate() {
if logit > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
bits
}
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
code_a
.iter()
.zip(code_b.iter())
.map(|(a, b)| (a ^ b).count_ones())
.sum()
}
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
// Convert Hamming distance to approximate cosine similarity
let normalized_dist = hamming_dist as f32 / code_bits as f32;
1.0 - 2.0 * normalized_dist
}
}
/// Simple LSH (Locality Sensitive Hashing) baseline
#[derive(Clone, Serialize, Deserialize)]
pub struct SimpleLSH {
/// Random projection vectors
projections: Array2<f32>,
/// Number of hash bits
num_bits: usize,
}
impl SimpleLSH {
/// Create a new LSH with random projections
pub fn new(input_dims: usize, num_bits: usize) -> Self {
let mut rng = rand::thread_rng();
// Random Gaussian projections
let projections =
Array2::from_shape_fn((num_bits, input_dims), |_| rng.gen::<f32>() * 2.0 - 1.0);
Self {
projections,
num_bits,
}
}
}
impl NeuralHash for SimpleLSH {
fn encode(&self, vector: &[f32]) -> Vec<u8> {
let input = Array1::from_vec(vector.to_vec());
let projections = self.projections.dot(&input);
let mut bits = vec![0u8; self.num_bits.div_ceil(8)];
for (i, &val) in projections.iter().enumerate() {
if val > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
bits
}
fn hamming_distance(&self, code_a: &[u8], code_b: &[u8]) -> u32 {
code_a
.iter()
.zip(code_b.iter())
.map(|(a, b)| (a ^ b).count_ones())
.sum()
}
fn estimate_similarity(&self, hamming_dist: u32, code_bits: usize) -> f32 {
let normalized_dist = hamming_dist as f32 / code_bits as f32;
1.0 - 2.0 * normalized_dist
}
}
/// Hash index for fast approximate nearest neighbor search
pub struct HashIndex<H: NeuralHash + Clone> {
/// Hash function
hasher: H,
/// Hash tables: binary code -> list of vector IDs
tables: HashMap<Vec<u8>, Vec<VectorId>>,
/// Original vectors for verification
vectors: HashMap<VectorId, Vec<f32>>,
/// Code bits
code_bits: usize,
}
impl<H: NeuralHash + Clone> HashIndex<H> {
/// Create a new hash index
pub fn new(hasher: H, code_bits: usize) -> Self {
Self {
hasher,
tables: HashMap::new(),
vectors: HashMap::new(),
code_bits,
}
}
/// Insert a vector
pub fn insert(&mut self, id: VectorId, vector: Vec<f32>) {
let code = self.hasher.encode(&vector);
self.tables.entry(code).or_default().push(id.clone());
self.vectors.insert(id, vector);
}
/// Search for approximate nearest neighbors
pub fn search(&self, query: &[f32], k: usize, max_hamming: u32) -> Vec<(VectorId, f32)> {
let query_code = self.hasher.encode(query);
let mut candidates = Vec::new();
// Find all vectors within Hamming distance threshold
for (code, ids) in &self.tables {
let hamming = self.hasher.hamming_distance(&query_code, code);
if hamming <= max_hamming {
for id in ids {
if let Some(vec) = self.vectors.get(id) {
let similarity = cosine_similarity(query, vec);
candidates.push((id.clone(), similarity));
}
}
}
}
// Sort by similarity and return top-k
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
candidates.truncate(k);
candidates
}
/// Get compression ratio
pub fn compression_ratio(&self) -> f32 {
if self.vectors.is_empty() {
return 0.0;
}
let original_size: usize = self
.vectors
.values()
.map(|v| v.len() * std::mem::size_of::<f32>())
.sum();
let compressed_size = self.tables.len() * self.code_bits.div_ceil(8);
original_size as f32 / compressed_size as f32
}
/// Get statistics
pub fn stats(&self) -> HashIndexStats {
let buckets = self.tables.len();
let total_vectors = self.vectors.len();
let avg_bucket_size = if buckets > 0 {
total_vectors as f32 / buckets as f32
} else {
0.0
};
HashIndexStats {
total_vectors,
num_buckets: buckets,
avg_bucket_size,
compression_ratio: self.compression_ratio(),
}
}
}
/// Hash index statistics
#[derive(Debug, Clone)]
pub struct HashIndexStats {
pub total_vectors: usize,
pub num_buckets: usize,
pub avg_bucket_size: f32,
pub compression_ratio: f32,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot / (norm_a * norm_b)
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deep_hash_encoding() {
let hash = DeepHashEmbedding::new(4, vec![8], 16);
let vector = vec![0.1, 0.2, 0.3, 0.4];
let code = hash.encode(&vector);
assert_eq!(code.len(), 2); // 16 bits = 2 bytes
}
#[test]
fn test_hamming_distance() {
let hash = DeepHashEmbedding::new(2, vec![], 8);
let code_a = vec![0b10101010];
let code_b = vec![0b11001100];
let dist = hash.hamming_distance(&code_a, &code_b);
assert_eq!(dist, 4); // 4 bits differ
}
#[test]
fn test_lsh_encoding() {
let lsh = SimpleLSH::new(4, 16);
let vector = vec![1.0, 2.0, 3.0, 4.0];
let code = lsh.encode(&vector);
assert_eq!(code.len(), 2);
// Same vector should produce same code
let code2 = lsh.encode(&vector);
assert_eq!(code, code2);
}
#[test]
fn test_hash_index() {
let lsh = SimpleLSH::new(3, 8);
let mut index = HashIndex::new(lsh, 8);
// Insert vectors
index.insert("0".to_string(), vec![1.0, 0.0, 0.0]);
index.insert("1".to_string(), vec![0.9, 0.1, 0.0]);
index.insert("2".to_string(), vec![0.0, 1.0, 0.0]);
// Search
let results = index.search(&[1.0, 0.0, 0.0], 2, 4);
assert!(!results.is_empty());
let stats = index.stats();
assert_eq!(stats.total_vectors, 3);
}
#[test]
fn test_compression_ratio() {
let lsh = SimpleLSH::new(128, 32); // 128D -> 32 bits
let mut index = HashIndex::new(lsh, 32);
for i in 0..10 {
let vec: Vec<f32> = (0..128).map(|j| (i + j) as f32 / 128.0).collect();
index.insert(i.to_string(), vec);
}
let ratio = index.compression_ratio();
assert!(ratio > 1.0); // Should have compression
}
}

View File

@@ -0,0 +1,496 @@
//! # Topological Data Analysis (TDA)
//!
//! Basic topological analysis for embedding quality assessment.
//! Detects mode collapse, degeneracy, and topological structure.
use crate::error::{Result, RuvectorError};
use ndarray::Array2;
use serde::{Deserialize, Serialize};
/// Topological analyzer for embeddings
pub struct TopologicalAnalyzer {
/// k for k-nearest neighbors graph
k_neighbors: usize,
/// Distance threshold for edge creation
epsilon: f32,
}
impl TopologicalAnalyzer {
/// Create a new topological analyzer
pub fn new(k_neighbors: usize, epsilon: f32) -> Self {
Self {
k_neighbors,
epsilon,
}
}
/// Analyze embedding quality
pub fn analyze(&self, embeddings: &[Vec<f32>]) -> Result<EmbeddingQuality> {
if embeddings.is_empty() {
return Err(RuvectorError::InvalidInput("Empty embeddings".into()));
}
let n = embeddings.len();
let dim = embeddings[0].len();
// Build k-NN graph
let graph = self.build_knn_graph(embeddings);
// Compute topological features
let connected_components = self.count_connected_components(&graph, n);
let clustering_coefficient = self.compute_clustering_coefficient(&graph);
let degree_stats = self.compute_degree_statistics(&graph, n);
// Detect mode collapse
let mode_collapse_score = self.detect_mode_collapse(embeddings);
// Compute embedding spread
let spread = self.compute_spread(embeddings);
// Detect degeneracy (vectors collapsing to a lower-dimensional manifold)
let degeneracy_score = self.detect_degeneracy(embeddings);
// Compute persistence features (simplified)
let persistence_score = self.compute_persistence(&graph, embeddings);
// Overall quality score (0-1, higher is better)
let quality_score = self.compute_quality_score(
mode_collapse_score,
degeneracy_score,
connected_components,
clustering_coefficient,
spread,
);
Ok(EmbeddingQuality {
dimensions: dim,
num_vectors: n,
connected_components,
clustering_coefficient,
avg_degree: degree_stats.0,
degree_std: degree_stats.1,
mode_collapse_score,
degeneracy_score,
spread,
persistence_score,
quality_score,
})
}
fn build_knn_graph(&self, embeddings: &[Vec<f32>]) -> Vec<Vec<usize>> {
let n = embeddings.len();
let mut graph = vec![Vec::new(); n];
for i in 0..n {
let mut distances: Vec<(usize, f32)> = (0..n)
.filter(|&j| i != j)
.map(|j| {
let dist = euclidean_distance(&embeddings[i], &embeddings[j]);
(j, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// Add k nearest neighbors
for (j, dist) in distances.iter().take(self.k_neighbors) {
if *dist <= self.epsilon {
graph[i].push(*j);
}
}
}
graph
}
fn count_connected_components(&self, graph: &[Vec<usize>], n: usize) -> usize {
let mut visited = vec![false; n];
let mut components = 0;
for i in 0..n {
if !visited[i] {
components += 1;
self.dfs(i, graph, &mut visited);
}
}
components
}
#[allow(clippy::only_used_in_recursion)]
fn dfs(&self, node: usize, graph: &[Vec<usize>], visited: &mut [bool]) {
visited[node] = true;
for &neighbor in &graph[node] {
if !visited[neighbor] {
self.dfs(neighbor, graph, visited);
}
}
}
fn compute_clustering_coefficient(&self, graph: &[Vec<usize>]) -> f32 {
let mut total_coeff = 0.0;
let mut count = 0;
for neighbors in graph {
if neighbors.len() < 2 {
continue;
}
let k = neighbors.len();
let mut triangles = 0;
// Count triangles
for i in 0..k {
for j in i + 1..k {
let ni = neighbors[i];
let nj = neighbors[j];
if graph[ni].contains(&nj) {
triangles += 1;
}
}
}
let possible_triangles = k * (k - 1) / 2;
if possible_triangles > 0 {
total_coeff += triangles as f32 / possible_triangles as f32;
count += 1;
}
}
if count > 0 {
total_coeff / count as f32
} else {
0.0
}
}
fn compute_degree_statistics(&self, graph: &[Vec<usize>], n: usize) -> (f32, f32) {
let degrees: Vec<f32> = graph
.iter()
.map(|neighbors| neighbors.len() as f32)
.collect();
let avg = degrees.iter().sum::<f32>() / n as f32;
let variance = degrees.iter().map(|&d| (d - avg).powi(2)).sum::<f32>() / n as f32;
let std = variance.sqrt();
(avg, std)
}
fn detect_mode_collapse(&self, embeddings: &[Vec<f32>]) -> f32 {
// Compute pairwise distances
let n = embeddings.len();
let mut distances = Vec::new();
for i in 0..n {
for j in i + 1..n {
let dist = euclidean_distance(&embeddings[i], &embeddings[j]);
distances.push(dist);
}
}
if distances.is_empty() {
return 0.0;
}
// Compute coefficient of variation
let mean = distances.iter().sum::<f32>() / distances.len() as f32;
let variance =
distances.iter().map(|&d| (d - mean).powi(2)).sum::<f32>() / distances.len() as f32;
let std = variance.sqrt();
// High CV indicates good separation, low CV indicates collapse
let cv = if mean > 0.0 { std / mean } else { 0.0 };
// Normalize to 0-1, where 0 is collapsed, 1 is good
(cv * 2.0).min(1.0)
}
fn compute_spread(&self, embeddings: &[Vec<f32>]) -> f32 {
if embeddings.is_empty() {
return 0.0;
}
let dim = embeddings[0].len();
// Compute mean
let mut mean = vec![0.0; dim];
for emb in embeddings {
for (i, &val) in emb.iter().enumerate() {
mean[i] += val;
}
}
for val in mean.iter_mut() {
*val /= embeddings.len() as f32;
}
// Compute average distance from mean
let mut total_dist = 0.0;
for emb in embeddings {
let dist = euclidean_distance(emb, &mean);
total_dist += dist;
}
total_dist / embeddings.len() as f32
}
fn detect_degeneracy(&self, embeddings: &[Vec<f32>]) -> f32 {
if embeddings.is_empty() || embeddings[0].is_empty() {
return 1.0; // Fully degenerate
}
let n = embeddings.len();
let dim = embeddings[0].len();
if n < dim {
return 0.0; // Cannot determine
}
// Compute covariance matrix
let cov = self.compute_covariance_matrix(embeddings);
// Estimate rank by counting significant singular values
let singular_values = self.approximate_singular_values(&cov);
let significant = singular_values.iter().filter(|&&sv| sv > 1e-6).count();
// Degeneracy score: 0 = full rank, 1 = rank-1 (collapsed)
1.0 - (significant as f32 / dim as f32)
}
fn compute_covariance_matrix(&self, embeddings: &[Vec<f32>]) -> Array2<f32> {
let n = embeddings.len();
let dim = embeddings[0].len();
// Compute mean
let mut mean = vec![0.0; dim];
for emb in embeddings {
for (i, &val) in emb.iter().enumerate() {
mean[i] += val;
}
}
for val in mean.iter_mut() {
*val /= n as f32;
}
// Compute covariance
let mut cov = Array2::zeros((dim, dim));
for emb in embeddings {
for i in 0..dim {
for j in 0..dim {
cov[[i, j]] += (emb[i] - mean[i]) * (emb[j] - mean[j]);
}
}
}
cov.mapv(|x| x / (n - 1) as f32);
cov
}
fn approximate_singular_values(&self, matrix: &Array2<f32>) -> Vec<f32> {
// Power iteration for largest singular values (simplified)
let dim = matrix.nrows();
let mut values = Vec::new();
// Just return diagonal for approximation
for i in 0..dim {
values.push(matrix[[i, i]].abs());
}
values.sort_by(|a, b| b.partial_cmp(a).unwrap());
values
}
fn compute_persistence(&self, _graph: &[Vec<usize>], embeddings: &[Vec<f32>]) -> f32 {
// Simplified persistence: measure how graph structure changes with distance threshold
let scales = vec![0.1, 0.5, 1.0, 2.0, 5.0];
let mut component_counts = Vec::new();
for &scale in &scales {
let scaled_analyzer = TopologicalAnalyzer::new(self.k_neighbors, scale);
let scaled_graph = scaled_analyzer.build_knn_graph(embeddings);
let components =
scaled_analyzer.count_connected_components(&scaled_graph, embeddings.len());
component_counts.push(components);
}
// Persistence is the variation in component count across scales
let max_components = *component_counts.iter().max().unwrap_or(&1);
let min_components = *component_counts.iter().min().unwrap_or(&1);
(max_components - min_components) as f32 / max_components as f32
}
fn compute_quality_score(
&self,
mode_collapse: f32,
degeneracy: f32,
components: usize,
clustering: f32,
spread: f32,
) -> f32 {
// Weighted combination of metrics
let collapse_score = mode_collapse; // Higher is better
let degeneracy_score = 1.0 - degeneracy; // Lower degeneracy is better
let component_score = if components == 1 { 1.0 } else { 0.5 }; // Single component is good
let clustering_score = clustering; // Higher clustering is good
let spread_score = (spread / 10.0).min(1.0); // Reasonable spread
(collapse_score * 0.3
+ degeneracy_score * 0.3
+ component_score * 0.2
+ clustering_score * 0.1
+ spread_score * 0.1)
.clamp(0.0, 1.0)
}
}
/// Embedding quality metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingQuality {
/// Embedding dimensions
pub dimensions: usize,
/// Number of vectors
pub num_vectors: usize,
/// Number of connected components
pub connected_components: usize,
/// Clustering coefficient (0-1)
pub clustering_coefficient: f32,
/// Average node degree
pub avg_degree: f32,
/// Degree standard deviation
pub degree_std: f32,
/// Mode collapse score (0=collapsed, 1=good)
pub mode_collapse_score: f32,
/// Degeneracy score (0=full rank, 1=degenerate)
pub degeneracy_score: f32,
/// Average spread from centroid
pub spread: f32,
/// Topological persistence score
pub persistence_score: f32,
/// Overall quality (0-1, higher is better)
pub quality_score: f32,
}
impl EmbeddingQuality {
/// Check if embeddings show signs of mode collapse
pub fn has_mode_collapse(&self) -> bool {
self.mode_collapse_score < 0.3
}
/// Check if embeddings are degenerate
pub fn is_degenerate(&self) -> bool {
self.degeneracy_score > 0.7
}
/// Check if embeddings are well-structured
pub fn is_good_quality(&self) -> bool {
self.quality_score > 0.7
}
/// Get quality assessment
pub fn assessment(&self) -> &str {
if self.quality_score > 0.8 {
"Excellent"
} else if self.quality_score > 0.6 {
"Good"
} else if self.quality_score > 0.4 {
"Fair"
} else {
"Poor"
}
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_analysis() {
let analyzer = TopologicalAnalyzer::new(3, 5.0);
// Create well-separated embeddings
let embeddings = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![0.2, 0.2],
vec![5.0, 5.0],
vec![5.1, 5.1],
];
let quality = analyzer.analyze(&embeddings).unwrap();
assert_eq!(quality.dimensions, 2);
assert_eq!(quality.num_vectors, 5);
assert!(quality.quality_score > 0.0);
}
#[test]
fn test_mode_collapse_detection() {
let analyzer = TopologicalAnalyzer::new(2, 10.0);
// Well-separated embeddings (high CV should give high score)
let good = vec![vec![0.0, 0.0], vec![5.0, 5.0], vec![10.0, 10.0]];
let score_good = analyzer.detect_mode_collapse(&good);
// Collapsed embeddings (all identical, CV = 0)
let collapsed = vec![vec![1.0, 1.0], vec![1.0, 1.0], vec![1.0, 1.0]];
let score_collapsed = analyzer.detect_mode_collapse(&collapsed);
// Identical vectors should have score 0 (distances all same = CV 0)
assert_eq!(score_collapsed, 0.0);
// Well-separated should have higher score
assert!(score_good > score_collapsed);
}
#[test]
fn test_connected_components() {
let analyzer = TopologicalAnalyzer::new(1, 1.0);
// Two separate clusters
let embeddings = vec![
vec![0.0, 0.0],
vec![0.5, 0.5],
vec![10.0, 10.0],
vec![10.5, 10.5],
];
let graph = analyzer.build_knn_graph(&embeddings);
let components = analyzer.count_connected_components(&graph, embeddings.len());
assert!(components >= 2); // Should have at least 2 components
}
#[test]
fn test_quality_assessment() {
let quality = EmbeddingQuality {
dimensions: 128,
num_vectors: 1000,
connected_components: 1,
clustering_coefficient: 0.6,
avg_degree: 5.0,
degree_std: 1.2,
mode_collapse_score: 0.8,
degeneracy_score: 0.2,
spread: 3.5,
persistence_score: 0.4,
quality_score: 0.75,
};
assert!(!quality.has_mode_collapse());
assert!(!quality.is_degenerate());
assert!(quality.is_good_quality());
assert_eq!(quality.assessment(), "Good");
}
}

View File

@@ -0,0 +1,23 @@
//! Advanced Features for Ruvector
//!
//! This module provides advanced vector database capabilities:
//! - Enhanced Product Quantization with precomputed lookup tables
//! - Filtered Search with automatic strategy selection
//! - MMR (Maximal Marginal Relevance) for diversity
//! - Hybrid Search combining vector and keyword matching
//! - Conformal Prediction for uncertainty quantification
pub mod conformal_prediction;
pub mod filtered_search;
pub mod hybrid_search;
pub mod mmr;
pub mod product_quantization;
// Re-exports
pub use conformal_prediction::{
ConformalConfig, ConformalPredictor, NonconformityMeasure, PredictionSet,
};
pub use filtered_search::{FilterExpression, FilterStrategy, FilteredSearch};
pub use hybrid_search::{HybridConfig, HybridSearch, NormalizationStrategy, BM25};
pub use mmr::{MMRConfig, MMRSearch};
pub use product_quantization::{EnhancedPQ, LookupTable, PQConfig};

View File

@@ -0,0 +1,503 @@
//! Conformal Prediction for Uncertainty Quantification
//!
//! Implements conformal prediction to provide statistically valid uncertainty estimates
//! and prediction sets with guaranteed coverage (1-α).
use crate::error::{Result, RuvectorError};
use crate::types::{SearchResult, VectorId};
use serde::{Deserialize, Serialize};
/// Configuration for conformal prediction
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConformalConfig {
/// Significance level (alpha) - typically 0.05 or 0.10
pub alpha: f32,
/// Size of calibration set (as fraction of total data)
pub calibration_fraction: f32,
/// Non-conformity measure type
pub nonconformity_measure: NonconformityMeasure,
}
impl Default for ConformalConfig {
fn default() -> Self {
Self {
alpha: 0.1, // 90% coverage
calibration_fraction: 0.2,
nonconformity_measure: NonconformityMeasure::Distance,
}
}
}
/// Type of non-conformity measure
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NonconformityMeasure {
/// Use distance score as non-conformity
Distance,
/// Use inverse rank as non-conformity
InverseRank,
/// Use normalized distance (distance / avg_distance)
NormalizedDistance,
}
/// Prediction set with conformal guarantees
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionSet {
/// Results in the prediction set
pub results: Vec<SearchResult>,
/// Conformal threshold used
pub threshold: f32,
/// Confidence level (1 - alpha)
pub confidence: f32,
/// Coverage guarantee
pub coverage_guarantee: f32,
}
/// Conformal predictor for vector search
#[derive(Debug, Clone)]
pub struct ConformalPredictor {
/// Configuration
pub config: ConformalConfig,
/// Calibration set: non-conformity scores
pub calibration_scores: Vec<f32>,
/// Conformal threshold (quantile of calibration scores)
pub threshold: Option<f32>,
}
impl ConformalPredictor {
/// Create a new conformal predictor
pub fn new(config: ConformalConfig) -> Result<Self> {
if !(0.0..=1.0).contains(&config.alpha) {
return Err(RuvectorError::InvalidParameter(format!(
"Alpha must be in [0, 1], got {}",
config.alpha
)));
}
if !(0.0..=1.0).contains(&config.calibration_fraction) {
return Err(RuvectorError::InvalidParameter(format!(
"Calibration fraction must be in [0, 1], got {}",
config.calibration_fraction
)));
}
Ok(Self {
config,
calibration_scores: Vec::new(),
threshold: None,
})
}
/// Calibrate on a set of validation examples
///
/// # Arguments
/// * `validation_queries` - Query vectors for calibration
/// * `true_neighbors` - Ground truth neighbors for each query
/// * `search_fn` - Function to perform search
pub fn calibrate<F>(
&mut self,
validation_queries: &[Vec<f32>],
true_neighbors: &[Vec<VectorId>],
search_fn: F,
) -> Result<()>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
if validation_queries.len() != true_neighbors.len() {
return Err(RuvectorError::InvalidParameter(
"Number of queries must match number of true neighbor sets".to_string(),
));
}
if validation_queries.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Calibration set cannot be empty".to_string(),
));
}
let mut all_scores = Vec::new();
// Compute non-conformity scores for calibration set
for (query, true_ids) in validation_queries.iter().zip(true_neighbors) {
// Search for neighbors
let results = search_fn(query, 100)?; // Fetch more results
// Compute non-conformity scores for true neighbors
for true_id in true_ids {
let score = self.compute_nonconformity_score(&results, true_id)?;
all_scores.push(score);
}
}
self.calibration_scores = all_scores;
// Compute threshold as (1 - alpha) quantile
self.compute_threshold()?;
Ok(())
}
/// Compute conformal threshold from calibration scores
fn compute_threshold(&mut self) -> Result<()> {
if self.calibration_scores.is_empty() {
return Err(RuvectorError::InvalidParameter(
"No calibration scores available".to_string(),
));
}
let mut sorted_scores = self.calibration_scores.clone();
sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
// Compute (1 - alpha) quantile
let n = sorted_scores.len();
let quantile_index = ((1.0 - self.config.alpha) * (n as f32 + 1.0)).ceil() as usize;
let quantile_index = quantile_index.min(n - 1);
self.threshold = Some(sorted_scores[quantile_index]);
Ok(())
}
/// Compute non-conformity score for a specific result
fn compute_nonconformity_score(
&self,
results: &[SearchResult],
target_id: &VectorId,
) -> Result<f32> {
match self.config.nonconformity_measure {
NonconformityMeasure::Distance => {
// Use distance score directly
results
.iter()
.find(|r| &r.id == target_id)
.map(|r| r.score)
.ok_or_else(|| {
RuvectorError::VectorNotFound(format!(
"Target {} not in results",
target_id
))
})
}
NonconformityMeasure::InverseRank => {
// Use inverse rank: 1 / (rank + 1)
let rank = results
.iter()
.position(|r| &r.id == target_id)
.ok_or_else(|| {
RuvectorError::VectorNotFound(format!(
"Target {} not in results",
target_id
))
})?;
Ok(1.0 / (rank as f32 + 1.0))
}
NonconformityMeasure::NormalizedDistance => {
// Normalize by average distance
let target_score = results
.iter()
.find(|r| &r.id == target_id)
.map(|r| r.score)
.ok_or_else(|| {
RuvectorError::VectorNotFound(format!(
"Target {} not in results",
target_id
))
})?;
// Guard against empty results
if results.is_empty() {
return Ok(target_score);
}
let avg_score = results.iter().map(|r| r.score).sum::<f32>() / results.len() as f32;
Ok(if avg_score > 0.0 {
target_score / avg_score
} else {
target_score
})
}
}
}
/// Make prediction with conformal guarantee
///
/// # Arguments
/// * `query` - Query vector
/// * `search_fn` - Function to perform search
///
/// # Returns
/// Prediction set with coverage guarantee
pub fn predict<F>(&self, query: &[f32], search_fn: F) -> Result<PredictionSet>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
let threshold = self.threshold.ok_or_else(|| {
RuvectorError::InvalidParameter("Predictor not calibrated yet".to_string())
})?;
// Perform search with large k
let results = search_fn(query, 1000)?;
// Select results based on non-conformity threshold
let prediction_set: Vec<SearchResult> = match self.config.nonconformity_measure {
NonconformityMeasure::Distance => {
// Include all results with distance <= threshold
results
.into_iter()
.filter(|r| r.score <= threshold)
.collect()
}
NonconformityMeasure::InverseRank => {
// Include top-k results where k is determined by threshold
let k = (1.0 / threshold).ceil() as usize;
results.into_iter().take(k).collect()
}
NonconformityMeasure::NormalizedDistance => {
// Guard against empty results
if results.is_empty() {
return Ok(PredictionSet {
results: vec![],
threshold,
confidence: 1.0 - self.config.alpha,
coverage_guarantee: 1.0 - self.config.alpha,
});
}
let avg_score = results.iter().map(|r| r.score).sum::<f32>() / results.len() as f32;
let adjusted_threshold = threshold * avg_score;
results
.into_iter()
.filter(|r| r.score <= adjusted_threshold)
.collect()
}
};
Ok(PredictionSet {
results: prediction_set,
threshold,
confidence: 1.0 - self.config.alpha,
coverage_guarantee: 1.0 - self.config.alpha,
})
}
/// Compute adaptive top-k based on uncertainty
///
/// # Arguments
/// * `query` - Query vector
/// * `search_fn` - Function to perform search
///
/// # Returns
/// Number of results to return based on uncertainty
pub fn adaptive_top_k<F>(&self, query: &[f32], search_fn: F) -> Result<usize>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
let prediction_set = self.predict(query, search_fn)?;
Ok(prediction_set.results.len())
}
/// Get calibration statistics
pub fn get_statistics(&self) -> Option<CalibrationStats> {
if self.calibration_scores.is_empty() {
return None;
}
let n = self.calibration_scores.len() as f32;
let mean = self.calibration_scores.iter().sum::<f32>() / n;
let variance = self
.calibration_scores
.iter()
.map(|&s| (s - mean).powi(2))
.sum::<f32>()
/ n;
let std = variance.sqrt();
let mut sorted = self.calibration_scores.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
Some(CalibrationStats {
num_samples: self.calibration_scores.len(),
mean,
std,
min: sorted.first().copied().unwrap(),
max: sorted.last().copied().unwrap(),
median: sorted[sorted.len() / 2],
threshold: self.threshold,
})
}
}
/// Calibration statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationStats {
/// Number of calibration samples
pub num_samples: usize,
/// Mean non-conformity score
pub mean: f32,
/// Standard deviation
pub std: f32,
/// Minimum score
pub min: f32,
/// Maximum score
pub max: f32,
/// Median score
pub median: f32,
/// Conformal threshold
pub threshold: Option<f32>,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_search_result(id: &str, score: f32) -> SearchResult {
SearchResult {
id: id.to_string(),
score,
vector: Some(vec![0.0; 10]),
metadata: None,
}
}
fn mock_search_fn(_query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
Ok((0..k)
.map(|i| create_search_result(&format!("doc_{}", i), i as f32 * 0.1))
.collect())
}
#[test]
fn test_conformal_config_validation() {
let config = ConformalConfig {
alpha: 0.1,
..Default::default()
};
assert!(ConformalPredictor::new(config).is_ok());
let invalid_config = ConformalConfig {
alpha: 1.5,
..Default::default()
};
assert!(ConformalPredictor::new(invalid_config).is_err());
}
#[test]
fn test_conformal_calibration() {
let config = ConformalConfig::default();
let mut predictor = ConformalPredictor::new(config).unwrap();
// Create calibration data
let queries = vec![vec![1.0; 10], vec![2.0; 10], vec![3.0; 10]];
let true_neighbors = vec![
vec!["doc_0".to_string(), "doc_1".to_string()],
vec!["doc_0".to_string()],
vec!["doc_1".to_string(), "doc_2".to_string()],
];
predictor
.calibrate(&queries, &true_neighbors, mock_search_fn)
.unwrap();
assert!(!predictor.calibration_scores.is_empty());
assert!(predictor.threshold.is_some());
}
#[test]
fn test_conformal_prediction() {
let config = ConformalConfig {
alpha: 0.1,
calibration_fraction: 0.2,
nonconformity_measure: NonconformityMeasure::Distance,
};
let mut predictor = ConformalPredictor::new(config).unwrap();
// Calibrate
let queries = vec![vec![1.0; 10], vec![2.0; 10]];
let true_neighbors = vec![vec!["doc_0".to_string()], vec!["doc_1".to_string()]];
predictor
.calibrate(&queries, &true_neighbors, mock_search_fn)
.unwrap();
// Make prediction
let query = vec![1.5; 10];
let prediction_set = predictor.predict(&query, mock_search_fn).unwrap();
assert!(!prediction_set.results.is_empty());
assert_eq!(prediction_set.confidence, 0.9);
assert!(prediction_set.threshold > 0.0);
}
#[test]
fn test_nonconformity_distance() {
let config = ConformalConfig {
nonconformity_measure: NonconformityMeasure::Distance,
..Default::default()
};
let predictor = ConformalPredictor::new(config).unwrap();
let results = vec![
create_search_result("doc_0", 0.1),
create_search_result("doc_1", 0.3),
create_search_result("doc_2", 0.5),
];
let score = predictor
.compute_nonconformity_score(&results, &"doc_1".to_string())
.unwrap();
assert!((score - 0.3).abs() < 0.01);
}
#[test]
fn test_nonconformity_inverse_rank() {
let config = ConformalConfig {
nonconformity_measure: NonconformityMeasure::InverseRank,
..Default::default()
};
let predictor = ConformalPredictor::new(config).unwrap();
let results = vec![
create_search_result("doc_0", 0.1),
create_search_result("doc_1", 0.3),
create_search_result("doc_2", 0.5),
];
let score = predictor
.compute_nonconformity_score(&results, &"doc_1".to_string())
.unwrap();
assert!((score - 0.5).abs() < 0.01); // 1 / (1 + 1) = 0.5
}
#[test]
fn test_calibration_stats() {
let config = ConformalConfig::default();
let mut predictor = ConformalPredictor::new(config).unwrap();
predictor.calibration_scores = vec![0.1, 0.2, 0.3, 0.4, 0.5];
predictor.threshold = Some(0.4);
let stats = predictor.get_statistics().unwrap();
assert_eq!(stats.num_samples, 5);
assert!((stats.mean - 0.3).abs() < 0.01);
assert!((stats.min - 0.1).abs() < 0.01);
assert!((stats.max - 0.5).abs() < 0.01);
}
#[test]
fn test_adaptive_top_k() {
let config = ConformalConfig::default();
let mut predictor = ConformalPredictor::new(config).unwrap();
// Calibrate
let queries = vec![vec![1.0; 10], vec![2.0; 10]];
let true_neighbors = vec![vec!["doc_0".to_string()], vec!["doc_1".to_string()]];
predictor
.calibrate(&queries, &true_neighbors, mock_search_fn)
.unwrap();
// Test adaptive k
let query = vec![1.5; 10];
let k = predictor.adaptive_top_k(&query, mock_search_fn).unwrap();
assert!(k > 0);
}
}

View File

@@ -0,0 +1,363 @@
//! Filtered Search with Automatic Strategy Selection
//!
//! Supports two filtering strategies:
//! - Pre-filtering: Apply metadata filters before graph traversal
//! - Post-filtering: Traverse graph then apply filters
//! - Automatic strategy selection based on filter selectivity
use crate::error::Result;
use crate::types::{SearchResult, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Filter strategy selection
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FilterStrategy {
/// Apply filters before search (efficient for highly selective filters)
PreFilter,
/// Apply filters after search (efficient for low selectivity)
PostFilter,
/// Automatically select strategy based on estimated selectivity
Auto,
}
/// Filter expression for metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FilterExpression {
/// Equality check: field == value
Eq(String, serde_json::Value),
/// Not equal: field != value
Ne(String, serde_json::Value),
/// Greater than: field > value
Gt(String, serde_json::Value),
/// Greater than or equal: field >= value
Gte(String, serde_json::Value),
/// Less than: field < value
Lt(String, serde_json::Value),
/// Less than or equal: field <= value
Lte(String, serde_json::Value),
/// In list: field in [values]
In(String, Vec<serde_json::Value>),
/// Not in list: field not in [values]
NotIn(String, Vec<serde_json::Value>),
/// Range check: min <= field <= max
Range(String, serde_json::Value, serde_json::Value),
/// Logical AND
And(Vec<FilterExpression>),
/// Logical OR
Or(Vec<FilterExpression>),
/// Logical NOT
Not(Box<FilterExpression>),
}
impl FilterExpression {
/// Evaluate filter against metadata
pub fn evaluate(&self, metadata: &HashMap<String, serde_json::Value>) -> bool {
match self {
FilterExpression::Eq(field, value) => metadata.get(field) == Some(value),
FilterExpression::Ne(field, value) => metadata.get(field) != Some(value),
FilterExpression::Gt(field, value) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, value) > 0
} else {
false
}
}
FilterExpression::Gte(field, value) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, value) >= 0
} else {
false
}
}
FilterExpression::Lt(field, value) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, value) < 0
} else {
false
}
}
FilterExpression::Lte(field, value) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, value) <= 0
} else {
false
}
}
FilterExpression::In(field, values) => {
if let Some(field_value) = metadata.get(field) {
values.contains(field_value)
} else {
false
}
}
FilterExpression::NotIn(field, values) => {
if let Some(field_value) = metadata.get(field) {
!values.contains(field_value)
} else {
true
}
}
FilterExpression::Range(field, min, max) => {
if let Some(field_value) = metadata.get(field) {
compare_values(field_value, min) >= 0 && compare_values(field_value, max) <= 0
} else {
false
}
}
FilterExpression::And(exprs) => exprs.iter().all(|e| e.evaluate(metadata)),
FilterExpression::Or(exprs) => exprs.iter().any(|e| e.evaluate(metadata)),
FilterExpression::Not(expr) => !expr.evaluate(metadata),
}
}
/// Estimate selectivity of filter (0.0 = very selective, 1.0 = not selective)
#[allow(clippy::only_used_in_recursion)]
pub fn estimate_selectivity(&self, total_vectors: usize) -> f32 {
match self {
FilterExpression::Eq(_, _) => 0.1, // Equality is typically selective
FilterExpression::Ne(_, _) => 0.9, // Not equal is less selective
FilterExpression::In(_, values) => (values.len() as f32) / 100.0,
FilterExpression::NotIn(_, values) => 1.0 - (values.len() as f32) / 100.0,
FilterExpression::Range(_, _, _) => 0.3, // Ranges are moderately selective
FilterExpression::Gt(_, _) | FilterExpression::Gte(_, _) => 0.5,
FilterExpression::Lt(_, _) | FilterExpression::Lte(_, _) => 0.5,
FilterExpression::And(exprs) => {
// AND is more selective (multiply selectivities)
exprs
.iter()
.map(|e| e.estimate_selectivity(total_vectors))
.product()
}
FilterExpression::Or(exprs) => {
// OR is less selective (sum selectivities, capped at 1.0)
exprs
.iter()
.map(|e| e.estimate_selectivity(total_vectors))
.sum::<f32>()
.min(1.0)
}
FilterExpression::Not(expr) => 1.0 - expr.estimate_selectivity(total_vectors),
}
}
}
/// Filtered search implementation
#[derive(Debug, Clone)]
pub struct FilteredSearch {
/// Filter expression
pub filter: FilterExpression,
/// Strategy for applying filter
pub strategy: FilterStrategy,
/// Metadata store: id -> metadata
pub metadata_store: HashMap<VectorId, HashMap<String, serde_json::Value>>,
}
impl FilteredSearch {
/// Create a new filtered search instance
pub fn new(
filter: FilterExpression,
strategy: FilterStrategy,
metadata_store: HashMap<VectorId, HashMap<String, serde_json::Value>>,
) -> Self {
Self {
filter,
strategy,
metadata_store,
}
}
/// Automatically select strategy based on filter selectivity
pub fn auto_select_strategy(&self) -> FilterStrategy {
let selectivity = self.filter.estimate_selectivity(self.metadata_store.len());
// If filter is highly selective (< 20%), use pre-filtering
// Otherwise use post-filtering
if selectivity < 0.2 {
FilterStrategy::PreFilter
} else {
FilterStrategy::PostFilter
}
}
/// Get list of vector IDs that pass the filter (for pre-filtering)
pub fn get_filtered_ids(&self) -> Vec<VectorId> {
self.metadata_store
.iter()
.filter(|(_, metadata)| self.filter.evaluate(metadata))
.map(|(id, _)| id.clone())
.collect()
}
/// Apply filter to search results (for post-filtering)
pub fn filter_results(&self, results: Vec<SearchResult>) -> Vec<SearchResult> {
results
.into_iter()
.filter(|result| {
if let Some(metadata) = result.metadata.as_ref() {
self.filter.evaluate(metadata)
} else {
false
}
})
.collect()
}
/// Apply filtered search with automatic strategy selection
pub fn search<F>(&self, query: &[f32], k: usize, search_fn: F) -> Result<Vec<SearchResult>>
where
F: Fn(&[f32], usize, Option<&[VectorId]>) -> Result<Vec<SearchResult>>,
{
let strategy = match self.strategy {
FilterStrategy::Auto => self.auto_select_strategy(),
other => other,
};
match strategy {
FilterStrategy::PreFilter => {
// Get filtered IDs first
let filtered_ids = self.get_filtered_ids();
if filtered_ids.is_empty() {
return Ok(Vec::new());
}
// Search only within filtered IDs
// We may need to fetch more results to get k after filtering
let fetch_k = (k as f32 * 1.5).ceil() as usize;
search_fn(query, fetch_k, Some(&filtered_ids))
}
FilterStrategy::PostFilter => {
// Search first, then filter
// Fetch more results to ensure we get k after filtering
let fetch_k = (k as f32 * 2.0).ceil() as usize;
let results = search_fn(query, fetch_k, None)?;
// Apply filter
let filtered = self.filter_results(results);
// Return top-k
Ok(filtered.into_iter().take(k).collect())
}
FilterStrategy::Auto => unreachable!(),
}
}
}
// Helper function to compare JSON values
fn compare_values(a: &serde_json::Value, b: &serde_json::Value) -> i32 {
use serde_json::Value;
match (a, b) {
(Value::Number(a), Value::Number(b)) => {
let a_f64 = a.as_f64().unwrap_or(0.0);
let b_f64 = b.as_f64().unwrap_or(0.0);
if a_f64 < b_f64 {
-1
} else if a_f64 > b_f64 {
1
} else {
0
}
}
(Value::String(a), Value::String(b)) => a.cmp(b) as i32,
(Value::Bool(a), Value::Bool(b)) => a.cmp(b) as i32,
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_filter_eq() {
let mut metadata = HashMap::new();
metadata.insert("category".to_string(), json!("electronics"));
let filter = FilterExpression::Eq("category".to_string(), json!("electronics"));
assert!(filter.evaluate(&metadata));
let filter = FilterExpression::Eq("category".to_string(), json!("books"));
assert!(!filter.evaluate(&metadata));
}
#[test]
fn test_filter_range() {
let mut metadata = HashMap::new();
metadata.insert("price".to_string(), json!(50.0));
let filter = FilterExpression::Range("price".to_string(), json!(10.0), json!(100.0));
assert!(filter.evaluate(&metadata));
let filter = FilterExpression::Range("price".to_string(), json!(60.0), json!(100.0));
assert!(!filter.evaluate(&metadata));
}
#[test]
fn test_filter_and() {
let mut metadata = HashMap::new();
metadata.insert("category".to_string(), json!("electronics"));
metadata.insert("price".to_string(), json!(50.0));
let filter = FilterExpression::And(vec![
FilterExpression::Eq("category".to_string(), json!("electronics")),
FilterExpression::Lt("price".to_string(), json!(100.0)),
]);
assert!(filter.evaluate(&metadata));
}
#[test]
fn test_filter_or() {
let mut metadata = HashMap::new();
metadata.insert("category".to_string(), json!("electronics"));
let filter = FilterExpression::Or(vec![
FilterExpression::Eq("category".to_string(), json!("books")),
FilterExpression::Eq("category".to_string(), json!("electronics")),
]);
assert!(filter.evaluate(&metadata));
}
#[test]
fn test_filter_in() {
let mut metadata = HashMap::new();
metadata.insert("tag".to_string(), json!("popular"));
let filter = FilterExpression::In(
"tag".to_string(),
vec![json!("popular"), json!("trending"), json!("new")],
);
assert!(filter.evaluate(&metadata));
}
#[test]
fn test_selectivity_estimation() {
let filter_eq = FilterExpression::Eq("field".to_string(), json!("value"));
assert!(filter_eq.estimate_selectivity(1000) < 0.5);
let filter_ne = FilterExpression::Ne("field".to_string(), json!("value"));
assert!(filter_ne.estimate_selectivity(1000) > 0.5);
}
#[test]
fn test_auto_strategy_selection() {
let mut metadata_store = HashMap::new();
for i in 0..100 {
let mut metadata = HashMap::new();
metadata.insert("id".to_string(), json!(i));
metadata_store.insert(format!("vec_{}", i), metadata);
}
// Highly selective filter should choose pre-filter
let filter = FilterExpression::Eq("id".to_string(), json!(42));
let search = FilteredSearch::new(filter, FilterStrategy::Auto, metadata_store.clone());
assert_eq!(search.auto_select_strategy(), FilterStrategy::PreFilter);
// Less selective filter should choose post-filter
let filter = FilterExpression::Gte("id".to_string(), json!(0));
let search = FilteredSearch::new(filter, FilterStrategy::Auto, metadata_store);
assert_eq!(search.auto_select_strategy(), FilterStrategy::PostFilter);
}
}

View File

@@ -0,0 +1,444 @@
//! Hybrid Search: Combining Vector Similarity and Keyword Matching
//!
//! Implements hybrid search by combining:
//! - Vector similarity search (semantic)
//! - BM25 keyword matching (lexical)
//! - Weighted combination of scores
use crate::error::Result;
use crate::types::{SearchResult, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
/// Configuration for hybrid search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridConfig {
/// Weight for vector similarity (alpha)
pub vector_weight: f32,
/// Weight for keyword matching (beta)
pub keyword_weight: f32,
/// Normalization strategy
pub normalization: NormalizationStrategy,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
vector_weight: 0.7,
keyword_weight: 0.3,
normalization: NormalizationStrategy::MinMax,
}
}
}
/// Score normalization strategy
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NormalizationStrategy {
/// Min-max normalization: (x - min) / (max - min)
MinMax,
/// Z-score normalization: (x - mean) / std
ZScore,
/// No normalization
None,
}
/// Simple BM25 implementation for keyword matching
#[derive(Debug, Clone)]
pub struct BM25 {
/// IDF scores for terms
pub idf: HashMap<String, f32>,
/// Average document length
pub avg_doc_len: f32,
/// Document lengths
pub doc_lengths: HashMap<VectorId, usize>,
/// Inverted index: term -> set of doc IDs
pub inverted_index: HashMap<String, HashSet<VectorId>>,
/// BM25 parameters
pub k1: f32,
pub b: f32,
}
impl BM25 {
/// Create a new BM25 instance
pub fn new(k1: f32, b: f32) -> Self {
Self {
idf: HashMap::new(),
avg_doc_len: 0.0,
doc_lengths: HashMap::new(),
inverted_index: HashMap::new(),
k1,
b,
}
}
/// Index a document
pub fn index_document(&mut self, doc_id: VectorId, text: &str) {
let terms = tokenize(text);
self.doc_lengths.insert(doc_id.clone(), terms.len());
// Update inverted index
for term in terms {
self.inverted_index
.entry(term)
.or_default()
.insert(doc_id.clone());
}
// Update average document length
let total_len: usize = self.doc_lengths.values().sum();
self.avg_doc_len = total_len as f32 / self.doc_lengths.len() as f32;
}
/// Build IDF scores after indexing all documents
pub fn build_idf(&mut self) {
let num_docs = self.doc_lengths.len() as f32;
for (term, doc_set) in &self.inverted_index {
let doc_freq = doc_set.len() as f32;
let idf = ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
self.idf.insert(term.clone(), idf);
}
}
/// Compute BM25 score for a query against a document
pub fn score(&self, query: &str, doc_id: &VectorId, doc_text: &str) -> f32 {
let query_terms = tokenize(query);
let doc_terms = tokenize(doc_text);
let doc_len = self.doc_lengths.get(doc_id).copied().unwrap_or(0) as f32;
// Count term frequencies in document
let mut term_freq: HashMap<String, f32> = HashMap::new();
for term in doc_terms {
*term_freq.entry(term).or_insert(0.0) += 1.0;
}
// Calculate BM25 score
let mut score = 0.0;
for term in query_terms {
let idf = self.idf.get(&term).copied().unwrap_or(0.0);
let tf = term_freq.get(&term).copied().unwrap_or(0.0);
let numerator = tf * (self.k1 + 1.0);
let denominator = tf + self.k1 * (1.0 - self.b + self.b * (doc_len / self.avg_doc_len));
score += idf * (numerator / denominator);
}
score
}
/// Get all documents containing at least one query term
pub fn get_candidate_docs(&self, query: &str) -> HashSet<VectorId> {
let query_terms = tokenize(query);
let mut candidates = HashSet::new();
for term in query_terms {
if let Some(doc_set) = self.inverted_index.get(&term) {
candidates.extend(doc_set.iter().cloned());
}
}
candidates
}
}
/// Hybrid search combining vector and keyword matching
#[derive(Debug, Clone)]
pub struct HybridSearch {
/// Configuration
pub config: HybridConfig,
/// BM25 index for keyword matching
pub bm25: BM25,
/// Document texts for BM25 scoring
pub doc_texts: HashMap<VectorId, String>,
}
impl HybridSearch {
/// Create a new hybrid search instance
pub fn new(config: HybridConfig) -> Self {
Self {
config,
bm25: BM25::new(1.5, 0.75), // Standard BM25 parameters
doc_texts: HashMap::new(),
}
}
/// Index a document with both vector and text
pub fn index_document(&mut self, doc_id: VectorId, text: String) {
self.bm25.index_document(doc_id.clone(), &text);
self.doc_texts.insert(doc_id, text);
}
/// Finalize indexing (build IDF scores)
pub fn finalize_indexing(&mut self) {
self.bm25.build_idf();
}
/// Perform hybrid search
///
/// # Arguments
/// * `query_vector` - Query vector for semantic search
/// * `query_text` - Query text for keyword matching
/// * `k` - Number of results to return
/// * `vector_search_fn` - Function to perform vector similarity search
///
/// # Returns
/// Combined and reranked search results
pub fn search<F>(
&self,
query_vector: &[f32],
query_text: &str,
k: usize,
vector_search_fn: F,
) -> Result<Vec<SearchResult>>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
// Get vector similarity results
let vector_results = vector_search_fn(query_vector, k * 2)?;
// Get keyword matching candidates
let keyword_candidates = self.bm25.get_candidate_docs(query_text);
// Compute BM25 scores for all candidates
let mut bm25_scores: HashMap<VectorId, f32> = HashMap::new();
for doc_id in &keyword_candidates {
if let Some(doc_text) = self.doc_texts.get(doc_id) {
let score = self.bm25.score(query_text, doc_id, doc_text);
bm25_scores.insert(doc_id.clone(), score);
}
}
// Combine results
let mut combined_results: HashMap<VectorId, CombinedScore> = HashMap::new();
// Add vector results
for result in vector_results {
combined_results.insert(
result.id.clone(),
CombinedScore {
id: result.id.clone(),
vector_score: Some(result.score),
keyword_score: bm25_scores.get(&result.id).copied(),
vector: result.vector,
metadata: result.metadata,
},
);
}
// Add keyword-only results
for (doc_id, bm25_score) in bm25_scores {
combined_results
.entry(doc_id.clone())
.or_insert(CombinedScore {
id: doc_id,
vector_score: None,
keyword_score: Some(bm25_score),
vector: None,
metadata: None,
});
}
// Normalize and combine scores
let normalized_results =
self.normalize_and_combine(combined_results.into_values().collect())?;
// Sort by combined score (descending)
let mut sorted_results = normalized_results;
sorted_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
// Return top-k
Ok(sorted_results.into_iter().take(k).collect())
}
/// Normalize and combine scores
fn normalize_and_combine(&self, results: Vec<CombinedScore>) -> Result<Vec<SearchResult>> {
let mut vector_scores: Vec<f32> = results.iter().filter_map(|r| r.vector_score).collect();
let mut keyword_scores: Vec<f32> = results.iter().filter_map(|r| r.keyword_score).collect();
// Normalize scores
normalize_scores(&mut vector_scores, self.config.normalization);
normalize_scores(&mut keyword_scores, self.config.normalization);
// Create lookup maps
let mut vector_map: HashMap<VectorId, f32> = HashMap::new();
let mut keyword_map: HashMap<VectorId, f32> = HashMap::new();
for (result, &norm_score) in results.iter().zip(&vector_scores) {
if result.vector_score.is_some() {
vector_map.insert(result.id.clone(), norm_score);
}
}
for (result, &norm_score) in results.iter().zip(&keyword_scores) {
if result.keyword_score.is_some() {
keyword_map.insert(result.id.clone(), norm_score);
}
}
// Combine scores
let combined: Vec<SearchResult> = results
.into_iter()
.map(|result| {
let vector_norm = vector_map.get(&result.id).copied().unwrap_or(0.0);
let keyword_norm = keyword_map.get(&result.id).copied().unwrap_or(0.0);
let combined_score = self.config.vector_weight * vector_norm
+ self.config.keyword_weight * keyword_norm;
SearchResult {
id: result.id,
score: combined_score,
vector: result.vector,
metadata: result.metadata,
}
})
.collect();
Ok(combined)
}
}
/// Combined score holder
#[derive(Debug, Clone)]
struct CombinedScore {
id: VectorId,
vector_score: Option<f32>,
keyword_score: Option<f32>,
vector: Option<Vec<f32>>,
metadata: Option<HashMap<String, serde_json::Value>>,
}
// Helper functions
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.filter(|s| s.len() > 2) // Remove very short tokens
.map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
.filter(|s| !s.is_empty())
.collect()
}
fn normalize_scores(scores: &mut [f32], strategy: NormalizationStrategy) {
if scores.is_empty() {
return;
}
match strategy {
NormalizationStrategy::MinMax => {
let min = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max - min;
if range > 0.0 {
for score in scores.iter_mut() {
*score = (*score - min) / range;
}
}
}
NormalizationStrategy::ZScore => {
let mean = scores.iter().sum::<f32>() / scores.len() as f32;
let variance =
scores.iter().map(|&s| (s - mean).powi(2)).sum::<f32>() / scores.len() as f32;
let std = variance.sqrt();
if std > 0.0 {
for score in scores.iter_mut() {
*score = (*score - mean) / std;
}
}
}
NormalizationStrategy::None => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenize() {
let text = "The quick brown fox jumps over the lazy dog!";
let tokens = tokenize(text);
assert!(tokens.contains(&"quick".to_string()));
assert!(tokens.contains(&"brown".to_string()));
assert!(tokens.contains(&"the".to_string())); // "the" is 3 chars, passes > 2 filter
assert!(!tokens.contains(&"a".to_string())); // 1 char, too short
}
#[test]
fn test_bm25_indexing() {
let mut bm25 = BM25::new(1.5, 0.75);
bm25.index_document("doc1".to_string(), "rust programming language");
bm25.index_document("doc2".to_string(), "python programming tutorial");
bm25.build_idf();
assert_eq!(bm25.doc_lengths.len(), 2);
assert!(bm25.idf.contains_key("rust"));
assert!(bm25.idf.contains_key("programming"));
}
#[test]
fn test_bm25_scoring() {
let mut bm25 = BM25::new(1.5, 0.75);
bm25.index_document("doc1".to_string(), "rust programming language");
bm25.index_document("doc2".to_string(), "python programming tutorial");
bm25.index_document("doc3".to_string(), "rust systems programming");
bm25.build_idf();
let score1 = bm25.score(
"rust programming",
&"doc1".to_string(),
"rust programming language",
);
let score2 = bm25.score(
"rust programming",
&"doc2".to_string(),
"python programming tutorial",
);
// doc1 should score higher (contains both terms)
assert!(score1 > score2);
}
#[test]
fn test_hybrid_search_initialization() {
let config = HybridConfig::default();
let mut hybrid = HybridSearch::new(config);
hybrid.index_document("doc1".to_string(), "rust vector database".to_string());
hybrid.index_document("doc2".to_string(), "python machine learning".to_string());
hybrid.finalize_indexing();
assert_eq!(hybrid.doc_texts.len(), 2);
assert_eq!(hybrid.bm25.doc_lengths.len(), 2);
}
#[test]
fn test_normalize_minmax() {
let mut scores = vec![1.0, 2.0, 3.0, 4.0, 5.0];
normalize_scores(&mut scores, NormalizationStrategy::MinMax);
assert!((scores[0] - 0.0).abs() < 0.01);
assert!((scores[4] - 1.0).abs() < 0.01);
assert!((scores[2] - 0.5).abs() < 0.01);
}
#[test]
fn test_bm25_candidate_retrieval() {
let mut bm25 = BM25::new(1.5, 0.75);
bm25.index_document("doc1".to_string(), "rust programming");
bm25.index_document("doc2".to_string(), "python programming");
bm25.index_document("doc3".to_string(), "java development");
bm25.build_idf();
let candidates = bm25.get_candidate_docs("rust programming");
assert!(candidates.contains(&"doc1".to_string()));
assert!(candidates.contains(&"doc2".to_string())); // Contains "programming"
assert!(!candidates.contains(&"doc3".to_string()));
}
}

View File

@@ -0,0 +1,336 @@
//! Maximal Marginal Relevance (MMR) for Diversity-Aware Search
//!
//! Implements MMR algorithm to balance relevance and diversity in search results:
//! MMR = λ × Similarity(query, doc) - (1-λ) × max Similarity(doc, selected_docs)
use crate::error::{Result, RuvectorError};
use crate::types::{DistanceMetric, SearchResult};
use serde::{Deserialize, Serialize};
/// Configuration for MMR search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MMRConfig {
/// Lambda parameter: balance between relevance (1.0) and diversity (0.0)
/// - λ = 1.0: Pure relevance (standard similarity search)
/// - λ = 0.5: Equal balance
/// - λ = 0.0: Pure diversity
pub lambda: f32,
/// Distance metric for similarity computation
pub metric: DistanceMetric,
/// Fetch multiplier for initial candidates (fetch k * multiplier results)
pub fetch_multiplier: f32,
}
impl Default for MMRConfig {
fn default() -> Self {
Self {
lambda: 0.5,
metric: DistanceMetric::Cosine,
fetch_multiplier: 2.0,
}
}
}
/// MMR search implementation
#[derive(Debug, Clone)]
pub struct MMRSearch {
/// Configuration
pub config: MMRConfig,
}
impl MMRSearch {
/// Create a new MMR search instance
pub fn new(config: MMRConfig) -> Result<Self> {
if !(0.0..=1.0).contains(&config.lambda) {
return Err(RuvectorError::InvalidParameter(format!(
"Lambda must be in [0, 1], got {}",
config.lambda
)));
}
Ok(Self { config })
}
/// Perform MMR-based reranking of search results
///
/// # Arguments
/// * `query` - Query vector
/// * `candidates` - Initial search results (sorted by relevance)
/// * `k` - Number of diverse results to return
///
/// # Returns
/// Reranked results optimizing for both relevance and diversity
pub fn rerank(
&self,
query: &[f32],
candidates: Vec<SearchResult>,
k: usize,
) -> Result<Vec<SearchResult>> {
if candidates.is_empty() {
return Ok(Vec::new());
}
if k == 0 {
return Ok(Vec::new());
}
if k >= candidates.len() {
return Ok(candidates);
}
let mut selected: Vec<SearchResult> = Vec::with_capacity(k);
let mut remaining = candidates;
// Iteratively select documents maximizing MMR
for _ in 0..k {
if remaining.is_empty() {
break;
}
// Compute MMR score for each remaining candidate
let mut best_idx = 0;
let mut best_mmr = f32::NEG_INFINITY;
for (idx, candidate) in remaining.iter().enumerate() {
let mmr_score = self.compute_mmr_score(query, candidate, &selected)?;
if mmr_score > best_mmr {
best_mmr = mmr_score;
best_idx = idx;
}
}
// Move best candidate to selected set
let best = remaining.remove(best_idx);
selected.push(best);
}
Ok(selected)
}
/// Compute MMR score for a candidate
fn compute_mmr_score(
&self,
_query: &[f32],
candidate: &SearchResult,
selected: &[SearchResult],
) -> Result<f32> {
let candidate_vec = candidate.vector.as_ref().ok_or_else(|| {
RuvectorError::InvalidParameter("Candidate vector not available".to_string())
})?;
// Relevance: similarity to query (convert distance to similarity)
let relevance = self.distance_to_similarity(candidate.score);
// Diversity: max similarity to already selected documents
let max_similarity = if selected.is_empty() {
0.0
} else {
selected
.iter()
.filter_map(|s| s.vector.as_ref())
.map(|selected_vec| {
let dist = compute_distance(candidate_vec, selected_vec, self.config.metric);
self.distance_to_similarity(dist)
})
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(0.0)
};
// MMR = λ × relevance - (1-λ) × max_similarity
let mmr = self.config.lambda * relevance - (1.0 - self.config.lambda) * max_similarity;
Ok(mmr)
}
/// Convert distance to similarity (higher is better)
fn distance_to_similarity(&self, distance: f32) -> f32 {
match self.config.metric {
DistanceMetric::Cosine => 1.0 - distance,
DistanceMetric::Euclidean => 1.0 / (1.0 + distance),
DistanceMetric::Manhattan => 1.0 / (1.0 + distance),
DistanceMetric::DotProduct => -distance, // Dot product is already similarity-like
}
}
/// Perform end-to-end MMR search
///
/// # Arguments
/// * `query` - Query vector
/// * `k` - Number of diverse results to return
/// * `search_fn` - Function to perform initial similarity search
///
/// # Returns
/// Diverse search results
pub fn search<F>(&self, query: &[f32], k: usize, search_fn: F) -> Result<Vec<SearchResult>>
where
F: Fn(&[f32], usize) -> Result<Vec<SearchResult>>,
{
// Fetch more candidates than needed
let fetch_k = (k as f32 * self.config.fetch_multiplier).ceil() as usize;
let candidates = search_fn(query, fetch_k)?;
// Rerank using MMR
self.rerank(query, candidates, k)
}
}
// Helper function
fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Euclidean => euclidean_distance(a, b),
DistanceMetric::Cosine => cosine_distance(a, b),
DistanceMetric::Manhattan => manhattan_distance(a, b),
DistanceMetric::DotProduct => dot_product_distance(a, b),
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum::<f32>()
.sqrt()
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
1.0
} else {
1.0 - (dot / (norm_a * norm_b))
}
}
fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum()
}
fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
-dot
}
#[cfg(test)]
mod tests {
use super::*;
fn create_search_result(id: &str, score: f32, vector: Vec<f32>) -> SearchResult {
SearchResult {
id: id.to_string(),
score,
vector: Some(vector),
metadata: None,
}
}
#[test]
fn test_mmr_config_validation() {
let config = MMRConfig {
lambda: 0.5,
..Default::default()
};
assert!(MMRSearch::new(config).is_ok());
let invalid_config = MMRConfig {
lambda: 1.5,
..Default::default()
};
assert!(MMRSearch::new(invalid_config).is_err());
}
#[test]
fn test_mmr_reranking() {
let config = MMRConfig {
lambda: 0.5,
metric: DistanceMetric::Euclidean,
fetch_multiplier: 2.0,
};
let mmr = MMRSearch::new(config).unwrap();
let query = vec![1.0, 0.0, 0.0];
// Create candidates with varying similarity
let candidates = vec![
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]), // Very similar to query
create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), // Similar to doc1 and query
create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.5]), // Different from doc1
create_search_result("doc4", 0.6, vec![0.0, 1.0, 0.0]), // Very different
];
let results = mmr.rerank(&query, candidates, 3).unwrap();
assert_eq!(results.len(), 3);
// First result should be most relevant
assert_eq!(results[0].id, "doc1");
// MMR should promote diversity, so doc3 or doc4 should appear
assert!(results.iter().any(|r| r.id == "doc3" || r.id == "doc4"));
}
#[test]
fn test_mmr_pure_relevance() {
let config = MMRConfig {
lambda: 1.0, // Pure relevance
metric: DistanceMetric::Euclidean,
fetch_multiplier: 2.0,
};
let mmr = MMRSearch::new(config).unwrap();
let query = vec![1.0, 0.0, 0.0];
let candidates = vec![
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
create_search_result("doc2", 0.15, vec![0.85, 0.1, 0.05]),
create_search_result("doc3", 0.5, vec![0.5, 0.5, 0.0]),
];
let results = mmr.rerank(&query, candidates, 2).unwrap();
// With lambda=1.0, should just preserve relevance order
assert_eq!(results[0].id, "doc1");
assert_eq!(results[1].id, "doc2");
}
#[test]
fn test_mmr_pure_diversity() {
let config = MMRConfig {
lambda: 0.0, // Pure diversity
metric: DistanceMetric::Euclidean,
fetch_multiplier: 2.0,
};
let mmr = MMRSearch::new(config).unwrap();
let query = vec![1.0, 0.0, 0.0];
let candidates = vec![
create_search_result("doc1", 0.1, vec![0.9, 0.1, 0.0]),
create_search_result("doc2", 0.15, vec![0.9, 0.0, 0.1]), // Very similar to doc1
create_search_result("doc3", 0.5, vec![0.0, 1.0, 0.0]), // Very different
];
let results = mmr.rerank(&query, candidates, 2).unwrap();
// With lambda=0.0, should maximize diversity
assert_eq!(results.len(), 2);
// Should not select both doc1 and doc2 (they're too similar)
let has_both_similar =
results.iter().any(|r| r.id == "doc1") && results.iter().any(|r| r.id == "doc2");
assert!(!has_both_similar);
}
#[test]
fn test_mmr_empty_candidates() {
let config = MMRConfig::default();
let mmr = MMRSearch::new(config).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = mmr.rerank(&query, Vec::new(), 5).unwrap();
assert!(results.is_empty());
}
}

View File

@@ -0,0 +1,549 @@
//! Enhanced Product Quantization with Precomputed Lookup Tables
//!
//! Provides 8-16x compression with 90-95% recall through:
//! - K-means clustering for codebook training
//! - Precomputed lookup tables for fast distance calculation
//! - Asymmetric distance computation (ADC)
use crate::error::{Result, RuvectorError};
use crate::types::DistanceMetric;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Configuration for Enhanced Product Quantization
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PQConfig {
/// Number of subspaces to split vector into
pub num_subspaces: usize,
/// Codebook size per subspace (typically 256)
pub codebook_size: usize,
/// Number of k-means iterations for training
pub num_iterations: usize,
/// Distance metric for codebook training
pub metric: DistanceMetric,
}
impl Default for PQConfig {
fn default() -> Self {
Self {
num_subspaces: 8,
codebook_size: 256,
num_iterations: 20,
metric: DistanceMetric::Euclidean,
}
}
}
impl PQConfig {
/// Validate the configuration
pub fn validate(&self) -> Result<()> {
if self.codebook_size > 256 {
return Err(RuvectorError::InvalidParameter(format!(
"Codebook size {} exceeds u8 maximum of 256",
self.codebook_size
)));
}
if self.num_subspaces == 0 {
return Err(RuvectorError::InvalidParameter(
"Number of subspaces must be greater than 0".to_string(),
));
}
Ok(())
}
}
/// Precomputed lookup table for fast distance computation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LookupTable {
/// Table: [subspace][centroid] -> distance to query subvector
pub tables: Vec<Vec<f32>>,
}
impl LookupTable {
/// Create a new lookup table for a query vector
pub fn new(query: &[f32], codebooks: &[Vec<Vec<f32>>], metric: DistanceMetric) -> Self {
let num_subspaces = codebooks.len();
let mut tables = Vec::with_capacity(num_subspaces);
for (subspace_idx, codebook) in codebooks.iter().enumerate() {
let subspace_dim = query.len() / num_subspaces;
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
let query_subvector = &query[start..end];
// Compute distance from query subvector to each centroid
let distances: Vec<f32> = codebook
.iter()
.map(|centroid| compute_distance(query_subvector, centroid, metric))
.collect();
tables.push(distances);
}
Self { tables }
}
/// Compute distance to a quantized vector using the lookup table
#[inline]
pub fn distance(&self, codes: &[u8]) -> f32 {
codes
.iter()
.enumerate()
.map(|(subspace_idx, &code)| self.tables[subspace_idx][code as usize])
.sum()
}
}
/// Enhanced Product Quantization with lookup tables
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnhancedPQ {
/// Configuration
pub config: PQConfig,
/// Trained codebooks: [subspace][centroid_id][dimensions]
pub codebooks: Vec<Vec<Vec<f32>>>,
/// Dimensions of original vectors
pub dimensions: usize,
/// Quantized vectors storage: id -> codes
pub quantized_vectors: HashMap<String, Vec<u8>>,
}
impl EnhancedPQ {
/// Create a new Enhanced PQ instance
pub fn new(dimensions: usize, config: PQConfig) -> Result<Self> {
config.validate()?;
if dimensions == 0 {
return Err(RuvectorError::InvalidParameter(
"Dimensions must be greater than 0".to_string(),
));
}
if dimensions % config.num_subspaces != 0 {
return Err(RuvectorError::InvalidParameter(format!(
"Dimensions {} must be divisible by num_subspaces {}",
dimensions, config.num_subspaces
)));
}
Ok(Self {
config,
codebooks: Vec::new(),
dimensions,
quantized_vectors: HashMap::new(),
})
}
/// Train codebooks on a set of vectors using k-means clustering
pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
if training_vectors.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Training set cannot be empty".to_string(),
));
}
if training_vectors[0].is_empty() {
return Err(RuvectorError::InvalidParameter(
"Training vectors cannot have zero dimensions".to_string(),
));
}
// Validate dimensions
for vec in training_vectors {
if vec.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: vec.len(),
});
}
}
let subspace_dim = self.dimensions / self.config.num_subspaces;
let mut codebooks = Vec::with_capacity(self.config.num_subspaces);
// Train a codebook for each subspace
for subspace_idx in 0..self.config.num_subspaces {
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
// Extract subspace vectors
let subspace_vectors: Vec<Vec<f32>> = training_vectors
.iter()
.map(|v| v[start..end].to_vec())
.collect();
// Run k-means clustering
let codebook = kmeans_clustering(
&subspace_vectors,
self.config.codebook_size,
self.config.num_iterations,
self.config.metric,
)?;
codebooks.push(codebook);
}
self.codebooks = codebooks;
Ok(())
}
/// Encode a vector into PQ codes
pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
if vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: vector.len(),
});
}
if self.codebooks.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Codebooks not trained yet".to_string(),
));
}
let subspace_dim = self.dimensions / self.config.num_subspaces;
let mut codes = Vec::with_capacity(self.config.num_subspaces);
for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
let subvector = &vector[start..end];
// Find nearest centroid (quantization)
let code = find_nearest_centroid(subvector, codebook, self.config.metric)?;
codes.push(code);
}
Ok(codes)
}
/// Add a quantized vector
pub fn add_quantized(&mut self, id: String, vector: &[f32]) -> Result<()> {
let codes = self.encode(vector)?;
self.quantized_vectors.insert(id, codes);
Ok(())
}
/// Create a lookup table for fast distance computation
pub fn create_lookup_table(&self, query: &[f32]) -> Result<LookupTable> {
if query.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: query.len(),
});
}
if self.codebooks.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Codebooks not trained yet".to_string(),
));
}
Ok(LookupTable::new(query, &self.codebooks, self.config.metric))
}
/// Search for nearest neighbors using ADC (Asymmetric Distance Computation)
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
let lookup_table = self.create_lookup_table(query)?;
// Compute distances using lookup table
let mut distances: Vec<(String, f32)> = self
.quantized_vectors
.iter()
.map(|(id, codes)| (id.clone(), lookup_table.distance(codes)))
.collect();
// Sort by distance (ascending)
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// Return top-k
Ok(distances.into_iter().take(k).collect())
}
/// Reconstruct approximate vector from codes
pub fn reconstruct(&self, codes: &[u8]) -> Result<Vec<f32>> {
if codes.len() != self.config.num_subspaces {
return Err(RuvectorError::InvalidParameter(format!(
"Expected {} codes, got {}",
self.config.num_subspaces,
codes.len()
)));
}
let mut result = Vec::with_capacity(self.dimensions);
for (subspace_idx, &code) in codes.iter().enumerate() {
let centroid = &self.codebooks[subspace_idx][code as usize];
result.extend_from_slice(centroid);
}
Ok(result)
}
/// Get compression ratio
pub fn compression_ratio(&self) -> f32 {
let original_bytes = self.dimensions * 4; // f32 = 4 bytes
let compressed_bytes = self.config.num_subspaces; // 1 byte per subspace
original_bytes as f32 / compressed_bytes as f32
}
}
// Helper functions
fn compute_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::Euclidean => euclidean_squared(a, b).sqrt(),
DistanceMetric::Cosine => {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
1.0
} else {
1.0 - (dot / (norm_a * norm_b))
}
}
DistanceMetric::DotProduct => {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
-dot // Negative for minimization
}
DistanceMetric::Manhattan => a.iter().zip(b).map(|(x, y)| (x - y).abs()).sum(),
}
}
fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
}
fn find_nearest_centroid(
vector: &[f32],
codebook: &[Vec<f32>],
metric: DistanceMetric,
) -> Result<u8> {
codebook
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = compute_distance(vector, a, metric);
let dist_b = compute_distance(vector, b, metric);
dist_a.partial_cmp(&dist_b).unwrap()
})
.map(|(idx, _)| idx as u8)
.ok_or_else(|| RuvectorError::Internal("Empty codebook".to_string()))
}
fn kmeans_clustering(
vectors: &[Vec<f32>],
k: usize,
iterations: usize,
metric: DistanceMetric,
) -> Result<Vec<Vec<f32>>> {
use rand::seq::SliceRandom;
use rand::thread_rng;
if vectors.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Cannot cluster empty vector set".to_string(),
));
}
if vectors[0].is_empty() {
return Err(RuvectorError::InvalidParameter(
"Cannot cluster vectors with zero dimensions".to_string(),
));
}
if k > vectors.len() {
return Err(RuvectorError::InvalidParameter(format!(
"k ({}) cannot be larger than number of vectors ({})",
k,
vectors.len()
)));
}
if k > 256 {
return Err(RuvectorError::InvalidParameter(format!(
"k ({}) exceeds u8 maximum of 256 for codebook size",
k
)));
}
let mut rng = thread_rng();
let dim = vectors[0].len();
// Initialize centroids using k-means++
let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(k);
centroids.push(vectors.choose(&mut rng).unwrap().clone());
while centroids.len() < k {
let distances: Vec<f32> = vectors
.iter()
.map(|v| {
centroids
.iter()
.map(|c| compute_distance(v, c, metric))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(f32::MAX)
})
.collect();
let total: f32 = distances.iter().sum();
let mut rand_val = rand::random::<f32>() * total;
for (i, &dist) in distances.iter().enumerate() {
rand_val -= dist;
if rand_val <= 0.0 {
centroids.push(vectors[i].clone());
break;
}
}
// Fallback if we didn't select anything
if centroids.len() < k && centroids.len() == centroids.len() {
centroids.push(vectors.choose(&mut rng).unwrap().clone());
}
}
// Lloyd's algorithm
for _ in 0..iterations {
let mut assignments: Vec<Vec<Vec<f32>>> = vec![Vec::new(); k];
// Assignment step
for vector in vectors {
let nearest = centroids
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = compute_distance(vector, a, metric);
let dist_b = compute_distance(vector, b, metric);
dist_a.partial_cmp(&dist_b).unwrap()
})
.map(|(idx, _)| idx)
.unwrap_or(0);
assignments[nearest].push(vector.clone());
}
// Update step
for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
if !assigned.is_empty() {
*centroid = vec![0.0; dim];
for vector in assigned {
for (i, &v) in vector.iter().enumerate() {
centroid[i] += v;
}
}
let count = assigned.len() as f32;
for v in centroid.iter_mut() {
*v /= count;
}
}
}
}
Ok(centroids)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pq_config_default() {
let config = PQConfig::default();
assert_eq!(config.num_subspaces, 8);
assert_eq!(config.codebook_size, 256);
}
#[test]
fn test_enhanced_pq_creation() {
let config = PQConfig {
num_subspaces: 4,
codebook_size: 16,
num_iterations: 10,
metric: DistanceMetric::Euclidean,
};
let pq = EnhancedPQ::new(128, config).unwrap();
assert_eq!(pq.dimensions, 128);
assert_eq!(pq.config.num_subspaces, 4);
}
#[test]
fn test_pq_training_and_encoding() {
let config = PQConfig {
num_subspaces: 2,
codebook_size: 4,
num_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut pq = EnhancedPQ::new(4, config).unwrap();
// Generate training data
let training_data = vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![2.0, 3.0, 4.0, 5.0],
vec![3.0, 4.0, 5.0, 6.0],
vec![4.0, 5.0, 6.0, 7.0],
vec![5.0, 6.0, 7.0, 8.0],
];
pq.train(&training_data).unwrap();
assert_eq!(pq.codebooks.len(), 2);
// Test encoding
let vector = vec![2.5, 3.5, 4.5, 5.5];
let codes = pq.encode(&vector).unwrap();
assert_eq!(codes.len(), 2);
}
#[test]
fn test_lookup_table_creation() {
let config = PQConfig {
num_subspaces: 2,
codebook_size: 4,
num_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut pq = EnhancedPQ::new(4, config).unwrap();
let training_data = vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![2.0, 3.0, 4.0, 5.0],
vec![3.0, 4.0, 5.0, 6.0],
vec![4.0, 5.0, 6.0, 7.0],
];
pq.train(&training_data).unwrap();
let query = vec![2.5, 3.5, 4.5, 5.5];
let lookup_table = pq.create_lookup_table(&query).unwrap();
assert_eq!(lookup_table.tables.len(), 2);
assert_eq!(lookup_table.tables[0].len(), 4);
}
#[test]
fn test_compression_ratio() {
let config = PQConfig {
num_subspaces: 8,
codebook_size: 256,
num_iterations: 10,
metric: DistanceMetric::Euclidean,
};
let pq = EnhancedPQ::new(128, config).unwrap();
let ratio = pq.compression_ratio();
assert_eq!(ratio, 64.0); // 128 * 4 / 8 = 64
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,704 @@
//! Arena allocator for batch operations
//!
//! This module provides arena-based memory allocation to reduce allocation
//! overhead in hot paths and improve memory locality.
//!
//! ## Features (ADR-001)
//!
//! - **Cache-aligned allocations**: All allocations are aligned to cache line boundaries (64 bytes)
//! - **Bump allocation**: O(1) allocation with minimal overhead
//! - **Batch deallocation**: Free all allocations at once via `reset()`
//! - **Thread-local arenas**: Per-thread allocation without synchronization
use std::alloc::{alloc, dealloc, Layout};
use std::cell::RefCell;
use std::ptr;
/// Cache line size (typically 64 bytes on modern CPUs)
pub const CACHE_LINE_SIZE: usize = 64;
/// Arena allocator for temporary allocations
///
/// Use this for batch operations where many temporary allocations
/// are needed and can be freed all at once.
pub struct Arena {
chunks: RefCell<Vec<Chunk>>,
chunk_size: usize,
}
struct Chunk {
data: *mut u8,
capacity: usize,
used: usize,
}
impl Arena {
/// Create a new arena with the specified chunk size
pub fn new(chunk_size: usize) -> Self {
Self {
chunks: RefCell::new(Vec::new()),
chunk_size,
}
}
/// Create an arena with a default 1MB chunk size
pub fn with_default_chunk_size() -> Self {
Self::new(1024 * 1024) // 1MB
}
/// Allocate a buffer of the specified size
pub fn alloc_vec<T>(&self, count: usize) -> ArenaVec<T> {
let size = count * std::mem::size_of::<T>();
let align = std::mem::align_of::<T>();
let ptr = self.alloc_raw(size, align);
ArenaVec {
ptr: ptr as *mut T,
len: 0,
capacity: count,
_phantom: std::marker::PhantomData,
}
}
/// Allocate raw bytes with specified alignment
fn alloc_raw(&self, size: usize, align: usize) -> *mut u8 {
// SECURITY: Validate alignment is a power of 2 and size is reasonable
assert!(
align > 0 && align.is_power_of_two(),
"Alignment must be a power of 2"
);
assert!(size > 0, "Cannot allocate zero bytes");
assert!(size <= isize::MAX as usize, "Allocation size too large");
let mut chunks = self.chunks.borrow_mut();
// Try to allocate from the last chunk
if let Some(chunk) = chunks.last_mut() {
// Align the current position
let current = chunk.used;
let aligned = (current + align - 1) & !(align - 1);
// SECURITY: Check for overflow in alignment calculation
if aligned < current {
panic!("Alignment calculation overflow");
}
let needed = aligned
.checked_add(size)
.expect("Arena allocation size overflow");
if needed <= chunk.capacity {
chunk.used = needed;
return unsafe {
// SECURITY: Verify pointer arithmetic doesn't overflow
let ptr = chunk.data.add(aligned);
debug_assert!(ptr as usize >= chunk.data as usize, "Pointer underflow");
ptr
};
}
}
// Need a new chunk
let chunk_size = self.chunk_size.max(size + align);
let layout = Layout::from_size_align(chunk_size, 64).unwrap();
let data = unsafe { alloc(layout) };
let aligned = align;
let chunk = Chunk {
data,
capacity: chunk_size,
used: aligned + size,
};
let ptr = unsafe { data.add(aligned) };
chunks.push(chunk);
ptr
}
/// Reset the arena, allowing reuse of allocated memory
pub fn reset(&self) {
let mut chunks = self.chunks.borrow_mut();
for chunk in chunks.iter_mut() {
chunk.used = 0;
}
}
/// Get total allocated bytes
pub fn allocated_bytes(&self) -> usize {
let chunks = self.chunks.borrow();
chunks.iter().map(|c| c.capacity).sum()
}
/// Get used bytes
pub fn used_bytes(&self) -> usize {
let chunks = self.chunks.borrow();
chunks.iter().map(|c| c.used).sum()
}
}
impl Drop for Arena {
fn drop(&mut self) {
let chunks = self.chunks.borrow();
for chunk in chunks.iter() {
let layout = Layout::from_size_align(chunk.capacity, 64).unwrap();
unsafe {
dealloc(chunk.data, layout);
}
}
}
}
/// Vector allocated from an arena
pub struct ArenaVec<T> {
ptr: *mut T,
len: usize,
capacity: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<T> ArenaVec<T> {
/// Push an element (panics if capacity exceeded)
pub fn push(&mut self, value: T) {
// SECURITY: Bounds check before pointer arithmetic
assert!(self.len < self.capacity, "ArenaVec capacity exceeded");
assert!(!self.ptr.is_null(), "ArenaVec pointer is null");
unsafe {
// Additional safety: verify the pointer offset is within bounds
let offset_ptr = self.ptr.add(self.len);
debug_assert!(
offset_ptr as usize >= self.ptr as usize,
"Pointer arithmetic overflow"
);
ptr::write(offset_ptr, value);
}
self.len += 1;
}
/// Get length
pub fn len(&self) -> usize {
self.len
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Get capacity
pub fn capacity(&self) -> usize {
self.capacity
}
/// Get as slice
pub fn as_slice(&self) -> &[T] {
// SECURITY: Bounds check before creating slice
assert!(self.len <= self.capacity, "Length exceeds capacity");
assert!(!self.ptr.is_null(), "Cannot create slice from null pointer");
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
/// Get as mutable slice
pub fn as_mut_slice(&mut self) -> &mut [T] {
// SECURITY: Bounds check before creating slice
assert!(self.len <= self.capacity, "Length exceeds capacity");
assert!(!self.ptr.is_null(), "Cannot create slice from null pointer");
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T> std::ops::Deref for ArenaVec<T> {
type Target = [T];
fn deref(&self) -> &[T] {
self.as_slice()
}
}
impl<T> std::ops::DerefMut for ArenaVec<T> {
fn deref_mut(&mut self) -> &mut [T] {
self.as_mut_slice()
}
}
// Thread-local arena for per-thread allocations
thread_local! {
static THREAD_ARENA: RefCell<Arena> = RefCell::new(Arena::with_default_chunk_size());
}
// Get the thread-local arena
// Note: Commented out due to lifetime issues with RefCell::borrow() escaping closure
// Use THREAD_ARENA.with(|arena| { ... }) directly instead
/*
pub fn thread_arena() -> impl std::ops::Deref<Target = Arena> {
THREAD_ARENA.with(|arena| {
arena.borrow()
})
}
*/
/// Cache-aligned vector storage for SIMD operations (ADR-001)
///
/// Ensures vectors are aligned to cache line boundaries (64 bytes) for
/// optimal SIMD operations and minimal cache misses.
#[repr(C, align(64))]
pub struct CacheAlignedVec {
data: *mut f32,
len: usize,
capacity: usize,
}
impl CacheAlignedVec {
/// Create a new cache-aligned vector with the given capacity
///
/// # Panics
///
/// Panics if memory allocation fails. For fallible allocation,
/// use `try_with_capacity`.
pub fn with_capacity(capacity: usize) -> Self {
Self::try_with_capacity(capacity).expect("Failed to allocate cache-aligned memory")
}
/// Try to create a new cache-aligned vector with the given capacity
///
/// Returns `None` if memory allocation fails.
pub fn try_with_capacity(capacity: usize) -> Option<Self> {
// Handle zero capacity case
if capacity == 0 {
return Some(Self {
data: std::ptr::null_mut(),
len: 0,
capacity: 0,
});
}
// Allocate cache-line aligned memory
let layout =
Layout::from_size_align(capacity * std::mem::size_of::<f32>(), CACHE_LINE_SIZE).ok()?;
let data = unsafe { alloc(layout) as *mut f32 };
// SECURITY: Check for allocation failure
if data.is_null() {
return None;
}
Some(Self {
data,
len: 0,
capacity,
})
}
/// Create from an existing slice, copying data to cache-aligned storage
///
/// # Panics
///
/// Panics if memory allocation fails. For fallible allocation,
/// use `try_from_slice`.
pub fn from_slice(slice: &[f32]) -> Self {
Self::try_from_slice(slice).expect("Failed to allocate cache-aligned memory for slice")
}
/// Try to create from an existing slice, copying data to cache-aligned storage
///
/// Returns `None` if memory allocation fails.
pub fn try_from_slice(slice: &[f32]) -> Option<Self> {
let mut vec = Self::try_with_capacity(slice.len())?;
if !slice.is_empty() {
unsafe {
ptr::copy_nonoverlapping(slice.as_ptr(), vec.data, slice.len());
}
}
vec.len = slice.len();
Some(vec)
}
/// Push an element
///
/// # Panics
///
/// Panics if capacity is exceeded or if the vector has zero capacity.
pub fn push(&mut self, value: f32) {
assert!(
self.len < self.capacity,
"CacheAlignedVec capacity exceeded"
);
assert!(
!self.data.is_null(),
"Cannot push to zero-capacity CacheAlignedVec"
);
unsafe {
*self.data.add(self.len) = value;
}
self.len += 1;
}
/// Get length
#[inline]
pub fn len(&self) -> usize {
self.len
}
/// Check if empty
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Get capacity
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
/// Get as slice
#[inline]
pub fn as_slice(&self) -> &[f32] {
if self.len == 0 {
// SAFETY: Empty slice doesn't require valid pointer
return &[];
}
// SAFETY: data is valid for len elements when len > 0
unsafe { std::slice::from_raw_parts(self.data, self.len) }
}
/// Get as mutable slice
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [f32] {
if self.len == 0 {
// SAFETY: Empty slice doesn't require valid pointer
return &mut [];
}
// SAFETY: data is valid for len elements when len > 0
unsafe { std::slice::from_raw_parts_mut(self.data, self.len) }
}
/// Get raw pointer (for SIMD operations)
#[inline]
pub fn as_ptr(&self) -> *const f32 {
self.data
}
/// Get mutable raw pointer (for SIMD operations)
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut f32 {
self.data
}
/// Check if properly aligned for SIMD
///
/// Returns `true` for zero-capacity vectors (considered trivially aligned).
#[inline]
pub fn is_aligned(&self) -> bool {
if self.data.is_null() {
// Zero-capacity vectors are considered aligned
return self.capacity == 0;
}
(self.data as usize) % CACHE_LINE_SIZE == 0
}
/// Clear the vector (sets len to 0, doesn't deallocate)
pub fn clear(&mut self) {
self.len = 0;
}
}
impl Drop for CacheAlignedVec {
fn drop(&mut self) {
if !self.data.is_null() && self.capacity > 0 {
let layout = Layout::from_size_align(
self.capacity * std::mem::size_of::<f32>(),
CACHE_LINE_SIZE,
)
.expect("Invalid layout");
unsafe {
dealloc(self.data as *mut u8, layout);
}
}
}
}
impl std::ops::Deref for CacheAlignedVec {
type Target = [f32];
fn deref(&self) -> &[f32] {
self.as_slice()
}
}
impl std::ops::DerefMut for CacheAlignedVec {
fn deref_mut(&mut self) -> &mut [f32] {
self.as_mut_slice()
}
}
// Safety: The raw pointer is owned and not shared
unsafe impl Send for CacheAlignedVec {}
unsafe impl Sync for CacheAlignedVec {}
/// Batch vector allocator for processing multiple vectors (ADR-001)
///
/// Allocates contiguous, cache-aligned storage for a batch of vectors,
/// enabling efficient SIMD processing and minimal cache misses.
pub struct BatchVectorAllocator {
data: *mut f32,
dimensions: usize,
capacity: usize,
count: usize,
}
impl BatchVectorAllocator {
/// Create allocator for vectors of given dimensions
///
/// # Panics
///
/// Panics if memory allocation fails. For fallible allocation,
/// use `try_new`.
pub fn new(dimensions: usize, initial_capacity: usize) -> Self {
Self::try_new(dimensions, initial_capacity)
.expect("Failed to allocate batch vector storage")
}
/// Try to create allocator for vectors of given dimensions
///
/// Returns `None` if memory allocation fails.
pub fn try_new(dimensions: usize, initial_capacity: usize) -> Option<Self> {
// Handle zero capacity case
if dimensions == 0 || initial_capacity == 0 {
return Some(Self {
data: std::ptr::null_mut(),
dimensions,
capacity: initial_capacity,
count: 0,
});
}
let total_floats = dimensions * initial_capacity;
let layout =
Layout::from_size_align(total_floats * std::mem::size_of::<f32>(), CACHE_LINE_SIZE)
.ok()?;
let data = unsafe { alloc(layout) as *mut f32 };
// SECURITY: Check for allocation failure
if data.is_null() {
return None;
}
Some(Self {
data,
dimensions,
capacity: initial_capacity,
count: 0,
})
}
/// Add a vector, returns its index
///
/// # Panics
///
/// Panics if the allocator is full, dimensions mismatch, or allocator has zero capacity.
pub fn add(&mut self, vector: &[f32]) -> usize {
assert_eq!(vector.len(), self.dimensions, "Vector dimension mismatch");
assert!(self.count < self.capacity, "Batch allocator full");
assert!(
!self.data.is_null(),
"Cannot add to zero-capacity BatchVectorAllocator"
);
let offset = self.count * self.dimensions;
unsafe {
ptr::copy_nonoverlapping(vector.as_ptr(), self.data.add(offset), self.dimensions);
}
let index = self.count;
self.count += 1;
index
}
/// Get a vector by index
pub fn get(&self, index: usize) -> &[f32] {
assert!(index < self.count, "Index out of bounds");
let offset = index * self.dimensions;
unsafe { std::slice::from_raw_parts(self.data.add(offset), self.dimensions) }
}
/// Get mutable vector by index
pub fn get_mut(&mut self, index: usize) -> &mut [f32] {
assert!(index < self.count, "Index out of bounds");
let offset = index * self.dimensions;
unsafe { std::slice::from_raw_parts_mut(self.data.add(offset), self.dimensions) }
}
/// Get raw pointer to vector at index (for SIMD)
#[inline]
pub fn ptr_at(&self, index: usize) -> *const f32 {
assert!(index < self.count, "Index out of bounds");
let offset = index * self.dimensions;
unsafe { self.data.add(offset) }
}
/// Number of vectors stored
#[inline]
pub fn len(&self) -> usize {
self.count
}
/// Check if empty
#[inline]
pub fn is_empty(&self) -> bool {
self.count == 0
}
/// Dimensions per vector
#[inline]
pub fn dimensions(&self) -> usize {
self.dimensions
}
/// Reset allocator (keeps memory)
pub fn clear(&mut self) {
self.count = 0;
}
}
impl Drop for BatchVectorAllocator {
fn drop(&mut self) {
if !self.data.is_null() {
let layout = Layout::from_size_align(
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
CACHE_LINE_SIZE,
)
.expect("Invalid layout");
unsafe {
dealloc(self.data as *mut u8, layout);
}
}
}
}
// Safety: The raw pointer is owned and not shared
unsafe impl Send for BatchVectorAllocator {}
unsafe impl Sync for BatchVectorAllocator {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arena_alloc() {
let arena = Arena::new(1024);
let mut vec1 = arena.alloc_vec::<f32>(10);
vec1.push(1.0);
vec1.push(2.0);
vec1.push(3.0);
assert_eq!(vec1.len(), 3);
assert_eq!(vec1[0], 1.0);
assert_eq!(vec1[1], 2.0);
assert_eq!(vec1[2], 3.0);
}
#[test]
fn test_arena_multiple_allocs() {
let arena = Arena::new(1024);
let vec1 = arena.alloc_vec::<u32>(100);
let vec2 = arena.alloc_vec::<u64>(50);
let vec3 = arena.alloc_vec::<f32>(200);
assert_eq!(vec1.capacity(), 100);
assert_eq!(vec2.capacity(), 50);
assert_eq!(vec3.capacity(), 200);
}
#[test]
fn test_arena_reset() {
let arena = Arena::new(1024);
{
let _vec1 = arena.alloc_vec::<f32>(100);
let _vec2 = arena.alloc_vec::<f32>(100);
}
let used_before = arena.used_bytes();
arena.reset();
let used_after = arena.used_bytes();
assert!(used_after < used_before);
}
#[test]
fn test_cache_aligned_vec() {
let mut vec = CacheAlignedVec::with_capacity(100);
// Check alignment
assert!(vec.is_aligned(), "Vector should be cache-aligned");
// Test push
for i in 0..50 {
vec.push(i as f32);
}
assert_eq!(vec.len(), 50);
// Test slice access
let slice = vec.as_slice();
assert_eq!(slice[0], 0.0);
assert_eq!(slice[49], 49.0);
}
#[test]
fn test_cache_aligned_vec_from_slice() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let aligned = CacheAlignedVec::from_slice(&data);
assert!(aligned.is_aligned());
assert_eq!(aligned.len(), 5);
assert_eq!(aligned.as_slice(), &data[..]);
}
#[test]
fn test_batch_vector_allocator() {
let mut allocator = BatchVectorAllocator::new(4, 10);
let v1 = vec![1.0, 2.0, 3.0, 4.0];
let v2 = vec![5.0, 6.0, 7.0, 8.0];
let idx1 = allocator.add(&v1);
let idx2 = allocator.add(&v2);
assert_eq!(idx1, 0);
assert_eq!(idx2, 1);
assert_eq!(allocator.len(), 2);
// Test retrieval
assert_eq!(allocator.get(0), &v1[..]);
assert_eq!(allocator.get(1), &v2[..]);
}
#[test]
fn test_batch_allocator_clear() {
let mut allocator = BatchVectorAllocator::new(3, 5);
allocator.add(&[1.0, 2.0, 3.0]);
allocator.add(&[4.0, 5.0, 6.0]);
assert_eq!(allocator.len(), 2);
allocator.clear();
assert_eq!(allocator.len(), 0);
// Should be able to add again
allocator.add(&[7.0, 8.0, 9.0]);
assert_eq!(allocator.len(), 1);
}
}

View File

@@ -0,0 +1,436 @@
//! Cache-optimized data structures using Structure-of-Arrays (SoA) layout
//!
//! This module provides cache-friendly layouts for vector storage to minimize
//! cache misses and improve memory access patterns.
use std::alloc::{alloc, dealloc, Layout};
use std::ptr;
/// Cache line size (typically 64 bytes on modern CPUs)
const CACHE_LINE_SIZE: usize = 64;
/// Structure-of-Arrays layout for vectors
///
/// Instead of storing vectors as Vec<Vec<f32>>, we store all components
/// separately to improve cache locality during SIMD operations.
#[repr(align(64))] // Align to cache line boundary
pub struct SoAVectorStorage {
/// Number of vectors
count: usize,
/// Dimensions per vector
dimensions: usize,
/// Capacity (allocated vectors)
capacity: usize,
/// Storage for each dimension separately
/// Layout: [dim0_vec0, dim0_vec1, ..., dim0_vecN, dim1_vec0, ...]
data: *mut f32,
}
impl SoAVectorStorage {
/// Maximum allowed dimensions to prevent overflow
const MAX_DIMENSIONS: usize = 65536;
/// Maximum allowed capacity to prevent overflow
const MAX_CAPACITY: usize = 1 << 24; // ~16M vectors
/// Create a new SoA vector storage
///
/// # Panics
/// Panics if dimensions or capacity exceed safe limits or would cause overflow.
pub fn new(dimensions: usize, initial_capacity: usize) -> Self {
// Security: Validate inputs to prevent integer overflow
assert!(
dimensions > 0 && dimensions <= Self::MAX_DIMENSIONS,
"dimensions must be between 1 and {}",
Self::MAX_DIMENSIONS
);
assert!(
initial_capacity <= Self::MAX_CAPACITY,
"initial_capacity exceeds maximum of {}",
Self::MAX_CAPACITY
);
let capacity = initial_capacity.next_power_of_two();
// Security: Use checked arithmetic to prevent overflow
let total_elements = dimensions
.checked_mul(capacity)
.expect("dimensions * capacity overflow");
let total_bytes = total_elements
.checked_mul(std::mem::size_of::<f32>())
.expect("total size overflow");
let layout =
Layout::from_size_align(total_bytes, CACHE_LINE_SIZE).expect("invalid memory layout");
let data = unsafe { alloc(layout) as *mut f32 };
// Zero initialize
unsafe {
ptr::write_bytes(data, 0, total_elements);
}
Self {
count: 0,
dimensions,
capacity,
data,
}
}
/// Add a vector to the storage
pub fn push(&mut self, vector: &[f32]) {
assert_eq!(vector.len(), self.dimensions);
if self.count >= self.capacity {
self.grow();
}
// Store each dimension separately
for (dim_idx, &value) in vector.iter().enumerate() {
let offset = dim_idx * self.capacity + self.count;
unsafe {
*self.data.add(offset) = value;
}
}
self.count += 1;
}
/// Get a vector by index (copies to output buffer)
pub fn get(&self, index: usize, output: &mut [f32]) {
assert!(index < self.count);
assert_eq!(output.len(), self.dimensions);
for (dim_idx, out) in output.iter_mut().enumerate().take(self.dimensions) {
let offset = dim_idx * self.capacity + index;
*out = unsafe { *self.data.add(offset) };
}
}
/// Get a slice of a specific dimension across all vectors
/// This allows efficient SIMD operations on a single dimension
pub fn dimension_slice(&self, dim_idx: usize) -> &[f32] {
assert!(dim_idx < self.dimensions);
let offset = dim_idx * self.capacity;
unsafe { std::slice::from_raw_parts(self.data.add(offset), self.count) }
}
/// Get a mutable slice of a specific dimension
pub fn dimension_slice_mut(&mut self, dim_idx: usize) -> &mut [f32] {
assert!(dim_idx < self.dimensions);
let offset = dim_idx * self.capacity;
unsafe { std::slice::from_raw_parts_mut(self.data.add(offset), self.count) }
}
/// Number of vectors stored
pub fn len(&self) -> usize {
self.count
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.count == 0
}
/// Dimensions per vector
pub fn dimensions(&self) -> usize {
self.dimensions
}
/// Grow the storage capacity
fn grow(&mut self) {
let new_capacity = self.capacity * 2;
// Security: Use checked arithmetic to prevent overflow
let new_total_elements = self
.dimensions
.checked_mul(new_capacity)
.expect("dimensions * new_capacity overflow");
let new_total_bytes = new_total_elements
.checked_mul(std::mem::size_of::<f32>())
.expect("total size overflow in grow");
let new_layout = Layout::from_size_align(new_total_bytes, CACHE_LINE_SIZE)
.expect("invalid memory layout in grow");
let new_data = unsafe { alloc(new_layout) as *mut f32 };
// Copy old data dimension by dimension
for dim_idx in 0..self.dimensions {
let old_offset = dim_idx * self.capacity;
let new_offset = dim_idx * new_capacity;
unsafe {
ptr::copy_nonoverlapping(
self.data.add(old_offset),
new_data.add(new_offset),
self.count,
);
}
}
// Deallocate old data
let old_layout = Layout::from_size_align(
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
CACHE_LINE_SIZE,
)
.unwrap();
unsafe {
dealloc(self.data as *mut u8, old_layout);
}
self.data = new_data;
self.capacity = new_capacity;
}
/// Compute distance from query to all stored vectors using dimension-wise operations
/// This takes advantage of the SoA layout for better cache utilization
#[inline(always)]
pub fn batch_euclidean_distances(&self, query: &[f32], output: &mut [f32]) {
assert_eq!(query.len(), self.dimensions);
assert_eq!(output.len(), self.count);
// Use SIMD-optimized version for larger batches
#[cfg(target_arch = "aarch64")]
{
if self.count >= 16 {
unsafe { self.batch_euclidean_distances_neon(query, output) };
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if self.count >= 32 && is_x86_feature_detected!("avx2") {
unsafe { self.batch_euclidean_distances_avx2(query, output) };
return;
}
}
// Scalar fallback
self.batch_euclidean_distances_scalar(query, output);
}
/// Scalar implementation of batch euclidean distances
#[inline(always)]
fn batch_euclidean_distances_scalar(&self, query: &[f32], output: &mut [f32]) {
// Initialize output with zeros
output.fill(0.0);
// Process dimension by dimension for cache-friendly access
for dim_idx in 0..self.dimensions {
let dim_slice = self.dimension_slice(dim_idx);
// Safety: dim_idx is bounded by self.dimensions which is validated in constructor
let query_val = unsafe { *query.get_unchecked(dim_idx) };
// Compute squared differences for this dimension
// Use unchecked access since vec_idx is bounded by self.count
for vec_idx in 0..self.count {
let diff = unsafe { *dim_slice.get_unchecked(vec_idx) } - query_val;
unsafe { *output.get_unchecked_mut(vec_idx) += diff * diff };
}
}
// Take square root
for distance in output.iter_mut() {
*distance = distance.sqrt();
}
}
/// NEON-optimized batch euclidean distances
///
/// # Safety
/// Caller must ensure query.len() == self.dimensions and output.len() == self.count
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn batch_euclidean_distances_neon(&self, query: &[f32], output: &mut [f32]) {
use std::arch::aarch64::*;
let out_ptr = output.as_mut_ptr();
let query_ptr = query.as_ptr();
// Initialize output with zeros
let chunks = self.count / 4;
// Zero initialize using SIMD
let zero = vdupq_n_f32(0.0);
for i in 0..chunks {
let idx = i * 4;
vst1q_f32(out_ptr.add(idx), zero);
}
for i in (chunks * 4)..self.count {
*output.get_unchecked_mut(i) = 0.0;
}
// Process dimension by dimension for cache-friendly access
for dim_idx in 0..self.dimensions {
let dim_slice = self.dimension_slice(dim_idx);
let dim_ptr = dim_slice.as_ptr();
let query_val = vdupq_n_f32(*query_ptr.add(dim_idx));
// SIMD processing of 4 vectors at a time
for i in 0..chunks {
let idx = i * 4;
let dim_vals = vld1q_f32(dim_ptr.add(idx));
let out_vals = vld1q_f32(out_ptr.add(idx));
let diff = vsubq_f32(dim_vals, query_val);
let result = vfmaq_f32(out_vals, diff, diff);
vst1q_f32(out_ptr.add(idx), result);
}
// Handle remainder with bounds-check elimination
let query_val_scalar = *query_ptr.add(dim_idx);
for i in (chunks * 4)..self.count {
let diff = *dim_slice.get_unchecked(i) - query_val_scalar;
*output.get_unchecked_mut(i) += diff * diff;
}
}
// Take square root using SIMD vsqrtq_f32
for i in 0..chunks {
let idx = i * 4;
let vals = vld1q_f32(out_ptr.add(idx));
let sqrt_vals = vsqrtq_f32(vals);
vst1q_f32(out_ptr.add(idx), sqrt_vals);
}
for i in (chunks * 4)..self.count {
*output.get_unchecked_mut(i) = output.get_unchecked(i).sqrt();
}
}
/// AVX2-optimized batch euclidean distances
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn batch_euclidean_distances_avx2(&self, query: &[f32], output: &mut [f32]) {
use std::arch::x86_64::*;
let chunks = self.count / 8;
// Zero initialize using SIMD
let zero = _mm256_setzero_ps();
for i in 0..chunks {
let idx = i * 8;
_mm256_storeu_ps(output.as_mut_ptr().add(idx), zero);
}
for out in output.iter_mut().take(self.count).skip(chunks * 8) {
*out = 0.0;
}
// Process dimension by dimension
for (dim_idx, &q_val) in query.iter().enumerate().take(self.dimensions) {
let dim_slice = self.dimension_slice(dim_idx);
let query_val = _mm256_set1_ps(q_val);
// SIMD processing of 8 vectors at a time
for i in 0..chunks {
let idx = i * 8;
let dim_vals = _mm256_loadu_ps(dim_slice.as_ptr().add(idx));
let out_vals = _mm256_loadu_ps(output.as_ptr().add(idx));
let diff = _mm256_sub_ps(dim_vals, query_val);
let sq = _mm256_mul_ps(diff, diff);
let result = _mm256_add_ps(out_vals, sq);
_mm256_storeu_ps(output.as_mut_ptr().add(idx), result);
}
// Handle remainder
for i in (chunks * 8)..self.count {
let diff = dim_slice[i] - query[dim_idx];
output[i] += diff * diff;
}
}
// Take square root (no SIMD sqrt in basic AVX2, use scalar)
for distance in output.iter_mut() {
*distance = distance.sqrt();
}
}
}
// Feature detection helper for x86_64
#[cfg(target_arch = "x86_64")]
#[allow(dead_code)]
fn is_x86_feature_detected_helper(feature: &str) -> bool {
match feature {
"avx2" => is_x86_feature_detected!("avx2"),
_ => false,
}
}
impl Drop for SoAVectorStorage {
fn drop(&mut self) {
let layout = Layout::from_size_align(
self.dimensions * self.capacity * std::mem::size_of::<f32>(),
CACHE_LINE_SIZE,
)
.unwrap();
unsafe {
dealloc(self.data as *mut u8, layout);
}
}
}
unsafe impl Send for SoAVectorStorage {}
unsafe impl Sync for SoAVectorStorage {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_soa_storage() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[1.0, 2.0, 3.0]);
storage.push(&[4.0, 5.0, 6.0]);
assert_eq!(storage.len(), 2);
let mut output = vec![0.0; 3];
storage.get(0, &mut output);
assert_eq!(output, vec![1.0, 2.0, 3.0]);
storage.get(1, &mut output);
assert_eq!(output, vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_dimension_slice() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[1.0, 2.0, 3.0]);
storage.push(&[4.0, 5.0, 6.0]);
storage.push(&[7.0, 8.0, 9.0]);
// Get all values for dimension 0
let dim0 = storage.dimension_slice(0);
assert_eq!(dim0, &[1.0, 4.0, 7.0]);
// Get all values for dimension 1
let dim1 = storage.dimension_slice(1);
assert_eq!(dim1, &[2.0, 5.0, 8.0]);
}
#[test]
fn test_batch_distances() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[1.0, 0.0, 0.0]);
storage.push(&[0.0, 1.0, 0.0]);
storage.push(&[0.0, 0.0, 1.0]);
let query = vec![1.0, 0.0, 0.0];
let mut distances = vec![0.0; 3];
storage.batch_euclidean_distances(&query, &mut distances);
assert!((distances[0] - 0.0).abs() < 0.001);
assert!((distances[1] - 1.414).abs() < 0.01);
assert!((distances[2] - 1.414).abs() < 0.01);
}
}

View File

@@ -0,0 +1,167 @@
//! SIMD-optimized distance metrics
//! Uses SimSIMD when available (native), falls back to pure Rust for WASM
use crate::error::{Result, RuvectorError};
use crate::types::DistanceMetric;
/// Calculate distance between two vectors using the specified metric
#[inline]
pub fn distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> Result<f32> {
if a.len() != b.len() {
return Err(RuvectorError::DimensionMismatch {
expected: a.len(),
actual: b.len(),
});
}
match metric {
DistanceMetric::Euclidean => Ok(euclidean_distance(a, b)),
DistanceMetric::Cosine => Ok(cosine_distance(a, b)),
DistanceMetric::DotProduct => Ok(dot_product_distance(a, b)),
DistanceMetric::Manhattan => Ok(manhattan_distance(a, b)),
}
}
/// Euclidean (L2) distance
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
{
(simsimd::SpatialSimilarity::sqeuclidean(a, b)
.expect("SimSIMD euclidean failed")
.sqrt()) as f32
}
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
{
// Pure Rust fallback for WASM
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
}
/// Cosine distance (1 - cosine_similarity)
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
{
simsimd::SpatialSimilarity::cosine(a, b).expect("SimSIMD cosine failed") as f32
}
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
{
// Pure Rust fallback for WASM
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 1e-8 && norm_b > 1e-8 {
1.0 - (dot / (norm_a * norm_b))
} else {
1.0
}
}
}
/// Dot product distance (negative for maximization)
#[inline]
pub fn dot_product_distance(a: &[f32], b: &[f32]) -> f32 {
#[cfg(all(feature = "simd", not(target_arch = "wasm32")))]
{
let dot = simsimd::SpatialSimilarity::dot(a, b).expect("SimSIMD dot product failed");
(-dot) as f32
}
#[cfg(any(not(feature = "simd"), target_arch = "wasm32"))]
{
// Pure Rust fallback for WASM
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
-dot
}
}
/// Manhattan (L1) distance
#[inline]
pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
/// Batch distance calculation optimized with Rayon (native) or sequential (WASM)
pub fn batch_distances(
query: &[f32],
vectors: &[Vec<f32>],
metric: DistanceMetric,
) -> Result<Vec<f32>> {
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
{
use rayon::prelude::*;
vectors
.par_iter()
.map(|v| distance(query, v, metric))
.collect()
}
#[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
{
// Sequential fallback for WASM
vectors.iter().map(|v| distance(query, v, metric)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_euclidean_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = euclidean_distance(&a, &b);
assert!((dist - 5.196).abs() < 0.01);
}
#[test]
fn test_cosine_distance() {
// Test with identical vectors (should have distance ~0)
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let dist = cosine_distance(&a, &b);
assert!(
dist < 0.01,
"Identical vectors should have ~0 distance, got {}",
dist
);
// Test with opposite vectors (should have high distance)
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
let dist = cosine_distance(&a, &b);
assert!(
dist > 1.5,
"Opposite vectors should have high distance, got {}",
dist
);
}
#[test]
fn test_dot_product_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = dot_product_distance(&a, &b);
assert!((dist + 32.0).abs() < 0.01); // -(4 + 10 + 18) = -32
}
#[test]
fn test_manhattan_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dist = manhattan_distance(&a, &b);
assert!((dist - 9.0).abs() < 0.01); // |1-4| + |2-5| + |3-6| = 9
}
#[test]
fn test_dimension_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let result = distance(&a, &b, DistanceMetric::Euclidean);
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,415 @@
//! Text Embedding Providers
//!
//! This module provides a pluggable embedding system for AgenticDB.
//!
//! ## Available Providers
//!
//! - **HashEmbedding**: Fast hash-based placeholder (default, not semantic)
//! - **CandleEmbedding**: Real embeddings using candle-transformers (feature: `real-embeddings`)
//! - **ApiEmbedding**: External API calls (OpenAI, Anthropic, Cohere, etc.)
//!
//! ## Usage
//!
//! ```rust,no_run
//! use ruvector_core::embeddings::{EmbeddingProvider, HashEmbedding, ApiEmbedding};
//! use ruvector_core::AgenticDB;
//!
//! // Default: Hash-based (fast, but not semantic)
//! let hash_provider = HashEmbedding::new(384);
//! let embedding = hash_provider.embed("hello world")?;
//!
//! // API-based (requires API key)
//! let api_provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
//! let embedding = api_provider.embed("hello world")?;
//! # Ok::<(), Box<dyn std::error::Error>>(())
//! ```
use crate::error::Result;
#[cfg(any(feature = "real-embeddings", feature = "api-embeddings"))]
use crate::error::RuvectorError;
use std::sync::Arc;
/// Trait for text embedding providers
pub trait EmbeddingProvider: Send + Sync {
/// Generate embedding vector for the given text
fn embed(&self, text: &str) -> Result<Vec<f32>>;
/// Get the dimensionality of embeddings produced by this provider
fn dimensions(&self) -> usize;
/// Get a description of this provider (for logging/debugging)
fn name(&self) -> &str;
}
/// Hash-based embedding provider (placeholder, not semantic)
///
/// ⚠️ **WARNING**: This does NOT produce semantic embeddings!
/// - "dog" and "cat" will NOT be similar
/// - "dog" and "god" WILL be similar (same characters)
///
/// Use this only for:
/// - Testing
/// - Prototyping
/// - When semantic similarity is not required
#[derive(Debug, Clone)]
pub struct HashEmbedding {
dimensions: usize,
}
impl HashEmbedding {
/// Create a new hash-based embedding provider
pub fn new(dimensions: usize) -> Self {
Self { dimensions }
}
}
impl EmbeddingProvider for HashEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut embedding = vec![0.0; self.dimensions];
let bytes = text.as_bytes();
for (i, byte) in bytes.iter().enumerate() {
embedding[i % self.dimensions] += (*byte as f32) / 255.0;
}
// Normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"HashEmbedding (placeholder)"
}
}
/// Real embeddings using candle-transformers
///
/// Requires feature flag: `real-embeddings`
///
/// ⚠️ **Note**: Full candle integration is complex and model-specific.
/// For production use, we recommend:
/// 1. Using the API-based providers (simpler, always up-to-date)
/// 2. Using ONNX Runtime with pre-exported models
/// 3. Implementing your own candle wrapper for your specific model
///
/// This is a stub implementation showing the structure.
/// Users should implement `EmbeddingProvider` trait for their specific models.
#[cfg(feature = "real-embeddings")]
pub mod candle {
use super::*;
/// Candle-based embedding provider stub
///
/// This is a placeholder. For real implementation:
/// 1. Add candle dependencies for your specific model type
/// 2. Implement model loading and inference
/// 3. Handle tokenization appropriately
///
/// Example structure:
/// ```rust,ignore
/// pub struct CandleEmbedding {
/// model: YourModelType,
/// tokenizer: Tokenizer,
/// device: Device,
/// dimensions: usize,
/// }
/// ```
pub struct CandleEmbedding {
dimensions: usize,
model_id: String,
}
impl CandleEmbedding {
/// Create a stub candle embedding provider
///
/// **This is not a real implementation!**
/// For production, implement with actual model loading.
///
/// # Example
/// ```rust,no_run
/// # #[cfg(feature = "real-embeddings")]
/// # {
/// use ruvector_core::embeddings::candle::CandleEmbedding;
///
/// // This returns an error - real implementation required
/// let result = CandleEmbedding::from_pretrained(
/// "sentence-transformers/all-MiniLM-L6-v2",
/// false
/// );
/// assert!(result.is_err());
/// # }
/// ```
pub fn from_pretrained(model_id: &str, _use_gpu: bool) -> Result<Self> {
Err(RuvectorError::ModelLoadError(format!(
"Candle embedding support is a stub. Please:\n\
1. Use ApiEmbedding for production (recommended)\n\
2. Or implement CandleEmbedding for model: {}\n\
3. See docs for ONNX Runtime integration examples",
model_id
)))
}
}
impl EmbeddingProvider for CandleEmbedding {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Err(RuvectorError::ModelInferenceError(
"Candle embedding not implemented - use ApiEmbedding instead".to_string(),
))
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"CandleEmbedding (stub - not implemented)"
}
}
}
#[cfg(feature = "real-embeddings")]
pub use candle::CandleEmbedding;
/// API-based embedding provider (OpenAI, Anthropic, Cohere, etc.)
///
/// Supports any API that accepts JSON and returns embeddings in a standard format.
///
/// # Example (OpenAI)
/// ```rust,no_run
/// use ruvector_core::embeddings::{EmbeddingProvider, ApiEmbedding};
///
/// let provider = ApiEmbedding::openai("sk-...", "text-embedding-3-small");
/// let embedding = provider.embed("hello world")?;
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
#[cfg(feature = "api-embeddings")]
#[derive(Clone)]
pub struct ApiEmbedding {
api_key: String,
endpoint: String,
model: String,
dimensions: usize,
client: reqwest::blocking::Client,
}
#[cfg(feature = "api-embeddings")]
impl ApiEmbedding {
/// Create a new API embedding provider
///
/// # Arguments
/// * `api_key` - API key for authentication
/// * `endpoint` - API endpoint URL
/// * `model` - Model identifier
/// * `dimensions` - Expected embedding dimensions
pub fn new(api_key: String, endpoint: String, model: String, dimensions: usize) -> Self {
Self {
api_key,
endpoint,
model,
dimensions,
client: reqwest::blocking::Client::new(),
}
}
/// Create OpenAI embedding provider
///
/// # Models
/// - `text-embedding-3-small` - 1536 dimensions, $0.02/1M tokens
/// - `text-embedding-3-large` - 3072 dimensions, $0.13/1M tokens
/// - `text-embedding-ada-002` - 1536 dimensions (legacy)
pub fn openai(api_key: &str, model: &str) -> Self {
let dimensions = match model {
"text-embedding-3-large" => 3072,
_ => 1536, // text-embedding-3-small and ada-002
};
Self::new(
api_key.to_string(),
"https://api.openai.com/v1/embeddings".to_string(),
model.to_string(),
dimensions,
)
}
/// Create Cohere embedding provider
///
/// # Models
/// - `embed-english-v3.0` - 1024 dimensions
/// - `embed-multilingual-v3.0` - 1024 dimensions
pub fn cohere(api_key: &str, model: &str) -> Self {
Self::new(
api_key.to_string(),
"https://api.cohere.ai/v1/embed".to_string(),
model.to_string(),
1024,
)
}
/// Create Voyage AI embedding provider
///
/// # Models
/// - `voyage-2` - 1024 dimensions
/// - `voyage-large-2` - 1536 dimensions
pub fn voyage(api_key: &str, model: &str) -> Self {
let dimensions = if model.contains("large") { 1536 } else { 1024 };
Self::new(
api_key.to_string(),
"https://api.voyageai.com/v1/embeddings".to_string(),
model.to_string(),
dimensions,
)
}
}
#[cfg(feature = "api-embeddings")]
impl EmbeddingProvider for ApiEmbedding {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let request_body = serde_json::json!({
"input": text,
"model": self.model,
});
let response = self
.client
.post(&self.endpoint)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.map_err(|e| {
RuvectorError::ModelInferenceError(format!("API request failed: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(RuvectorError::ModelInferenceError(format!(
"API returned error {}: {}",
status, error_text
)));
}
let response_json: serde_json::Value = response.json().map_err(|e| {
RuvectorError::ModelInferenceError(format!("Failed to parse response: {}", e))
})?;
// Handle different API response formats
let embedding = if let Some(data) = response_json.get("data") {
// OpenAI format: {"data": [{"embedding": [...]}]}
data.as_array()
.and_then(|arr| arr.first())
.and_then(|obj| obj.get("embedding"))
.and_then(|emb| emb.as_array())
.ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid OpenAI response format".to_string())
})?
} else if let Some(embeddings) = response_json.get("embeddings") {
// Cohere format: {"embeddings": [[...]]}
embeddings
.as_array()
.and_then(|arr| arr.first())
.and_then(|emb| emb.as_array())
.ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid Cohere response format".to_string())
})?
} else {
return Err(RuvectorError::ModelInferenceError(
"Unknown API response format".to_string(),
));
};
let embedding_vec: Result<Vec<f32>> = embedding
.iter()
.map(|v| {
v.as_f64().map(|f| f as f32).ok_or_else(|| {
RuvectorError::ModelInferenceError("Invalid embedding value".to_string())
})
})
.collect();
embedding_vec
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
"ApiEmbedding"
}
}
/// Type-erased embedding provider for dynamic dispatch
pub type BoxedEmbeddingProvider = Arc<dyn EmbeddingProvider>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_embedding() {
let provider = HashEmbedding::new(128);
let emb1 = provider.embed("hello world").unwrap();
let emb2 = provider.embed("hello world").unwrap();
assert_eq!(emb1.len(), 128);
assert_eq!(emb1, emb2, "Same text should produce same embedding");
// Check normalization
let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
}
#[test]
fn test_hash_embedding_different_text() {
let provider = HashEmbedding::new(128);
let emb1 = provider.embed("hello").unwrap();
let emb2 = provider.embed("world").unwrap();
assert_ne!(
emb1, emb2,
"Different text should produce different embeddings"
);
}
#[cfg(feature = "real-embeddings")]
#[test]
#[ignore] // Requires model download
fn test_candle_embedding() {
let provider =
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false)
.unwrap();
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 384);
// Check normalization
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Embedding should be normalized");
}
#[test]
#[ignore] // Requires API key
fn test_api_embedding_openai() {
let api_key = std::env::var("OPENAI_API_KEY").unwrap();
let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 1536);
}
}

View File

@@ -0,0 +1,113 @@
//! Error types for Ruvector
use thiserror::Error;
/// Result type alias for Ruvector operations
pub type Result<T> = std::result::Result<T, RuvectorError>;
/// Main error type for Ruvector
#[derive(Error, Debug)]
pub enum RuvectorError {
/// Vector dimension mismatch
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
/// Expected dimension
expected: usize,
/// Actual dimension
actual: usize,
},
/// Vector not found
#[error("Vector not found: {0}")]
VectorNotFound(String),
/// Invalid parameter
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
/// Invalid input
#[error("Invalid input: {0}")]
InvalidInput(String),
/// Invalid dimension
#[error("Invalid dimension: {0}")]
InvalidDimension(String),
/// Storage error
#[error("Storage error: {0}")]
StorageError(String),
/// Model loading error
#[error("Model loading error: {0}")]
ModelLoadError(String),
/// Model inference error
#[error("Model inference error: {0}")]
ModelInferenceError(String),
/// Index error
#[error("Index error: {0}")]
IndexError(String),
/// Serialization error
#[error("Serialization error: {0}")]
SerializationError(String),
/// IO error
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
/// Database error
#[error("Database error: {0}")]
DatabaseError(String),
/// Invalid path error
#[error("Invalid path: {0}")]
InvalidPath(String),
/// Other errors
#[error("Internal error: {0}")]
Internal(String),
}
#[cfg(feature = "storage")]
impl From<redb::Error> for RuvectorError {
fn from(err: redb::Error) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::DatabaseError> for RuvectorError {
fn from(err: redb::DatabaseError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::StorageError> for RuvectorError {
fn from(err: redb::StorageError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::TableError> for RuvectorError {
fn from(err: redb::TableError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::TransactionError> for RuvectorError {
fn from(err: redb::TransactionError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}
#[cfg(feature = "storage")]
impl From<redb::CommitError> for RuvectorError {
fn from(err: redb::CommitError) -> Self {
RuvectorError::DatabaseError(err.to_string())
}
}

View File

@@ -0,0 +1,36 @@
//! Index structures for efficient vector search
pub mod flat;
#[cfg(feature = "hnsw")]
pub mod hnsw;
use crate::error::Result;
use crate::types::{SearchResult, VectorId};
/// Trait for vector index implementations
pub trait VectorIndex: Send + Sync {
/// Add a vector to the index
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()>;
/// Add multiple vectors in batch
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
for (id, vector) in entries {
self.add(id, vector)?;
}
Ok(())
}
/// Search for k nearest neighbors
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
/// Remove a vector from the index
fn remove(&mut self, id: &VectorId) -> Result<bool>;
/// Get the number of vectors in the index
fn len(&self) -> usize;
/// Check if the index is empty
fn is_empty(&self) -> bool {
self.len() == 0
}
}

View File

@@ -0,0 +1,108 @@
//! Flat (brute-force) index for baseline and small datasets
use crate::distance::distance;
use crate::error::Result;
use crate::index::VectorIndex;
use crate::types::{DistanceMetric, SearchResult, VectorId};
use dashmap::DashMap;
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
use rayon::prelude::*;
/// Flat index using brute-force search
pub struct FlatIndex {
vectors: DashMap<VectorId, Vec<f32>>,
metric: DistanceMetric,
_dimensions: usize,
}
impl FlatIndex {
/// Create a new flat index
pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
Self {
vectors: DashMap::new(),
metric,
_dimensions: dimensions,
}
}
}
impl VectorIndex for FlatIndex {
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
self.vectors.insert(id, vector);
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
// Distance calculation - parallel on native, sequential on WASM
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
let mut results: Vec<_> = self
.vectors
.iter()
.par_bridge()
.map(|entry| {
let id = entry.key().clone();
let vector = entry.value();
let dist = distance(query, vector, self.metric)?;
Ok((id, dist))
})
.collect::<Result<Vec<_>>>()?;
#[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
let mut results: Vec<_> = self
.vectors
.iter()
.map(|entry| {
let id = entry.key().clone();
let vector = entry.value();
let dist = distance(query, vector, self.metric)?;
Ok((id, dist))
})
.collect::<Result<Vec<_>>>()?;
// Sort by distance and take top k
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
results.truncate(k);
Ok(results
.into_iter()
.map(|(id, score)| SearchResult {
id,
score,
vector: None,
metadata: None,
})
.collect())
}
fn remove(&mut self, id: &VectorId) -> Result<bool> {
Ok(self.vectors.remove(id).is_some())
}
fn len(&self) -> usize {
self.vectors.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flat_index() -> Result<()> {
let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2)?;
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "v1");
assert!(results[0].score < 0.01);
Ok(())
}
}

View File

@@ -0,0 +1,481 @@
//! HNSW (Hierarchical Navigable Small World) index implementation
use crate::distance::distance;
use crate::error::{Result, RuvectorError};
use crate::index::VectorIndex;
use crate::types::{DistanceMetric, HnswConfig, SearchResult, VectorId};
use bincode::{Decode, Encode};
use dashmap::DashMap;
use hnsw_rs::prelude::*;
use parking_lot::RwLock;
use std::sync::Arc;
/// Distance function wrapper for hnsw_rs
struct DistanceFn {
metric: DistanceMetric,
}
impl DistanceFn {
fn new(metric: DistanceMetric) -> Self {
Self { metric }
}
}
impl Distance<f32> for DistanceFn {
fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
distance(a, b, self.metric).unwrap_or(f32::MAX)
}
}
/// HNSW index wrapper
pub struct HnswIndex {
inner: Arc<RwLock<HnswInner>>,
config: HnswConfig,
metric: DistanceMetric,
dimensions: usize,
}
struct HnswInner {
hnsw: Hnsw<'static, f32, DistanceFn>,
vectors: DashMap<VectorId, Vec<f32>>,
id_to_idx: DashMap<VectorId, usize>,
idx_to_id: DashMap<usize, VectorId>,
next_idx: usize,
}
/// Serializable HNSW index state
#[derive(Encode, Decode, Clone)]
pub struct HnswState {
vectors: Vec<(String, Vec<f32>)>,
id_to_idx: Vec<(String, usize)>,
idx_to_id: Vec<(usize, String)>,
next_idx: usize,
config: SerializableHnswConfig,
dimensions: usize,
metric: SerializableDistanceMetric,
}
#[derive(Encode, Decode, Clone)]
struct SerializableHnswConfig {
m: usize,
ef_construction: usize,
ef_search: usize,
max_elements: usize,
}
#[derive(Encode, Decode, Clone, Copy)]
enum SerializableDistanceMetric {
Euclidean,
Cosine,
DotProduct,
Manhattan,
}
impl From<DistanceMetric> for SerializableDistanceMetric {
fn from(metric: DistanceMetric) -> Self {
match metric {
DistanceMetric::Euclidean => SerializableDistanceMetric::Euclidean,
DistanceMetric::Cosine => SerializableDistanceMetric::Cosine,
DistanceMetric::DotProduct => SerializableDistanceMetric::DotProduct,
DistanceMetric::Manhattan => SerializableDistanceMetric::Manhattan,
}
}
}
impl From<SerializableDistanceMetric> for DistanceMetric {
fn from(metric: SerializableDistanceMetric) -> Self {
match metric {
SerializableDistanceMetric::Euclidean => DistanceMetric::Euclidean,
SerializableDistanceMetric::Cosine => DistanceMetric::Cosine,
SerializableDistanceMetric::DotProduct => DistanceMetric::DotProduct,
SerializableDistanceMetric::Manhattan => DistanceMetric::Manhattan,
}
}
}
impl HnswIndex {
/// Create a new HNSW index
pub fn new(dimensions: usize, metric: DistanceMetric, config: HnswConfig) -> Result<Self> {
let distance_fn = DistanceFn::new(metric);
// Create HNSW with configured parameters
let hnsw = Hnsw::<f32, DistanceFn>::new(
config.m,
config.max_elements,
dimensions,
config.ef_construction,
distance_fn,
);
Ok(Self {
inner: Arc::new(RwLock::new(HnswInner {
hnsw,
vectors: DashMap::new(),
id_to_idx: DashMap::new(),
idx_to_id: DashMap::new(),
next_idx: 0,
})),
config,
metric,
dimensions,
})
}
/// Get configuration
pub fn config(&self) -> &HnswConfig {
&self.config
}
/// Set efSearch parameter for query-time accuracy tuning
pub fn set_ef_search(&mut self, _ef_search: usize) {
// Note: hnsw_rs controls ef_search via the search method's knbn parameter
// We store it in config and use it in search_with_ef
}
/// Serialize the index to bytes using bincode
pub fn serialize(&self) -> Result<Vec<u8>> {
let inner = self.inner.read();
let state = HnswState {
vectors: inner
.vectors
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect(),
id_to_idx: inner
.id_to_idx
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect(),
idx_to_id: inner
.idx_to_id
.iter()
.map(|entry| (*entry.key(), entry.value().clone()))
.collect(),
next_idx: inner.next_idx,
config: SerializableHnswConfig {
m: self.config.m,
ef_construction: self.config.ef_construction,
ef_search: self.config.ef_search,
max_elements: self.config.max_elements,
},
dimensions: self.dimensions,
metric: self.metric.into(),
};
bincode::encode_to_vec(&state, bincode::config::standard()).map_err(|e| {
RuvectorError::SerializationError(format!("Failed to serialize HNSW index: {}", e))
})
}
/// Deserialize the index from bytes using bincode
pub fn deserialize(bytes: &[u8]) -> Result<Self> {
let (state, _): (HnswState, usize) =
bincode::decode_from_slice(bytes, bincode::config::standard()).map_err(|e| {
RuvectorError::SerializationError(format!(
"Failed to deserialize HNSW index: {}",
e
))
})?;
let config = HnswConfig {
m: state.config.m,
ef_construction: state.config.ef_construction,
ef_search: state.config.ef_search,
max_elements: state.config.max_elements,
};
let dimensions = state.dimensions;
let metric: DistanceMetric = state.metric.into();
let distance_fn = DistanceFn::new(metric);
let mut hnsw = Hnsw::<'static, f32, DistanceFn>::new(
config.m,
config.max_elements,
dimensions,
config.ef_construction,
distance_fn,
);
// Rebuild the index by inserting all vectors
let id_to_idx: DashMap<VectorId, usize> = state.id_to_idx.into_iter().collect();
let idx_to_id: DashMap<usize, VectorId> = state.idx_to_id.into_iter().collect();
// Insert vectors into HNSW in order
for entry in idx_to_id.iter() {
let idx = *entry.key();
let id = entry.value();
if let Some(vector) = state.vectors.iter().find(|(vid, _)| vid == id) {
// Use insert_data method with slice and idx
hnsw.insert_data(&vector.1, idx);
}
}
let vectors_map: DashMap<VectorId, Vec<f32>> = state.vectors.into_iter().collect();
Ok(Self {
inner: Arc::new(RwLock::new(HnswInner {
hnsw,
vectors: vectors_map,
id_to_idx,
idx_to_id,
next_idx: state.next_idx,
})),
config,
metric,
dimensions,
})
}
/// Search with custom efSearch parameter
pub fn search_with_ef(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Result<Vec<SearchResult>> {
if query.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: query.len(),
});
}
let inner = self.inner.read();
// Use HNSW search with custom ef parameter (knbn)
let neighbors = inner.hnsw.search(query, k, ef_search);
Ok(neighbors
.into_iter()
.filter_map(|neighbor| {
inner.idx_to_id.get(&neighbor.d_id).map(|id| SearchResult {
id: id.clone(),
score: neighbor.distance,
vector: None,
metadata: None,
})
})
.collect())
}
}
impl VectorIndex for HnswIndex {
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
if vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: vector.len(),
});
}
let mut inner = self.inner.write();
let idx = inner.next_idx;
inner.next_idx += 1;
// Insert into HNSW graph using insert_data
inner.hnsw.insert_data(&vector, idx);
// Store mappings
inner.vectors.insert(id.clone(), vector);
inner.id_to_idx.insert(id.clone(), idx);
inner.idx_to_id.insert(idx, id);
Ok(())
}
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()> {
// Validate all dimensions first
for (_, vector) in &entries {
if vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: vector.len(),
});
}
}
let mut inner = self.inner.write();
// Prepare batch data for insertion
// First, assign indices and collect vector data
let data_with_ids: Vec<_> = entries
.iter()
.enumerate()
.map(|(i, (id, vector))| {
let idx = inner.next_idx + i;
(id.clone(), idx, vector.clone())
})
.collect();
// Update next_idx
inner.next_idx += entries.len();
// Insert into HNSW sequentially
// Note: Using sequential insertion to avoid Send requirements with RwLock guard
// For large batches, consider restructuring to use hnsw_rs parallel_insert
for (_id, idx, vector) in &data_with_ids {
inner.hnsw.insert_data(vector, *idx);
}
// Store mappings
for (id, idx, vector) in data_with_ids {
inner.vectors.insert(id.clone(), vector);
inner.id_to_idx.insert(id.clone(), idx);
inner.idx_to_id.insert(idx, id);
}
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
// Use configured ef_search
self.search_with_ef(query, k, self.config.ef_search)
}
fn remove(&mut self, id: &VectorId) -> Result<bool> {
let inner = self.inner.write();
// Note: hnsw_rs doesn't support direct deletion
// We remove from our mappings but the graph structure remains
// This is a known limitation of HNSW
let removed = inner.vectors.remove(id).is_some();
if removed {
if let Some((_, idx)) = inner.id_to_idx.remove(id) {
inner.idx_to_id.remove(&idx);
}
}
Ok(removed)
}
fn len(&self) -> usize {
self.inner.read().vectors.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_random_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..count)
.map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
.collect()
}
fn normalize_vector(v: &[f32]) -> Vec<f32> {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
#[test]
fn test_hnsw_index_creation() -> Result<()> {
let config = HnswConfig::default();
let index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
assert_eq!(index.len(), 0);
Ok(())
}
#[test]
fn test_hnsw_insert_and_search() -> Result<()> {
let config = HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 50,
max_elements: 1000,
};
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
// Insert a few vectors
let vectors = generate_random_vectors(100, 128);
for (i, vector) in vectors.iter().enumerate() {
let normalized = normalize_vector(vector);
index.add(format!("vec_{}", i), normalized)?;
}
assert_eq!(index.len(), 100);
// Search for the first vector
let query = normalize_vector(&vectors[0]);
let results = index.search(&query, 10)?;
assert!(!results.is_empty());
assert_eq!(results[0].id, "vec_0");
Ok(())
}
#[test]
fn test_hnsw_batch_insert() -> Result<()> {
let config = HnswConfig::default();
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
let vectors = generate_random_vectors(100, 128);
let entries: Vec<_> = vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), normalize_vector(v)))
.collect();
index.add_batch(entries)?;
assert_eq!(index.len(), 100);
Ok(())
}
#[test]
fn test_hnsw_serialization() -> Result<()> {
let config = HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 50,
max_elements: 1000,
};
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
// Insert vectors
let vectors = generate_random_vectors(50, 128);
for (i, vector) in vectors.iter().enumerate() {
let normalized = normalize_vector(vector);
index.add(format!("vec_{}", i), normalized)?;
}
// Serialize
let bytes = index.serialize()?;
// Deserialize
let restored_index = HnswIndex::deserialize(&bytes)?;
assert_eq!(restored_index.len(), 50);
// Test search on restored index
let query = normalize_vector(&vectors[0]);
let results = restored_index.search(&query, 5)?;
assert!(!results.is_empty());
Ok(())
}
#[test]
fn test_dimension_mismatch() -> Result<()> {
let config = HnswConfig::default();
let mut index = HnswIndex::new(128, DistanceMetric::Cosine, config)?;
let result = index.add("test".to_string(), vec![1.0; 64]);
assert!(result.is_err());
Ok(())
}
}

View File

@@ -0,0 +1,142 @@
//! # Ruvector Core
//!
//! High-performance Rust-native vector database with HNSW indexing and SIMD-optimized operations.
//!
//! ## Working Features (Tested & Benchmarked)
//!
//! - **HNSW Indexing**: Approximate nearest neighbor search with O(log n) complexity
//! - **SIMD Distance**: SimSIMD-powered distance calculations (~16M ops/sec for 512-dim)
//! - **Quantization**: Scalar (4x), Int4 (8x), Product (8-16x), and binary (32x) compression with distance support
//! - **Persistence**: REDB-based storage with config persistence
//! - **Search**: ~2.5K queries/sec on 10K vectors (benchmarked)
//!
//! ## ⚠️ Experimental/Incomplete Features - READ BEFORE USE
//!
//! - **AgenticDB**: ⚠️⚠️⚠️ **CRITICAL WARNING** ⚠️⚠️⚠️
//! - Uses PLACEHOLDER hash-based embeddings, NOT real semantic embeddings
//! - "dog" and "cat" will NOT be similar (different characters)
//! - "dog" and "god" WILL be similar (same characters) - **This is wrong!**
//! - **MUST integrate real embedding model for production** (ONNX, Candle, or API)
//! - See [`agenticdb`] module docs and `/examples/onnx-embeddings` for integration
//! - **Advanced Features**: Conformal prediction, hybrid search - functional but less tested
//!
//! ## What This Is NOT
//!
//! - This is NOT a complete RAG solution - you need external embedding models
//! - Examples use mock embeddings for demonstration only
#![allow(missing_docs)]
#![warn(clippy::all)]
#![allow(clippy::incompatible_msrv)]
pub mod advanced_features;
// AgenticDB requires storage feature
#[cfg(feature = "storage")]
pub mod agenticdb;
pub mod distance;
pub mod embeddings;
pub mod error;
pub mod index;
pub mod quantization;
// Storage backends - conditional compilation based on features
#[cfg(feature = "storage")]
pub mod storage;
#[cfg(not(feature = "storage"))]
pub mod storage_memory;
#[cfg(not(feature = "storage"))]
pub use storage_memory as storage;
pub mod types;
pub mod vector_db;
// Performance optimization modules
pub mod arena;
pub mod cache_optimized;
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub mod lockfree;
pub mod simd_intrinsics;
/// Unified Memory Pool and Paging System (ADR-006)
///
/// High-performance paged memory management for LLM inference:
/// - 2MB page-granular allocation with best-fit strategy
/// - Reference-counted pinning with RAII guards
/// - LRU eviction with hysteresis for thrash prevention
/// - Multi-tenant isolation with Hot/Warm/Cold residency tiers
pub mod memory;
/// Advanced techniques: hypergraphs, learned indexes, neural hashing, TDA (Phase 6)
pub mod advanced;
// Re-exports
pub use advanced_features::{
ConformalConfig, ConformalPredictor, EnhancedPQ, FilterExpression, FilterStrategy,
FilteredSearch, HybridConfig, HybridSearch, MMRConfig, MMRSearch, PQConfig, PredictionSet,
BM25,
};
#[cfg(feature = "storage")]
pub use agenticdb::{
AgenticDB, PolicyAction, PolicyEntry, PolicyMemoryStore, SessionStateIndex, SessionTurn,
WitnessEntry, WitnessLog,
};
#[cfg(feature = "api-embeddings")]
pub use embeddings::ApiEmbedding;
pub use embeddings::{BoxedEmbeddingProvider, EmbeddingProvider, HashEmbedding};
#[cfg(feature = "real-embeddings")]
pub use embeddings::CandleEmbedding;
// Compile-time warning about AgenticDB limitations
#[cfg(feature = "storage")]
#[allow(deprecated, clippy::let_unit_value)]
const _: () = {
#[deprecated(
since = "0.1.0",
note = "AgenticDB uses placeholder hash-based embeddings. For semantic search, integrate a real embedding model (ONNX, Candle, or API). See /examples/onnx-embeddings for production setup."
)]
const AGENTICDB_EMBEDDING_WARNING: () = ();
let _ = AGENTICDB_EMBEDDING_WARNING;
};
pub use error::{Result, RuvectorError};
pub use types::{DistanceMetric, SearchQuery, SearchResult, VectorEntry, VectorId};
pub use vector_db::VectorDB;
// Quantization types (ADR-001)
pub use quantization::{
BinaryQuantized, Int4Quantized, ProductQuantized, QuantizedVector, ScalarQuantized,
};
// Memory management types (ADR-001)
pub use arena::{Arena, ArenaVec, BatchVectorAllocator, CacheAlignedVec, CACHE_LINE_SIZE};
// Lock-free structures (requires parallel feature)
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub use lockfree::{
AtomicVectorPool, BatchItem, BatchResult, LockFreeBatchProcessor, LockFreeCounter,
LockFreeStats, LockFreeWorkQueue, ObjectPool, PooledObject, PooledVector, StatsSnapshot,
VectorPoolStats,
};
// Cache-optimized storage
pub use cache_optimized::SoAVectorStorage;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version() {
// Verify version matches workspace - use dynamic check instead of hardcoded value
let version = env!("CARGO_PKG_VERSION");
assert!(!version.is_empty(), "Version should not be empty");
assert!(version.starts_with("0.1."), "Version should be 0.1.x");
}
}

View File

@@ -0,0 +1,590 @@
//! Lock-free data structures for high-concurrency operations
//!
//! This module provides lock-free implementations of common data structures
//! to minimize contention and improve scalability.
//!
//! Note: This module requires the `parallel` feature and is not available on WASM.
#![cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
use crossbeam::queue::{ArrayQueue, SegQueue};
use crossbeam::utils::CachePadded;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
/// Lock-free counter with cache padding to prevent false sharing
#[repr(align(64))]
pub struct LockFreeCounter {
value: CachePadded<AtomicU64>,
}
impl LockFreeCounter {
pub fn new(initial: u64) -> Self {
Self {
value: CachePadded::new(AtomicU64::new(initial)),
}
}
#[inline]
pub fn increment(&self) -> u64 {
self.value.fetch_add(1, Ordering::Relaxed)
}
#[inline]
pub fn get(&self) -> u64 {
self.value.load(Ordering::Relaxed)
}
#[inline]
pub fn add(&self, delta: u64) -> u64 {
self.value.fetch_add(delta, Ordering::Relaxed)
}
}
/// Lock-free statistics collector
pub struct LockFreeStats {
queries: CachePadded<AtomicU64>,
inserts: CachePadded<AtomicU64>,
deletes: CachePadded<AtomicU64>,
total_latency_ns: CachePadded<AtomicU64>,
}
impl LockFreeStats {
pub fn new() -> Self {
Self {
queries: CachePadded::new(AtomicU64::new(0)),
inserts: CachePadded::new(AtomicU64::new(0)),
deletes: CachePadded::new(AtomicU64::new(0)),
total_latency_ns: CachePadded::new(AtomicU64::new(0)),
}
}
#[inline]
pub fn record_query(&self, latency_ns: u64) {
self.queries.fetch_add(1, Ordering::Relaxed);
self.total_latency_ns
.fetch_add(latency_ns, Ordering::Relaxed);
}
#[inline]
pub fn record_insert(&self) {
self.inserts.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_delete(&self) {
self.deletes.fetch_add(1, Ordering::Relaxed);
}
pub fn snapshot(&self) -> StatsSnapshot {
let queries = self.queries.load(Ordering::Relaxed);
let total_latency = self.total_latency_ns.load(Ordering::Relaxed);
StatsSnapshot {
queries,
inserts: self.inserts.load(Ordering::Relaxed),
deletes: self.deletes.load(Ordering::Relaxed),
avg_latency_ns: if queries > 0 {
total_latency / queries
} else {
0
},
}
}
}
impl Default for LockFreeStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct StatsSnapshot {
pub queries: u64,
pub inserts: u64,
pub deletes: u64,
pub avg_latency_ns: u64,
}
/// Lock-free object pool for reducing allocations
pub struct ObjectPool<T> {
queue: Arc<SegQueue<T>>,
factory: Arc<dyn Fn() -> T + Send + Sync>,
capacity: usize,
allocated: AtomicUsize,
}
impl<T> ObjectPool<T> {
pub fn new<F>(capacity: usize, factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
Self {
queue: Arc::new(SegQueue::new()),
factory: Arc::new(factory),
capacity,
allocated: AtomicUsize::new(0),
}
}
/// Get an object from the pool or create a new one
pub fn acquire(&self) -> PooledObject<T> {
let object = self.queue.pop().unwrap_or_else(|| {
let current = self.allocated.fetch_add(1, Ordering::Relaxed);
if current < self.capacity {
(self.factory)()
} else {
self.allocated.fetch_sub(1, Ordering::Relaxed);
// Wait for an object to be returned
loop {
if let Some(obj) = self.queue.pop() {
break obj;
}
std::hint::spin_loop();
}
}
});
PooledObject {
object: Some(object),
pool: Arc::clone(&self.queue),
}
}
}
/// RAII wrapper for pooled objects
pub struct PooledObject<T> {
object: Option<T>,
pool: Arc<SegQueue<T>>,
}
impl<T> PooledObject<T> {
pub fn get(&self) -> &T {
self.object.as_ref().unwrap()
}
pub fn get_mut(&mut self) -> &mut T {
self.object.as_mut().unwrap()
}
}
impl<T> Drop for PooledObject<T> {
fn drop(&mut self) {
if let Some(object) = self.object.take() {
self.pool.push(object);
}
}
}
impl<T> std::ops::Deref for PooledObject<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.object.as_ref().unwrap()
}
}
impl<T> std::ops::DerefMut for PooledObject<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.object.as_mut().unwrap()
}
}
/// Lock-free ring buffer for work distribution
pub struct LockFreeWorkQueue<T> {
queue: ArrayQueue<T>,
}
impl<T> LockFreeWorkQueue<T> {
pub fn new(capacity: usize) -> Self {
Self {
queue: ArrayQueue::new(capacity),
}
}
#[inline]
pub fn try_push(&self, item: T) -> Result<(), T> {
self.queue.push(item)
}
#[inline]
pub fn try_pop(&self) -> Option<T> {
self.queue.pop()
}
#[inline]
pub fn len(&self) -> usize {
self.queue.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
}
/// Atomic vector pool for lock-free vector operations (ADR-001)
///
/// Provides a pool of pre-allocated vectors that can be acquired and released
/// without locking, ideal for high-throughput batch operations.
pub struct AtomicVectorPool {
/// Pool of available vectors
pool: SegQueue<Vec<f32>>,
/// Dimensions per vector
dimensions: usize,
/// Maximum pool size
max_size: usize,
/// Current pool size
size: AtomicUsize,
/// Total allocations
total_allocations: AtomicU64,
/// Pool hits (reused vectors)
pool_hits: AtomicU64,
}
impl AtomicVectorPool {
/// Create a new atomic vector pool
pub fn new(dimensions: usize, initial_size: usize, max_size: usize) -> Self {
let pool = SegQueue::new();
// Pre-allocate vectors
for _ in 0..initial_size {
pool.push(vec![0.0; dimensions]);
}
Self {
pool,
dimensions,
max_size,
size: AtomicUsize::new(initial_size),
total_allocations: AtomicU64::new(0),
pool_hits: AtomicU64::new(0),
}
}
/// Acquire a vector from the pool (or allocate new one)
pub fn acquire(&self) -> PooledVector<'_> {
self.total_allocations.fetch_add(1, Ordering::Relaxed);
let vec = if let Some(mut v) = self.pool.pop() {
self.pool_hits.fetch_add(1, Ordering::Relaxed);
// Clear the vector for reuse
v.fill(0.0);
v
} else {
// Allocate new vector
vec![0.0; self.dimensions]
};
PooledVector {
vec: Some(vec),
pool: self,
}
}
/// Return a vector to the pool
fn return_to_pool(&self, vec: Vec<f32>) {
let current_size = self.size.load(Ordering::Relaxed);
if current_size < self.max_size {
self.pool.push(vec);
self.size.fetch_add(1, Ordering::Relaxed);
}
// If pool is full, vector is dropped
}
/// Get pool statistics
pub fn stats(&self) -> VectorPoolStats {
let total = self.total_allocations.load(Ordering::Relaxed);
let hits = self.pool_hits.load(Ordering::Relaxed);
let hit_rate = if total > 0 {
hits as f64 / total as f64
} else {
0.0
};
VectorPoolStats {
total_allocations: total,
pool_hits: hits,
hit_rate,
current_size: self.size.load(Ordering::Relaxed),
max_size: self.max_size,
}
}
/// Get dimensions
pub fn dimensions(&self) -> usize {
self.dimensions
}
}
/// Statistics for the vector pool
#[derive(Debug, Clone)]
pub struct VectorPoolStats {
pub total_allocations: u64,
pub pool_hits: u64,
pub hit_rate: f64,
pub current_size: usize,
pub max_size: usize,
}
/// RAII wrapper for pooled vectors
pub struct PooledVector<'a> {
vec: Option<Vec<f32>>,
pool: &'a AtomicVectorPool,
}
impl<'a> PooledVector<'a> {
/// Get as slice
pub fn as_slice(&self) -> &[f32] {
self.vec.as_ref().unwrap()
}
/// Get as mutable slice
pub fn as_mut_slice(&mut self) -> &mut [f32] {
self.vec.as_mut().unwrap()
}
/// Copy from source slice
pub fn copy_from(&mut self, src: &[f32]) {
let vec = self.vec.as_mut().unwrap();
assert_eq!(vec.len(), src.len(), "Dimension mismatch");
vec.copy_from_slice(src);
}
/// Detach the vector from the pool (it won't be returned)
pub fn detach(mut self) -> Vec<f32> {
self.vec.take().unwrap()
}
}
impl<'a> Drop for PooledVector<'a> {
fn drop(&mut self) {
if let Some(vec) = self.vec.take() {
self.pool.return_to_pool(vec);
}
}
}
impl<'a> std::ops::Deref for PooledVector<'a> {
type Target = [f32];
fn deref(&self) -> &[f32] {
self.as_slice()
}
}
impl<'a> std::ops::DerefMut for PooledVector<'a> {
fn deref_mut(&mut self) -> &mut [f32] {
self.as_mut_slice()
}
}
/// Lock-free batch processor for parallel vector operations (ADR-001)
///
/// Distributes work across multiple workers without contention.
pub struct LockFreeBatchProcessor {
/// Work queue for pending items
work_queue: ArrayQueue<BatchItem>,
/// Results queue
results_queue: SegQueue<BatchResult>,
/// Pending count
pending: AtomicUsize,
/// Completed count
completed: AtomicUsize,
}
/// Item in the batch work queue
#[derive(Debug)]
pub struct BatchItem {
pub id: u64,
pub data: Vec<f32>,
}
/// Result from batch processing
pub struct BatchResult {
pub id: u64,
pub result: Vec<f32>,
}
impl LockFreeBatchProcessor {
/// Create a new batch processor with given capacity
pub fn new(capacity: usize) -> Self {
Self {
work_queue: ArrayQueue::new(capacity),
results_queue: SegQueue::new(),
pending: AtomicUsize::new(0),
completed: AtomicUsize::new(0),
}
}
/// Submit a batch item for processing
pub fn submit(&self, item: BatchItem) -> Result<(), BatchItem> {
self.pending.fetch_add(1, Ordering::Relaxed);
self.work_queue.push(item)
}
/// Try to get a work item (for workers)
pub fn try_get_work(&self) -> Option<BatchItem> {
self.work_queue.pop()
}
/// Submit a result (from workers)
pub fn submit_result(&self, result: BatchResult) {
self.completed.fetch_add(1, Ordering::Relaxed);
self.results_queue.push(result);
}
/// Collect all available results
pub fn collect_results(&self) -> Vec<BatchResult> {
let mut results = Vec::new();
while let Some(result) = self.results_queue.pop() {
results.push(result);
}
results
}
/// Get pending count
pub fn pending(&self) -> usize {
self.pending.load(Ordering::Relaxed)
}
/// Get completed count
pub fn completed(&self) -> usize {
self.completed.load(Ordering::Relaxed)
}
/// Check if all work is done
pub fn is_done(&self) -> bool {
self.pending() == self.completed()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_lockfree_counter() {
let counter = Arc::new(LockFreeCounter::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter_clone = Arc::clone(&counter);
handles.push(thread::spawn(move || {
for _ in 0..1000 {
counter_clone.increment();
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(counter.get(), 10000);
}
#[test]
fn test_object_pool() {
let pool = ObjectPool::new(4, || Vec::<u8>::with_capacity(1024));
let mut obj1 = pool.acquire();
obj1.push(1);
assert_eq!(obj1.len(), 1);
drop(obj1);
let obj2 = pool.acquire();
// Object should be reused (but cleared state is not guaranteed)
assert!(obj2.capacity() >= 1024);
}
#[test]
fn test_stats_collector() {
let stats = LockFreeStats::new();
stats.record_query(1000);
stats.record_query(2000);
stats.record_insert();
let snapshot = stats.snapshot();
assert_eq!(snapshot.queries, 2);
assert_eq!(snapshot.inserts, 1);
assert_eq!(snapshot.avg_latency_ns, 1500);
}
#[test]
fn test_atomic_vector_pool() {
let pool = AtomicVectorPool::new(4, 2, 10);
// Acquire first vector
let mut v1 = pool.acquire();
v1.copy_from(&[1.0, 2.0, 3.0, 4.0]);
assert_eq!(v1.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
// Acquire second vector
let mut v2 = pool.acquire();
v2.copy_from(&[5.0, 6.0, 7.0, 8.0]);
// Stats should show allocations
let stats = pool.stats();
assert_eq!(stats.total_allocations, 2);
}
#[test]
fn test_vector_pool_reuse() {
let pool = AtomicVectorPool::new(3, 1, 5);
// Acquire and release
{
let mut v = pool.acquire();
v.copy_from(&[1.0, 2.0, 3.0]);
} // v is returned to pool here
// Acquire again - should be a pool hit
let _v2 = pool.acquire();
let stats = pool.stats();
assert_eq!(stats.total_allocations, 2);
assert!(stats.pool_hits >= 1, "Should have at least one pool hit");
}
#[test]
fn test_batch_processor() {
let processor = LockFreeBatchProcessor::new(10);
// Submit work items
processor
.submit(BatchItem {
id: 1,
data: vec![1.0, 2.0],
})
.unwrap();
processor
.submit(BatchItem {
id: 2,
data: vec![3.0, 4.0],
})
.unwrap();
assert_eq!(processor.pending(), 2);
// Process work
while let Some(item) = processor.try_get_work() {
let result = BatchResult {
id: item.id,
result: item.data.iter().map(|x| x * 2.0).collect(),
};
processor.submit_result(result);
}
assert!(processor.is_done());
assert_eq!(processor.completed(), 2);
// Collect results
let results = processor.collect_results();
assert_eq!(results.len(), 2);
}
}

View File

@@ -0,0 +1,38 @@
//! Memory management utilities for ruvector-core
//!
//! This module provides memory-efficient data structures and utilities
//! for vector storage operations.
/// Memory pool for vector allocations.
#[derive(Debug, Default)]
pub struct MemoryPool {
/// Total allocated bytes.
allocated: usize,
/// Maximum allocation limit.
limit: Option<usize>,
}
impl MemoryPool {
/// Create a new memory pool.
pub fn new() -> Self {
Self::default()
}
/// Create a memory pool with a limit.
pub fn with_limit(limit: usize) -> Self {
Self {
allocated: 0,
limit: Some(limit),
}
}
/// Get currently allocated bytes.
pub fn allocated(&self) -> usize {
self.allocated
}
/// Get the allocation limit, if any.
pub fn limit(&self) -> Option<usize> {
self.limit
}
}

View File

@@ -0,0 +1,934 @@
//! Quantization techniques for memory compression
//!
//! This module provides tiered quantization strategies as specified in ADR-001:
//!
//! | Quantization | Compression | Use Case |
//! |--------------|-------------|----------|
//! | Scalar (u8) | 4x | Warm data (40-80% access) |
//! | Int4 | 8x | Cool data (10-40% access) |
//! | Product | 8-16x | Cold data (1-10% access) |
//! | Binary | 32x | Archive (<1% access) |
//!
//! ## Performance Optimizations v2
//!
//! - SIMD-accelerated distance calculations for scalar (int8) quantization
//! - SIMD popcnt for binary hamming distance
//! - 4x loop unrolling for better instruction-level parallelism
//! - Separate accumulator strategy to reduce data dependencies
use crate::error::Result;
use serde::{Deserialize, Serialize};
/// Trait for quantized vector representations
pub trait QuantizedVector: Send + Sync {
/// Quantize a full-precision vector
fn quantize(vector: &[f32]) -> Self;
/// Calculate distance to another quantized vector
fn distance(&self, other: &Self) -> f32;
/// Reconstruct approximate full-precision vector
fn reconstruct(&self) -> Vec<f32>;
}
/// Scalar quantization to int8 (4x compression)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantized {
/// Quantized values (int8)
pub data: Vec<u8>,
/// Minimum value for dequantization
pub min: f32,
/// Scale factor for dequantization
pub scale: f32,
}
impl QuantizedVector for ScalarQuantized {
fn quantize(vector: &[f32]) -> Self {
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
// Handle edge case where all values are the same (scale = 0)
let scale = if (max - min).abs() < f32::EPSILON {
1.0 // Arbitrary non-zero scale when all values are identical
} else {
(max - min) / 255.0
};
let data = vector
.iter()
.map(|&v| ((v - min) / scale).round().clamp(0.0, 255.0) as u8)
.collect();
Self { data, min, scale }
}
fn distance(&self, other: &Self) -> f32 {
// Fast int8 distance calculation with SIMD optimization
// Use i32 to avoid overflow: max diff is 255, and 255*255=65025 fits in i32
// Scale handling: We use the average of both scales for balanced comparison.
// Using max(scale) would bias toward the vector with larger range,
// while average provides a more symmetric distance metric.
// This ensures distance(a, b) ≈ distance(b, a) in the reconstructed space.
let avg_scale = (self.scale + other.scale) / 2.0;
// Use SIMD-optimized version for larger vectors
#[cfg(target_arch = "aarch64")]
{
if self.data.len() >= 16 {
return unsafe { scalar_distance_neon(&self.data, &other.data) }.sqrt() * avg_scale;
}
}
#[cfg(target_arch = "x86_64")]
{
if self.data.len() >= 32 && is_x86_feature_detected!("avx2") {
return unsafe { scalar_distance_avx2(&self.data, &other.data) }.sqrt() * avg_scale;
}
}
// Scalar fallback with 4x loop unrolling for better ILP
scalar_distance_scalar(&self.data, &other.data).sqrt() * avg_scale
}
fn reconstruct(&self) -> Vec<f32> {
self.data
.iter()
.map(|&v| self.min + (v as f32) * self.scale)
.collect()
}
}
/// Product quantization (8-16x compression)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantized {
/// Quantized codes (one per subspace)
pub codes: Vec<u8>,
/// Codebooks for each subspace
pub codebooks: Vec<Vec<Vec<f32>>>,
}
impl ProductQuantized {
/// Train product quantization on a set of vectors
pub fn train(
vectors: &[Vec<f32>],
num_subspaces: usize,
codebook_size: usize,
iterations: usize,
) -> Result<Self> {
if vectors.is_empty() {
return Err(crate::error::RuvectorError::InvalidInput(
"Cannot train on empty vector set".into(),
));
}
if vectors[0].is_empty() {
return Err(crate::error::RuvectorError::InvalidInput(
"Cannot train on vectors with zero dimensions".into(),
));
}
if codebook_size > 256 {
return Err(crate::error::RuvectorError::InvalidParameter(format!(
"Codebook size {} exceeds u8 maximum of 256",
codebook_size
)));
}
let dimensions = vectors[0].len();
let subspace_dim = dimensions / num_subspaces;
let mut codebooks = Vec::with_capacity(num_subspaces);
// Train codebook for each subspace using k-means
for subspace_idx in 0..num_subspaces {
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
// Extract subspace vectors
let subspace_vectors: Vec<Vec<f32>> =
vectors.iter().map(|v| v[start..end].to_vec()).collect();
// Run k-means
let codebook = kmeans_clustering(&subspace_vectors, codebook_size, iterations);
codebooks.push(codebook);
}
Ok(Self {
codes: vec![],
codebooks,
})
}
/// Quantize a vector using trained codebooks
pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
let num_subspaces = self.codebooks.len();
let subspace_dim = vector.len() / num_subspaces;
let mut codes = Vec::with_capacity(num_subspaces);
for (subspace_idx, codebook) in self.codebooks.iter().enumerate() {
let start = subspace_idx * subspace_dim;
let end = start + subspace_dim;
let subvector = &vector[start..end];
// Find nearest centroid
let code = codebook
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = euclidean_squared(subvector, a);
let dist_b = euclidean_squared(subvector, b);
dist_a.partial_cmp(&dist_b).unwrap()
})
.map(|(idx, _)| idx as u8)
.unwrap_or(0);
codes.push(code);
}
codes
}
}
/// Int4 quantization (8x compression)
///
/// Quantizes f32 to 4-bit integers (0-15), packing 2 values per byte.
/// Provides 8x compression with better precision than binary.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Int4Quantized {
/// Packed 4-bit values (2 per byte)
pub data: Vec<u8>,
/// Minimum value for dequantization
pub min: f32,
/// Scale factor for dequantization
pub scale: f32,
/// Number of dimensions
pub dimensions: usize,
}
impl Int4Quantized {
/// Quantize a vector to 4-bit representation
pub fn quantize(vector: &[f32]) -> Self {
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
// Handle edge case where all values are the same
let scale = if (max - min).abs() < f32::EPSILON {
1.0
} else {
(max - min) / 15.0 // 4-bit gives 0-15 range
};
let dimensions = vector.len();
let num_bytes = dimensions.div_ceil(2);
let mut data = vec![0u8; num_bytes];
for (i, &v) in vector.iter().enumerate() {
let quantized = ((v - min) / scale).round().clamp(0.0, 15.0) as u8;
let byte_idx = i / 2;
if i % 2 == 0 {
// Low nibble
data[byte_idx] |= quantized;
} else {
// High nibble
data[byte_idx] |= quantized << 4;
}
}
Self {
data,
min,
scale,
dimensions,
}
}
/// Calculate distance to another Int4 quantized vector
pub fn distance(&self, other: &Self) -> f32 {
assert_eq!(self.dimensions, other.dimensions);
// Use average scale for balanced comparison
let avg_scale = (self.scale + other.scale) / 2.0;
let _avg_min = (self.min + other.min) / 2.0;
let mut sum_sq = 0i32;
for i in 0..self.dimensions {
let byte_idx = i / 2;
let shift = if i % 2 == 0 { 0 } else { 4 };
let a = ((self.data[byte_idx] >> shift) & 0x0F) as i32;
let b = ((other.data[byte_idx] >> shift) & 0x0F) as i32;
let diff = a - b;
sum_sq += diff * diff;
}
(sum_sq as f32).sqrt() * avg_scale
}
/// Reconstruct approximate full-precision vector
pub fn reconstruct(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.dimensions);
for i in 0..self.dimensions {
let byte_idx = i / 2;
let shift = if i % 2 == 0 { 0 } else { 4 };
let quantized = (self.data[byte_idx] >> shift) & 0x0F;
result.push(self.min + (quantized as f32) * self.scale);
}
result
}
/// Get compression ratio (8x for Int4)
pub fn compression_ratio() -> f32 {
8.0 // f32 (4 bytes) -> 4 bits (0.5 bytes)
}
}
/// Binary quantization (32x compression)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BinaryQuantized {
/// Binary representation (1 bit per dimension, packed into bytes)
pub bits: Vec<u8>,
/// Number of dimensions
pub dimensions: usize,
}
impl QuantizedVector for BinaryQuantized {
fn quantize(vector: &[f32]) -> Self {
let dimensions = vector.len();
let num_bytes = dimensions.div_ceil(8);
let mut bits = vec![0u8; num_bytes];
for (i, &v) in vector.iter().enumerate() {
if v > 0.0 {
let byte_idx = i / 8;
let bit_idx = i % 8;
bits[byte_idx] |= 1 << bit_idx;
}
}
Self { bits, dimensions }
}
fn distance(&self, other: &Self) -> f32 {
// Hamming distance using SIMD-friendly operations
Self::hamming_distance_fast(&self.bits, &other.bits) as f32
}
fn reconstruct(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.dimensions);
for i in 0..self.dimensions {
let byte_idx = i / 8;
let bit_idx = i % 8;
let bit = (self.bits[byte_idx] >> bit_idx) & 1;
result.push(if bit == 1 { 1.0 } else { -1.0 });
}
result
}
}
impl BinaryQuantized {
/// Fast hamming distance using SIMD-optimized operations
///
/// Uses hardware POPCNT on x86_64 or NEON vcnt on ARM64 for optimal performance.
/// Processes 16 bytes at a time on ARM64, 8 bytes at a time on x86_64.
/// Falls back to 64-bit operations for remainders.
pub fn hamming_distance_fast(a: &[u8], b: &[u8]) -> u32 {
// Use SIMD-optimized version based on architecture
#[cfg(target_arch = "aarch64")]
{
if a.len() >= 16 {
return unsafe { hamming_distance_neon(a, b) };
}
}
#[cfg(target_arch = "x86_64")]
{
if a.len() >= 8 && is_x86_feature_detected!("popcnt") {
return unsafe { hamming_distance_simd_x86(a, b) };
}
}
// Scalar fallback using 64-bit operations
let mut distance = 0u32;
// Process 8 bytes at a time using u64
let chunks_a = a.chunks_exact(8);
let chunks_b = b.chunks_exact(8);
let remainder_a = chunks_a.remainder();
let remainder_b = chunks_b.remainder();
for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
let a_u64 = u64::from_le_bytes(chunk_a.try_into().unwrap());
let b_u64 = u64::from_le_bytes(chunk_b.try_into().unwrap());
distance += (a_u64 ^ b_u64).count_ones();
}
// Handle remainder bytes
for (&a_byte, &b_byte) in remainder_a.iter().zip(remainder_b) {
distance += (a_byte ^ b_byte).count_ones();
}
distance
}
/// Compute normalized hamming similarity (0.0 to 1.0)
pub fn similarity(&self, other: &Self) -> f32 {
let distance = self.distance(other);
1.0 - (distance / self.dimensions as f32)
}
/// Get compression ratio (32x for binary)
pub fn compression_ratio() -> f32 {
32.0 // f32 (4 bytes = 32 bits) -> 1 bit
}
/// Convert to bytes for storage
pub fn to_bytes(&self) -> &[u8] {
&self.bits
}
/// Create from bytes
pub fn from_bytes(bits: Vec<u8>, dimensions: usize) -> Self {
Self { bits, dimensions }
}
}
// ============================================================================
// Helper functions for scalar quantization distance
// ============================================================================
/// Scalar fallback for scalar quantization distance (sum of squared differences)
fn scalar_distance_scalar(a: &[u8], b: &[u8]) -> f32 {
let mut sum_sq = 0i32;
// 4x loop unrolling for better ILP
let chunks = a.len() / 4;
for i in 0..chunks {
let idx = i * 4;
let d0 = (a[idx] as i32) - (b[idx] as i32);
let d1 = (a[idx + 1] as i32) - (b[idx + 1] as i32);
let d2 = (a[idx + 2] as i32) - (b[idx + 2] as i32);
let d3 = (a[idx + 3] as i32) - (b[idx + 3] as i32);
sum_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
}
// Handle remainder
for i in (chunks * 4)..a.len() {
let diff = (a[i] as i32) - (b[i] as i32);
sum_sq += diff * diff;
}
sum_sq as f32
}
/// NEON SIMD distance for scalar quantization
///
/// # Safety
/// Caller must ensure a.len() == b.len()
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn scalar_distance_neon(a: &[u8], b: &[u8]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum = vdupq_n_s32(0);
// Process 8 bytes at a time
let chunks = len / 8;
let mut idx = 0usize;
for _ in 0..chunks {
// Load 8 u8 values
let va = vld1_u8(a_ptr.add(idx));
let vb = vld1_u8(b_ptr.add(idx));
// Zero-extend u8 to u16
let va_u16 = vmovl_u8(va);
let vb_u16 = vmovl_u8(vb);
// Convert to signed for subtraction
let va_s16 = vreinterpretq_s16_u16(va_u16);
let vb_s16 = vreinterpretq_s16_u16(vb_u16);
// Compute difference
let diff = vsubq_s16(va_s16, vb_s16);
// Square and accumulate
let prod_lo = vmull_s16(vget_low_s16(diff), vget_low_s16(diff));
let prod_hi = vmull_s16(vget_high_s16(diff), vget_high_s16(diff));
sum = vaddq_s32(sum, prod_lo);
sum = vaddq_s32(sum, prod_hi);
idx += 8;
}
let mut total = vaddvq_s32(sum);
// Handle remainder with bounds-check elimination
for i in (chunks * 8)..len {
let diff = (*a.get_unchecked(i) as i32) - (*b.get_unchecked(i) as i32);
total += diff * diff;
}
total as f32
}
/// AVX2 SIMD distance for scalar quantization
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn scalar_distance_avx2(a: &[u8], b: &[u8]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm256_setzero_si256();
// Process 16 bytes at a time
let chunks = len / 16;
for i in 0..chunks {
let idx = i * 16;
// Load 16 u8 values
let va = _mm_loadu_si128(a.as_ptr().add(idx) as *const __m128i);
let vb = _mm_loadu_si128(b.as_ptr().add(idx) as *const __m128i);
// Zero-extend u8 to i16 (low and high halves)
let va_lo = _mm256_cvtepu8_epi16(va);
let vb_lo = _mm256_cvtepu8_epi16(vb);
// Compute difference
let diff = _mm256_sub_epi16(va_lo, vb_lo);
// Square (multiply i16 * i16 -> i32)
let prod = _mm256_madd_epi16(diff, diff);
// Accumulate
sum = _mm256_add_epi32(sum, prod);
}
// Horizontal sum
let sum_lo = _mm256_castsi256_si128(sum);
let sum_hi = _mm256_extracti128_si256(sum, 1);
let sum_128 = _mm_add_epi32(sum_lo, sum_hi);
let shuffle = _mm_shuffle_epi32(sum_128, 0b10_11_00_01);
let sum_64 = _mm_add_epi32(sum_128, shuffle);
let shuffle2 = _mm_shuffle_epi32(sum_64, 0b00_00_10_10);
let final_sum = _mm_add_epi32(sum_64, shuffle2);
let mut total = _mm_cvtsi128_si32(final_sum);
// Handle remainder
for i in (chunks * 16)..len {
let diff = (a[i] as i32) - (b[i] as i32);
total += diff * diff;
}
total as f32
}
// Helper functions
fn euclidean_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b)
.map(|(&x, &y)| {
let diff = x - y;
diff * diff
})
.sum()
}
fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
use rand::seq::SliceRandom;
use rand::thread_rng;
let mut rng = thread_rng();
// Initialize centroids randomly
let mut centroids: Vec<Vec<f32>> = vectors.choose_multiple(&mut rng, k).cloned().collect();
for _ in 0..iterations {
// Assign vectors to nearest centroid
let mut assignments = vec![Vec::new(); k];
for vector in vectors {
let nearest = centroids
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let dist_a = euclidean_squared(vector, a);
let dist_b = euclidean_squared(vector, b);
dist_a.partial_cmp(&dist_b).unwrap()
})
.map(|(idx, _)| idx)
.unwrap_or(0);
assignments[nearest].push(vector.clone());
}
// Update centroids
for (centroid, assigned) in centroids.iter_mut().zip(&assignments) {
if !assigned.is_empty() {
let dim = centroid.len();
*centroid = vec![0.0; dim];
for vector in assigned {
for (i, &v) in vector.iter().enumerate() {
centroid[i] += v;
}
}
let count = assigned.len() as f32;
for v in centroid.iter_mut() {
*v /= count;
}
}
}
}
centroids
}
// =============================================================================
// SIMD-Optimized Distance Calculations for Quantized Vectors
// =============================================================================
// NOTE: scalar_distance_scalar is already defined above (lines 404-425)
// NOTE: scalar_distance_neon is already defined above (lines 430-473)
// NOTE: scalar_distance_avx2 is already defined above (lines 479-540)
// This section uses the existing implementations for consistency
/// SIMD-optimized hamming distance using popcnt
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "popcnt")]
#[inline]
unsafe fn hamming_distance_simd_x86(a: &[u8], b: &[u8]) -> u32 {
use std::arch::x86_64::*;
let mut distance = 0u64;
// Process 8 bytes at a time using u64 with hardware popcnt
let chunks_a = a.chunks_exact(8);
let chunks_b = b.chunks_exact(8);
let remainder_a = chunks_a.remainder();
let remainder_b = chunks_b.remainder();
for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
let a_u64 = u64::from_le_bytes(chunk_a.try_into().unwrap());
let b_u64 = u64::from_le_bytes(chunk_b.try_into().unwrap());
distance += _popcnt64((a_u64 ^ b_u64) as i64) as u64;
}
// Handle remainder
for (&a_byte, &b_byte) in remainder_a.iter().zip(remainder_b) {
distance += (a_byte ^ b_byte).count_ones() as u64;
}
distance as u32
}
/// NEON-optimized hamming distance for ARM64
///
/// # Safety
/// Caller must ensure a.len() == b.len()
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
use std::arch::aarch64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks = len / 16;
let mut idx = 0usize;
let mut sum = vdupq_n_u8(0);
for _ in 0..chunks {
// Load 16 bytes
let a_vec = vld1q_u8(a_ptr.add(idx));
let b_vec = vld1q_u8(b_ptr.add(idx));
// XOR and count bits using vcntq_u8 (population count)
let xor_result = veorq_u8(a_vec, b_vec);
let bits = vcntq_u8(xor_result);
// Accumulate
sum = vaddq_u8(sum, bits);
idx += 16;
}
// Horizontal sum
let sum_val = vaddvq_u8(sum) as u32;
// Handle remainder with bounds-check elimination
let mut remainder_sum = 0u32;
let start = chunks * 16;
for i in start..len {
remainder_sum += (*a.get_unchecked(i) ^ *b.get_unchecked(i)).count_ones();
}
sum_val + remainder_sum
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_quantization() {
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let quantized = ScalarQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
// Check approximate reconstruction
for (orig, recon) in vector.iter().zip(&reconstructed) {
assert!((orig - recon).abs() < 0.1);
}
}
#[test]
fn test_binary_quantization() {
let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5];
let quantized = BinaryQuantized::quantize(&vector);
assert_eq!(quantized.dimensions, 5);
assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
}
#[test]
fn test_binary_distance() {
let v1 = vec![1.0, 1.0, 1.0, 1.0];
let v2 = vec![1.0, 1.0, -1.0, -1.0];
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let dist = q1.distance(&q2);
assert_eq!(dist, 2.0); // 2 bits differ
}
#[test]
fn test_scalar_quantization_roundtrip() {
// Test that quantize -> reconstruct produces values close to original
let test_vectors = vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![-10.0, -5.0, 0.0, 5.0, 10.0],
vec![0.1, 0.2, 0.3, 0.4, 0.5],
vec![100.0, 200.0, 300.0, 400.0, 500.0],
];
for vector in test_vectors {
let quantized = ScalarQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
assert_eq!(vector.len(), reconstructed.len());
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
// With 8-bit quantization, max error is roughly (max-min)/255
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
let max_error = (max - min) / 255.0 * 2.0; // Allow 2x for rounding
assert!(
(orig - recon).abs() < max_error,
"Roundtrip error too large: orig={}, recon={}, error={}",
orig,
recon,
(orig - recon).abs()
);
}
}
}
#[test]
fn test_scalar_distance_symmetry() {
// Test that distance(a, b) == distance(b, a)
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let q1 = ScalarQuantized::quantize(&v1);
let q2 = ScalarQuantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
// Distance should be symmetric (within floating point precision)
assert!(
(dist_ab - dist_ba).abs() < 0.01,
"Distance is not symmetric: d(a,b)={}, d(b,a)={}",
dist_ab,
dist_ba
);
}
#[test]
fn test_scalar_distance_different_scales() {
// Test distance calculation with vectors that have different scales
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // range: 4.0
let v2 = vec![10.0, 20.0, 30.0, 40.0, 50.0]; // range: 40.0
let q1 = ScalarQuantized::quantize(&v1);
let q2 = ScalarQuantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
// With average scaling, symmetry should be maintained
assert!(
(dist_ab - dist_ba).abs() < 0.01,
"Distance with different scales not symmetric: d(a,b)={}, d(b,a)={}",
dist_ab,
dist_ba
);
}
#[test]
fn test_scalar_quantization_edge_cases() {
// Test with all same values
let same_values = vec![5.0, 5.0, 5.0, 5.0];
let quantized = ScalarQuantized::quantize(&same_values);
let reconstructed = quantized.reconstruct();
for (orig, recon) in same_values.iter().zip(reconstructed.iter()) {
assert!((orig - recon).abs() < 0.01);
}
// Test with extreme ranges
let extreme = vec![f32::MIN / 1e10, 0.0, f32::MAX / 1e10];
let quantized = ScalarQuantized::quantize(&extreme);
let reconstructed = quantized.reconstruct();
assert_eq!(extreme.len(), reconstructed.len());
}
#[test]
fn test_binary_distance_symmetry() {
// Test that binary distance is symmetric
let v1 = vec![1.0, -1.0, 1.0, -1.0];
let v2 = vec![1.0, 1.0, -1.0, -1.0];
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
assert_eq!(
dist_ab, dist_ba,
"Binary distance not symmetric: d(a,b)={}, d(b,a)={}",
dist_ab, dist_ba
);
}
#[test]
fn test_int4_quantization() {
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let quantized = Int4Quantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
assert_eq!(quantized.dimensions, 5);
// 5 dimensions = 3 bytes (2 per byte, last byte has 1)
assert_eq!(quantized.data.len(), 3);
// Check approximate reconstruction
for (orig, recon) in vector.iter().zip(&reconstructed) {
// With 4-bit quantization, max error is roughly (max-min)/15
let max_error = (5.0 - 1.0) / 15.0 * 2.0;
assert!(
(orig - recon).abs() < max_error,
"Int4 roundtrip error too large: orig={}, recon={}",
orig,
recon
);
}
}
#[test]
fn test_int4_distance() {
// Use vectors with different quantized patterns
// v1 spans [0.0, 15.0] -> quantizes to [0, 1, 2, ..., 15] (linear mapping)
// v2 spans [0.0, 15.0] but with different distribution
let v1 = vec![0.0, 5.0, 10.0, 15.0];
let v2 = vec![0.0, 3.0, 12.0, 15.0]; // Different middle values
let q1 = Int4Quantized::quantize(&v1);
let q2 = Int4Quantized::quantize(&v2);
let dist = q1.distance(&q2);
// The quantized values differ in the middle, so distance should be positive
assert!(
dist > 0.0,
"Distance should be positive, got {}. q1.data={:?}, q2.data={:?}",
dist,
q1.data,
q2.data
);
}
#[test]
fn test_int4_distance_symmetry() {
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let q1 = Int4Quantized::quantize(&v1);
let q2 = Int4Quantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
assert!(
(dist_ab - dist_ba).abs() < 0.01,
"Int4 distance not symmetric: d(a,b)={}, d(b,a)={}",
dist_ab,
dist_ba
);
}
#[test]
fn test_int4_compression_ratio() {
assert_eq!(Int4Quantized::compression_ratio(), 8.0);
}
#[test]
fn test_binary_fast_hamming() {
// Test fast hamming distance with various sizes
let a = vec![0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xAA];
let b = vec![0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x55];
let distance = BinaryQuantized::hamming_distance_fast(&a, &b);
// All bits differ: 9 bytes * 8 bits = 72 bits
assert_eq!(distance, 72);
}
#[test]
fn test_binary_similarity() {
let v1 = vec![1.0; 8]; // All positive
let v2 = vec![1.0; 8]; // Same
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let sim = q1.similarity(&q2);
assert!(
(sim - 1.0).abs() < 0.001,
"Same vectors should have similarity 1.0"
);
}
#[test]
fn test_binary_compression_ratio() {
assert_eq!(BinaryQuantized::compression_ratio(), 32.0);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,446 @@
//! Storage layer with redb for metadata and memory-mapped vectors
//!
//! This module is only available when the "storage" feature is enabled.
//! For WASM builds, use the in-memory storage backend instead.
#[cfg(feature = "storage")]
use crate::error::{Result, RuvectorError};
#[cfg(feature = "storage")]
use crate::types::{DbOptions, VectorEntry, VectorId};
#[cfg(feature = "storage")]
use bincode::config;
#[cfg(feature = "storage")]
use once_cell::sync::Lazy;
#[cfg(feature = "storage")]
use parking_lot::Mutex;
#[cfg(feature = "storage")]
use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
#[cfg(feature = "storage")]
use serde_json;
#[cfg(feature = "storage")]
use std::collections::HashMap;
#[cfg(feature = "storage")]
use std::path::{Path, PathBuf};
#[cfg(feature = "storage")]
use std::sync::Arc;
#[cfg(feature = "storage")]
const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
const CONFIG_TABLE: TableDefinition<&str, &str> = TableDefinition::new("config");
/// Key used to store database configuration in CONFIG_TABLE
const DB_CONFIG_KEY: &str = "__ruvector_db_config__";
// Global database connection pool to allow multiple VectorDB instances
// to share the same underlying database file
static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
/// Storage backend for vector database
pub struct VectorStorage {
db: Arc<Database>,
dimensions: usize,
}
impl VectorStorage {
/// Create or open a vector storage at the given path
///
/// This method uses a global connection pool to allow multiple VectorDB
/// instances to share the same underlying database file, fixing the
/// "Database already open. Cannot acquire lock" error.
pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
// SECURITY: Validate path to prevent directory traversal attacks
let path_ref = path.as_ref();
// Create parent directories if they don't exist (needed for canonicalize)
if let Some(parent) = path_ref.parent() {
if !parent.as_os_str().is_empty() && !parent.exists() {
std::fs::create_dir_all(parent).map_err(|e| {
RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
})?;
}
}
// Convert to absolute path first, then validate
let path_buf = if path_ref.is_absolute() {
path_ref.to_path_buf()
} else {
std::env::current_dir()
.map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
.join(path_ref)
};
// SECURITY: Check for path traversal attempts (e.g., "../../../etc/passwd")
// Only reject paths that contain ".." components trying to escape
let path_str = path_ref.to_string_lossy();
if path_str.contains("..") {
// Verify the resolved path doesn't escape intended boundaries
// For absolute paths, we allow them as-is (user explicitly specified)
// For relative paths with "..", check they don't escape cwd
if !path_ref.is_absolute() {
if let Ok(cwd) = std::env::current_dir() {
// Normalize the path by resolving .. components
let mut normalized = cwd.clone();
for component in path_ref.components() {
match component {
std::path::Component::ParentDir => {
if !normalized.pop() || !normalized.starts_with(&cwd) {
return Err(RuvectorError::InvalidPath(
"Path traversal attempt detected".to_string(),
));
}
}
std::path::Component::Normal(c) => normalized.push(c),
_ => {}
}
}
}
}
}
// Check if we already have a Database instance for this path
let db = {
let mut pool = DB_POOL.lock();
if let Some(existing_db) = pool.get(&path_buf) {
// Reuse existing database connection
Arc::clone(existing_db)
} else {
// Create new database and add to pool
let new_db = Arc::new(Database::create(&path_buf)?);
// Initialize tables
let write_txn = new_db.begin_write()?;
{
let _ = write_txn.open_table(VECTORS_TABLE)?;
let _ = write_txn.open_table(METADATA_TABLE)?;
let _ = write_txn.open_table(CONFIG_TABLE)?;
}
write_txn.commit()?;
pool.insert(path_buf, Arc::clone(&new_db));
new_db
}
};
Ok(Self { db, dimensions })
}
/// Insert a vector entry
pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
if entry.vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: entry.vector.len(),
});
}
let id = entry
.id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(VECTORS_TABLE)?;
// Serialize vector data
let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
table.insert(id.as_str(), vector_data.as_slice())?;
// Store metadata if present
if let Some(metadata) = &entry.metadata {
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
let metadata_json = serde_json::to_string(metadata)
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
meta_table.insert(id.as_str(), metadata_json.as_str())?;
}
}
write_txn.commit()?;
Ok(id)
}
/// Insert multiple vectors in a batch
pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
let write_txn = self.db.begin_write()?;
let mut ids = Vec::with_capacity(entries.len());
{
let mut table = write_txn.open_table(VECTORS_TABLE)?;
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
for entry in entries {
if entry.vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: entry.vector.len(),
});
}
let id = entry
.id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
// Serialize and insert vector
let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
table.insert(id.as_str(), vector_data.as_slice())?;
// Insert metadata if present
if let Some(metadata) = &entry.metadata {
let metadata_json = serde_json::to_string(metadata)
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
meta_table.insert(id.as_str(), metadata_json.as_str())?;
}
ids.push(id);
}
}
write_txn.commit()?;
Ok(ids)
}
/// Get a vector by ID
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(VECTORS_TABLE)?;
let Some(vector_data) = table.get(id)? else {
return Ok(None);
};
let (vector, _): (Vec<f32>, usize) =
bincode::decode_from_slice(vector_data.value(), config::standard())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
// Try to get metadata
let meta_table = read_txn.open_table(METADATA_TABLE)?;
let metadata = if let Some(meta_data) = meta_table.get(id)? {
let meta_str = meta_data.value();
Some(
serde_json::from_str(meta_str)
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
)
} else {
None
};
Ok(Some(VectorEntry {
id: Some(id.to_string()),
vector,
metadata,
}))
}
/// Delete a vector by ID
pub fn delete(&self, id: &str) -> Result<bool> {
let write_txn = self.db.begin_write()?;
let deleted;
{
let mut table = write_txn.open_table(VECTORS_TABLE)?;
deleted = table.remove(id)?.is_some();
let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
let _ = meta_table.remove(id)?;
}
write_txn.commit()?;
Ok(deleted)
}
/// Get the number of vectors stored
pub fn len(&self) -> Result<usize> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(VECTORS_TABLE)?;
Ok(table.len()? as usize)
}
/// Check if storage is empty
pub fn is_empty(&self) -> Result<bool> {
Ok(self.len()? == 0)
}
/// Get all vector IDs
pub fn all_ids(&self) -> Result<Vec<VectorId>> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(VECTORS_TABLE)?;
let mut ids = Vec::new();
let iter = table.iter()?;
for item in iter {
let (key, _) = item?;
ids.push(key.value().to_string());
}
Ok(ids)
}
/// Save database configuration to persistent storage
pub fn save_config(&self, options: &DbOptions) -> Result<()> {
let config_json = serde_json::to_string(options)
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(CONFIG_TABLE)?;
table.insert(DB_CONFIG_KEY, config_json.as_str())?;
}
write_txn.commit()?;
Ok(())
}
/// Load database configuration from persistent storage
pub fn load_config(&self) -> Result<Option<DbOptions>> {
let read_txn = self.db.begin_read()?;
// Try to open config table - may not exist in older databases
let table = match read_txn.open_table(CONFIG_TABLE) {
Ok(t) => t,
Err(_) => return Ok(None),
};
let Some(config_data) = table.get(DB_CONFIG_KEY)? else {
return Ok(None);
};
let config: DbOptions = serde_json::from_str(config_data.value())
.map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
Ok(Some(config))
}
/// Get the stored dimensions
pub fn dimensions(&self) -> usize {
self.dimensions
}
}
// Add uuid dependency
use uuid;
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_insert_and_get() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let entry = VectorEntry {
id: Some("test1".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
};
let id = storage.insert(&entry)?;
assert_eq!(id, "test1");
let retrieved = storage.get("test1")?;
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
Ok(())
}
#[test]
fn test_batch_insert() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let entries = vec![
VectorEntry {
id: None,
vector: vec![1.0, 2.0, 3.0],
metadata: None,
},
VectorEntry {
id: None,
vector: vec![4.0, 5.0, 6.0],
metadata: None,
},
];
let ids = storage.insert_batch(&entries)?;
assert_eq!(ids.len(), 2);
assert_eq!(storage.len()?, 2);
Ok(())
}
#[test]
fn test_delete() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let entry = VectorEntry {
id: Some("test1".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
};
storage.insert(&entry)?;
assert_eq!(storage.len()?, 1);
let deleted = storage.delete("test1")?;
assert!(deleted);
assert_eq!(storage.len()?, 0);
Ok(())
}
#[test]
fn test_multiple_instances_same_path() -> Result<()> {
// This test verifies the fix for the database locking bug
// Multiple VectorStorage instances should be able to share the same database file
let dir = tempdir().unwrap();
let db_path = dir.path().join("shared.db");
// Create first instance
let storage1 = VectorStorage::new(&db_path, 3)?;
// Insert data with first instance
storage1.insert(&VectorEntry {
id: Some("test1".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
})?;
// Create second instance with SAME path - this should NOT fail
let storage2 = VectorStorage::new(&db_path, 3)?;
// Both instances should see the same data
assert_eq!(storage1.len()?, 1);
assert_eq!(storage2.len()?, 1);
// Insert with second instance
storage2.insert(&VectorEntry {
id: Some("test2".to_string()),
vector: vec![4.0, 5.0, 6.0],
metadata: None,
})?;
// Both instances should see both records
assert_eq!(storage1.len()?, 2);
assert_eq!(storage2.len()?, 2);
// Verify data integrity
let retrieved1 = storage1.get("test1")?;
assert!(retrieved1.is_some());
let retrieved2 = storage2.get("test2")?;
assert!(retrieved2.is_some());
Ok(())
}
}

View File

@@ -0,0 +1,79 @@
//! Storage compatibility layer
//!
//! This module provides a unified interface that works with both
//! file-based (redb) and in-memory storage backends.
use crate::error::Result;
use crate::types::{VectorEntry, VectorId};
#[cfg(feature = "storage")]
pub use crate::storage::VectorStorage;
#[cfg(not(feature = "storage"))]
pub use crate::storage_memory::MemoryStorage as VectorStorage;
/// Unified storage trait
pub trait StorageBackend {
fn insert(&self, entry: &VectorEntry) -> Result<VectorId>;
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>>;
fn get(&self, id: &str) -> Result<Option<VectorEntry>>;
fn delete(&self, id: &str) -> Result<bool>;
fn len(&self) -> Result<usize>;
fn is_empty(&self) -> Result<bool>;
}
// Implement trait for redb-based storage
#[cfg(feature = "storage")]
impl StorageBackend for crate::storage::VectorStorage {
fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
self.insert(entry)
}
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
self.insert_batch(entries)
}
fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
self.get(id)
}
fn delete(&self, id: &str) -> Result<bool> {
self.delete(id)
}
fn len(&self) -> Result<usize> {
self.len()
}
fn is_empty(&self) -> Result<bool> {
self.is_empty()
}
}
// Implement trait for memory storage
#[cfg(not(feature = "storage"))]
impl StorageBackend for crate::storage_memory::MemoryStorage {
fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
self.insert(entry)
}
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
self.insert_batch(entries)
}
fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
self.get(id)
}
fn delete(&self, id: &str) -> Result<bool> {
self.delete(id)
}
fn len(&self) -> Result<usize> {
self.len()
}
fn is_empty(&self) -> Result<bool> {
self.is_empty()
}
}

View File

@@ -0,0 +1,257 @@
//! In-memory storage backend for WASM and testing
//!
//! This storage implementation doesn't require file system access,
//! making it suitable for WebAssembly environments.
use crate::error::{Result, RuvectorError};
use crate::types::{VectorEntry, VectorId};
use dashmap::DashMap;
use serde_json::Value as JsonValue;
use std::sync::atomic::{AtomicU64, Ordering};
/// In-memory storage backend using DashMap for thread-safe concurrent access
pub struct MemoryStorage {
vectors: DashMap<String, Vec<f32>>,
metadata: DashMap<String, JsonValue>,
dimensions: usize,
counter: AtomicU64,
}
impl MemoryStorage {
/// Create a new in-memory storage
pub fn new(dimensions: usize) -> Result<Self> {
Ok(Self {
vectors: DashMap::new(),
metadata: DashMap::new(),
dimensions,
counter: AtomicU64::new(0),
})
}
/// Generate a new unique ID
fn generate_id(&self) -> String {
let id = self.counter.fetch_add(1, Ordering::SeqCst);
format!("vec_{}", id)
}
/// Insert a vector entry
pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
if entry.vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: entry.vector.len(),
});
}
let id = entry.id.clone().unwrap_or_else(|| self.generate_id());
// Insert vector
self.vectors.insert(id.clone(), entry.vector.clone());
// Insert metadata if present
if let Some(metadata) = &entry.metadata {
self.metadata.insert(
id.clone(),
serde_json::Value::Object(
metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
),
);
}
Ok(id)
}
/// Insert multiple vectors in a batch
pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
let mut ids = Vec::with_capacity(entries.len());
for entry in entries {
if entry.vector.len() != self.dimensions {
return Err(RuvectorError::DimensionMismatch {
expected: self.dimensions,
actual: entry.vector.len(),
});
}
let id = entry.id.clone().unwrap_or_else(|| self.generate_id());
self.vectors.insert(id.clone(), entry.vector.clone());
if let Some(metadata) = &entry.metadata {
self.metadata.insert(
id.clone(),
serde_json::Value::Object(
metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
),
);
}
ids.push(id);
}
Ok(ids)
}
/// Get a vector by ID
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
if let Some(vector_ref) = self.vectors.get(id) {
let vector = vector_ref.clone();
let metadata = self.metadata.get(id).and_then(|m| {
if let serde_json::Value::Object(map) = m.value() {
Some(map.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
} else {
None
}
});
Ok(Some(VectorEntry {
id: Some(id.to_string()),
vector,
metadata,
}))
} else {
Ok(None)
}
}
/// Delete a vector by ID
pub fn delete(&self, id: &str) -> Result<bool> {
let vector_removed = self.vectors.remove(id).is_some();
self.metadata.remove(id);
Ok(vector_removed)
}
/// Get the number of vectors stored
pub fn len(&self) -> Result<usize> {
Ok(self.vectors.len())
}
/// Check if the storage is empty
pub fn is_empty(&self) -> Result<bool> {
Ok(self.vectors.is_empty())
}
/// Get all vector IDs (for iteration)
pub fn keys(&self) -> Vec<String> {
self.vectors
.iter()
.map(|entry| entry.key().clone())
.collect()
}
/// Get all vector IDs (alias for keys, for API compatibility with VectorStorage)
pub fn all_ids(&self) -> Result<Vec<String>> {
Ok(self.keys())
}
/// Clear all data
pub fn clear(&self) -> Result<()> {
self.vectors.clear();
self.metadata.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_insert_and_get() {
let storage = MemoryStorage::new(128).unwrap();
let entry = VectorEntry {
id: Some("test_1".to_string()),
vector: vec![0.1; 128],
metadata: Some(json!({"key": "value"})),
};
let id = storage.insert(&entry).unwrap();
assert_eq!(id, "test_1");
let retrieved = storage.get("test_1").unwrap().unwrap();
assert_eq!(retrieved.vector.len(), 128);
assert!(retrieved.metadata.is_some());
}
#[test]
fn test_batch_insert() {
let storage = MemoryStorage::new(64).unwrap();
let entries: Vec<_> = (0..10)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 64],
metadata: None,
})
.collect();
let ids = storage.insert_batch(&entries).unwrap();
assert_eq!(ids.len(), 10);
assert_eq!(storage.len().unwrap(), 10);
}
#[test]
fn test_delete() {
let storage = MemoryStorage::new(32).unwrap();
let entry = VectorEntry {
id: Some("delete_me".to_string()),
vector: vec![1.0; 32],
metadata: None,
};
storage.insert(&entry).unwrap();
assert_eq!(storage.len().unwrap(), 1);
let deleted = storage.delete("delete_me").unwrap();
assert!(deleted);
assert_eq!(storage.len().unwrap(), 0);
}
#[test]
fn test_auto_id_generation() {
let storage = MemoryStorage::new(16).unwrap();
let entry = VectorEntry {
id: None,
vector: vec![0.5; 16],
metadata: None,
};
let id1 = storage.insert(&entry).unwrap();
let id2 = storage.insert(&entry).unwrap();
assert_ne!(id1, id2);
assert!(id1.starts_with("vec_"));
assert!(id2.starts_with("vec_"));
}
#[test]
fn test_dimension_mismatch() {
let storage = MemoryStorage::new(128).unwrap();
let entry = VectorEntry {
id: Some("bad".to_string()),
vector: vec![0.1; 64], // Wrong dimension
metadata: None,
};
let result = storage.insert(&entry);
assert!(result.is_err());
if let Err(RuvectorError::DimensionMismatch { expected, actual }) = result {
assert_eq!(expected, 128);
assert_eq!(actual, 64);
} else {
panic!("Expected DimensionMismatch error");
}
}
}

View File

@@ -0,0 +1,126 @@
//! Core types and data structures
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Unique identifier for vectors
pub type VectorId = String;
/// Distance metric for similarity calculation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistanceMetric {
/// Euclidean (L2) distance
Euclidean,
/// Cosine similarity (converted to distance)
Cosine,
/// Dot product (converted to distance for maximization)
DotProduct,
/// Manhattan (L1) distance
Manhattan,
}
/// Vector entry with metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorEntry {
/// Optional ID (auto-generated if not provided)
pub id: Option<VectorId>,
/// Vector data
pub vector: Vec<f32>,
/// Optional metadata
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
/// Search query parameters
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchQuery {
/// Query vector
pub vector: Vec<f32>,
/// Number of results to return (top-k)
pub k: usize,
/// Optional metadata filters
pub filter: Option<HashMap<String, serde_json::Value>>,
/// Optional ef_search parameter for HNSW (overrides default)
pub ef_search: Option<usize>,
}
/// Search result with similarity score
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
/// Vector ID
pub id: VectorId,
/// Distance/similarity score (lower is better for distance metrics)
pub score: f32,
/// Vector data (optional)
pub vector: Option<Vec<f32>>,
/// Metadata (optional)
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
/// Database configuration options
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbOptions {
/// Vector dimensions
pub dimensions: usize,
/// Distance metric
pub distance_metric: DistanceMetric,
/// Storage path
pub storage_path: String,
/// HNSW configuration
pub hnsw_config: Option<HnswConfig>,
/// Quantization configuration
pub quantization: Option<QuantizationConfig>,
}
/// HNSW index configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswConfig {
/// Number of connections per layer (M)
pub m: usize,
/// Size of dynamic candidate list during construction (efConstruction)
pub ef_construction: usize,
/// Size of dynamic candidate list during search (efSearch)
pub ef_search: usize,
/// Maximum number of elements
pub max_elements: usize,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
m: 32,
ef_construction: 200,
ef_search: 100,
max_elements: 10_000_000,
}
}
}
/// Quantization configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantizationConfig {
/// No quantization (full precision)
None,
/// Scalar quantization to int8 (4x compression)
Scalar,
/// Product quantization
Product {
/// Number of subspaces
subspaces: usize,
/// Codebook size (typically 256)
k: usize,
},
/// Binary quantization (32x compression)
Binary,
}
impl Default for DbOptions {
fn default() -> Self {
Self {
dimensions: 384,
distance_metric: DistanceMetric::Cosine,
storage_path: "./ruvector.db".to_string(),
hnsw_config: Some(HnswConfig::default()),
quantization: Some(QuantizationConfig::Scalar),
}
}
}

View File

@@ -0,0 +1,391 @@
//! Main VectorDB interface
use crate::error::Result;
use crate::index::flat::FlatIndex;
#[cfg(feature = "hnsw")]
use crate::index::hnsw::HnswIndex;
use crate::index::VectorIndex;
use crate::types::*;
use parking_lot::RwLock;
use std::sync::Arc;
// Import appropriate storage backend based on features
#[cfg(feature = "storage")]
use crate::storage::VectorStorage;
#[cfg(not(feature = "storage"))]
use crate::storage_memory::MemoryStorage as VectorStorage;
/// Main vector database
pub struct VectorDB {
storage: Arc<VectorStorage>,
index: Arc<RwLock<Box<dyn VectorIndex>>>,
options: DbOptions,
}
impl VectorDB {
/// Create a new vector database with the given options
///
/// If a storage path is provided and contains persisted vectors,
/// the HNSW index will be automatically rebuilt from storage.
/// If opening an existing database, the stored configuration (dimensions,
/// distance metric, etc.) will be used instead of the provided options.
#[allow(unused_mut)] // `options` is mutated only when feature = "storage"
pub fn new(mut options: DbOptions) -> Result<Self> {
#[cfg(feature = "storage")]
let storage = {
// First, try to load existing configuration from the database
// We create a temporary storage to check for config
let temp_storage = VectorStorage::new(&options.storage_path, options.dimensions)?;
let stored_config = temp_storage.load_config()?;
if let Some(config) = stored_config {
// Existing database - use stored configuration
tracing::info!(
"Loading existing database with {} dimensions",
config.dimensions
);
options = DbOptions {
// Keep the provided storage path (may have changed)
storage_path: options.storage_path.clone(),
// Use stored configuration for everything else
dimensions: config.dimensions,
distance_metric: config.distance_metric,
hnsw_config: config.hnsw_config,
quantization: config.quantization,
};
// Recreate storage with correct dimensions
Arc::new(VectorStorage::new(
&options.storage_path,
options.dimensions,
)?)
} else {
// New database - save the configuration
tracing::info!(
"Creating new database with {} dimensions",
options.dimensions
);
temp_storage.save_config(&options)?;
Arc::new(temp_storage)
}
};
#[cfg(not(feature = "storage"))]
let storage = Arc::new(VectorStorage::new(options.dimensions)?);
// Choose index based on configuration and available features
#[allow(unused_mut)] // `index` is mutated only when feature = "storage"
let mut index: Box<dyn VectorIndex> = if let Some(hnsw_config) = &options.hnsw_config {
#[cfg(feature = "hnsw")]
{
Box::new(HnswIndex::new(
options.dimensions,
options.distance_metric,
hnsw_config.clone(),
)?)
}
#[cfg(not(feature = "hnsw"))]
{
// Fall back to flat index if HNSW is not available
tracing::warn!("HNSW requested but not available (WASM build), using flat index");
Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
}
} else {
Box::new(FlatIndex::new(options.dimensions, options.distance_metric))
};
// Rebuild index from persisted vectors if storage is not empty
// This fixes the bug where search() returns empty results after restart
#[cfg(feature = "storage")]
{
let stored_ids = storage.all_ids()?;
if !stored_ids.is_empty() {
tracing::info!(
"Rebuilding index from {} persisted vectors",
stored_ids.len()
);
// Batch load all vectors for efficient index rebuilding
let mut entries = Vec::with_capacity(stored_ids.len());
for id in stored_ids {
if let Some(entry) = storage.get(&id)? {
entries.push((id, entry.vector));
}
}
// Add all vectors to index in batch for better performance
index.add_batch(entries)?;
tracing::info!("Index rebuilt successfully");
}
}
Ok(Self {
storage,
index: Arc::new(RwLock::new(index)),
options,
})
}
/// Create with default options
pub fn with_dimensions(dimensions: usize) -> Result<Self> {
let options = DbOptions {
dimensions,
..DbOptions::default()
};
Self::new(options)
}
/// Insert a vector entry
pub fn insert(&self, entry: VectorEntry) -> Result<VectorId> {
let id = self.storage.insert(&entry)?;
// Add to index
let mut index = self.index.write();
index.add(id.clone(), entry.vector)?;
Ok(id)
}
/// Insert multiple vectors in a batch
pub fn insert_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<VectorId>> {
let ids = self.storage.insert_batch(&entries)?;
// Add to index
let mut index = self.index.write();
let index_entries: Vec<_> = ids
.iter()
.zip(entries.iter())
.map(|(id, entry)| (id.clone(), entry.vector.clone()))
.collect();
index.add_batch(index_entries)?;
Ok(ids)
}
/// Search for similar vectors
pub fn search(&self, query: SearchQuery) -> Result<Vec<SearchResult>> {
let index = self.index.read();
let mut results = index.search(&query.vector, query.k)?;
// Enrich results with full data if needed
for result in &mut results {
if let Ok(Some(entry)) = self.storage.get(&result.id) {
result.vector = Some(entry.vector);
result.metadata = entry.metadata;
}
}
// Apply metadata filters if specified
if let Some(filter) = &query.filter {
results.retain(|r| {
if let Some(metadata) = &r.metadata {
filter
.iter()
.all(|(key, value)| metadata.get(key).is_some_and(|v| v == value))
} else {
false
}
});
}
Ok(results)
}
/// Delete a vector by ID
pub fn delete(&self, id: &str) -> Result<bool> {
let deleted_storage = self.storage.delete(id)?;
if deleted_storage {
let mut index = self.index.write();
let _ = index.remove(&id.to_string())?;
}
Ok(deleted_storage)
}
/// Get a vector by ID
pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
self.storage.get(id)
}
/// Get the number of vectors
pub fn len(&self) -> Result<usize> {
self.storage.len()
}
/// Check if database is empty
pub fn is_empty(&self) -> Result<bool> {
self.storage.is_empty()
}
/// Get database options
pub fn options(&self) -> &DbOptions {
&self.options
}
/// Get all vector IDs (for iteration/serialization)
pub fn keys(&self) -> Result<Vec<String>> {
self.storage.all_ids()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
use tempfile::tempdir;
#[test]
fn test_vector_db_creation() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
let db = VectorDB::new(options)?;
assert!(db.is_empty()?);
Ok(())
}
#[test]
fn test_insert_and_search() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
options.distance_metric = DistanceMetric::Euclidean; // Use Euclidean for clearer test
options.hnsw_config = None; // Use flat index for testing
let db = VectorDB::new(options)?;
// Insert vectors
db.insert(VectorEntry {
id: Some("v1".to_string()),
vector: vec![1.0, 0.0, 0.0],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("v2".to_string()),
vector: vec![0.0, 1.0, 0.0],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("v3".to_string()),
vector: vec![0.0, 0.0, 1.0],
metadata: None,
})?;
// Search for exact match
let results = db.search(SearchQuery {
vector: vec![1.0, 0.0, 0.0],
k: 2,
filter: None,
ef_search: None,
})?;
assert!(results.len() >= 1);
assert_eq!(results[0].id, "v1", "First result should be exact match");
assert!(
results[0].score < 0.01,
"Exact match should have ~0 distance"
);
Ok(())
}
/// Test that search works after simulated restart (new VectorDB instance)
/// This verifies the fix for issue #30: HNSW index not rebuilt from storage
#[test]
#[cfg(feature = "storage")]
fn test_search_after_restart() -> Result<()> {
let dir = tempdir().unwrap();
let db_path = dir.path().join("persist.db").to_string_lossy().to_string();
// Phase 1: Create database and insert vectors
{
let mut options = DbOptions::default();
options.storage_path = db_path.clone();
options.dimensions = 3;
options.distance_metric = DistanceMetric::Euclidean;
options.hnsw_config = None;
let db = VectorDB::new(options)?;
db.insert(VectorEntry {
id: Some("v1".to_string()),
vector: vec![1.0, 0.0, 0.0],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("v2".to_string()),
vector: vec![0.0, 1.0, 0.0],
metadata: None,
})?;
db.insert(VectorEntry {
id: Some("v3".to_string()),
vector: vec![0.7, 0.7, 0.0],
metadata: None,
})?;
// Verify search works before "restart"
let results = db.search(SearchQuery {
vector: vec![0.8, 0.6, 0.0],
k: 3,
filter: None,
ef_search: None,
})?;
assert_eq!(results.len(), 3, "Should find all 3 vectors before restart");
}
// db is dropped here, simulating application shutdown
// Phase 2: Create new database instance (simulates restart)
{
let mut options = DbOptions::default();
options.storage_path = db_path.clone();
options.dimensions = 3;
options.distance_metric = DistanceMetric::Euclidean;
options.hnsw_config = None;
let db = VectorDB::new(options)?;
// Verify vectors are still accessible
assert_eq!(db.len()?, 3, "Should have 3 vectors after restart");
// Verify get() works
let v1 = db.get("v1")?;
assert!(v1.is_some(), "get() should work after restart");
// Verify search() works - THIS WAS THE BUG
let results = db.search(SearchQuery {
vector: vec![0.8, 0.6, 0.0],
k: 3,
filter: None,
ef_search: None,
})?;
assert_eq!(
results.len(),
3,
"search() should return results after restart (was returning 0 before fix)"
);
// v3 should be closest to query [0.8, 0.6, 0.0]
assert_eq!(
results[0].id, "v3",
"v3 [0.7, 0.7, 0.0] should be closest to query [0.8, 0.6, 0.0]"
);
}
Ok(())
}
}

View File

@@ -0,0 +1,251 @@
# Ruvector Core Test Suite
## Overview
This directory contains a comprehensive Test-Driven Development (TDD) test suite following the London School approach. The test suite covers unit tests, integration tests, property-based tests, stress tests, and concurrent access tests.
## Test Files
### 1. `unit_tests.rs` - Unit Tests with Mocking (London School)
Comprehensive unit tests using `mockall` for mocking dependencies:
- **Distance Metric Tests**: Tests for all 4 distance metrics (Euclidean, Cosine, Dot Product, Manhattan)
- Self-distance verification
- Symmetry properties
- Orthogonal and parallel vector cases
- Dimension mismatch error handling
- **Quantization Tests**: Tests for scalar and binary quantization
- Round-trip reconstruction accuracy
- Distance calculation correctness
- Sign preservation (binary quantization)
- Hamming distance validation
- **Storage Layer Tests**: Tests for VectorStorage
- Insert with explicit and auto-generated IDs
- Metadata handling
- Dimension validation
- Batch operations
- Delete operations
- Error cases (non-existent vectors, dimension mismatches)
- **VectorDB Tests**: High-level API tests
- Empty database operations
- Insert/delete with len() tracking
- Search functionality
- Metadata filtering
- Batch insert operations
### 2. `integration_tests.rs` - End-to-End Integration Tests
Full workflow tests that verify all components work together:
- **Complete Workflows**: Insert + search + retrieve with metadata
- **Large Batch Operations**: 10K+ vector batch insertions
- **Persistence**: Database save and reload verification
- **Mixed Operations**: Combined insert, delete, and search operations
- **Distance Metrics**: Tests for all 4 metrics end-to-end
- **HNSW Configurations**: Different HNSW parameter combinations
- **Metadata Filtering**: Complex filtering scenarios
- **Error Handling**: Dimension validation, wrong query dimensions
### 3. `property_tests.rs` - Property-Based Tests (proptest)
Mathematical property verification using proptest:
- **Distance Metric Properties**:
- Self-distance is zero
- Symmetry: d(a,b) = d(b,a)
- Triangle inequality: d(a,c) ≤ d(a,b) + d(b,c)
- Non-negativity
- Scale invariance (cosine)
- Translation invariance (Euclidean)
- **Quantization Properties**:
- Round-trip reconstruction bounds
- Sign preservation (binary)
- Self-distance is zero
- Symmetry
- Distance bounds
- **Batch Operations**:
- Consistency between batch and individual operations
- **Dimension Handling**:
- Mismatch error detection
- Success on matching dimensions
### 4. `stress_tests.rs` - Scalability and Performance Stress Tests
Tests that push the system to its limits:
- **Million Vector Insertion** (ignored by default): Insert 1M vectors in batches
- **Concurrent Queries**: 10 threads × 100 queries each
- **Concurrent Mixed Operations**: Simultaneous readers and writers
- **Memory Pressure**: Large 2048-dimensional vectors
- **Error Recovery**: Invalid operations don't crash the system
- **Repeated Operations**: Same operation executed many times
- **Extreme Parameters**: k values larger than database size
### 5. `concurrent_tests.rs` - Thread-Safety Tests
Multi-threaded access patterns:
- **Concurrent Reads**: Multiple threads reading simultaneously
- **Concurrent Writes**: Non-overlapping writes from multiple threads
- **Mixed Read/Write**: Concurrent reads and writes
- **Delete and Insert**: Simultaneous deletes and inserts
- **Search and Insert**: Searching while inserting
- **Batch Atomicity**: Verifying batch operations are atomic
- **Read-Write Consistency**: Ensuring no data corruption
- **Metadata Updates**: Concurrent metadata modifications
## Benchmarks
### 6. `benches/quantization_bench.rs` - Quantization Performance
Criterion benchmarks for quantization operations:
- Scalar quantization encode/decode/distance
- Binary quantization encode/decode/distance
- Compression ratio comparisons
### 7. `benches/batch_operations.rs` - Batch Operation Performance
Criterion benchmarks for batch operations:
- Batch insert at various scales (100, 1K, 10K)
- Individual vs batch insert comparison
- Parallel search performance
- Batch delete operations
## Running Tests
### Run All Tests
```bash
cargo test --package ruvector-core
```
### Run Specific Test Suites
```bash
# Unit tests only
cargo test --test unit_tests
# Integration tests only
cargo test --test integration_tests
# Property tests only
cargo test --test property_tests
# Concurrent tests only
cargo test --test concurrent_tests
# Stress tests (including ignored tests)
cargo test --test stress_tests -- --ignored --test-threads=1
```
### Run Benchmarks
```bash
# Distance metrics (existing)
cargo bench --bench distance_metrics
# HNSW search (existing)
cargo bench --bench hnsw_search
# Quantization (new)
cargo bench --bench quantization_bench
# Batch operations (new)
cargo bench --bench batch_operations
```
### Generate Coverage Report
```bash
# Install tarpaulin if not already installed
cargo install cargo-tarpaulin
# Generate HTML coverage report
cargo tarpaulin --out Html --output-dir target/coverage
# Open coverage report
open target/coverage/index.html
```
## Test Coverage Goals
- **Target**: 90%+ code coverage
- **Focus Areas**:
- Distance calculations
- Index operations
- Storage layer
- Error handling paths
- Edge cases
## Known Issues
As of the current implementation, there are pre-existing compilation errors in the codebase that prevent some tests from running:
1. **HNSW Index**: `DataId::new` construction issues in `src/index/hnsw.rs`
2. **AgenticDB**: Missing `Encode`/`Decode` trait implementations for `ReflexionEpisode`
These issues exist in the main codebase and need to be fixed before the full test suite can execute.
## Test Organization
Tests are organized by purpose and scope:
1. **Unit Tests**: Test individual components in isolation with mocking
2. **Integration Tests**: Test component interactions and workflows
3. **Property Tests**: Test mathematical properties and invariants
4. **Stress Tests**: Test performance limits and edge cases
5. **Concurrent Tests**: Test thread-safety and concurrent access patterns
## Dependencies
Test dependencies (in `Cargo.toml`):
```toml
[dev-dependencies]
criterion = { workspace = true }
proptest = { workspace = true }
mockall = { workspace = true }
tempfile = "3.13"
tracing-subscriber = { workspace = true }
```
## Contributing
When adding new tests:
1. Follow the existing structure and naming conventions
2. Add tests to the appropriate file (unit, integration, property, etc.)
3. Document the test purpose clearly
4. Ensure tests are deterministic and don't depend on timing
5. Use `tempdir()` for database paths in tests
6. Clean up resources properly
## CI/CD Integration
Recommended CI pipeline:
```yaml
test:
script:
- cargo test --all-features
- cargo tarpaulin --out Xml
coverage: '/\d+\.\d+% coverage/'
bench:
script:
- cargo bench --no-run
```
## Performance Expectations
Based on stress tests:
- **Insert**: ~10K vectors/second (batch mode)
- **Search**: ~1K queries/second (k=10, HNSW)
- **Concurrent**: 10+ threads without performance degradation
- **Memory**: ~4KB per 384-dim vector (uncompressed)
## Additional Resources
- [Mockall Documentation](https://docs.rs/mockall/)
- [Proptest Guide](https://altsysrq.github.io/proptest-book/)
- [Criterion.rs Guide](https://bheisler.github.io/criterion.rs/book/)
- [Cargo Tarpaulin](https://github.com/xd009642/tarpaulin)

View File

@@ -0,0 +1,550 @@
//! Integration tests for advanced features
//!
//! Tests Enhanced PQ, Filtered Search, MMR, Hybrid Search, and Conformal Prediction
//! across multiple vector dimensions (128D, 384D, 768D)
use ruvector_core::advanced_features::*;
use ruvector_core::types::{DistanceMetric, SearchResult};
use ruvector_core::{Result, RuvectorError};
use std::collections::HashMap;
// Helper function to generate random vectors
fn generate_vectors(count: usize, dimensions: usize) -> Vec<Vec<f32>> {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..count)
.map(|_| (0..dimensions).map(|_| rng.gen::<f32>()).collect())
.collect()
}
// Helper function to normalize vectors
fn normalize_vector(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
#[test]
fn test_enhanced_pq_128d() {
let dimensions = 128;
let num_vectors = 1000;
let config = PQConfig {
num_subspaces: 8,
codebook_size: 256,
num_iterations: 10,
metric: DistanceMetric::Euclidean,
};
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
// Generate training data
let training_data = generate_vectors(num_vectors, dimensions);
pq.train(&training_data).unwrap();
// Test encoding and search
let query = normalize_vector(&generate_vectors(1, dimensions)[0]);
// Add quantized vectors
for (i, vector) in training_data.iter().enumerate() {
pq.add_quantized(format!("vec_{}", i), vector).unwrap();
}
// Perform search
let results = pq.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
// Test compression ratio
let compression_ratio = pq.compression_ratio();
assert!(compression_ratio >= 8.0); // Should be 16x for 128D with 8 subspaces
println!(
"✓ Enhanced PQ 128D: compression ratio = {:.1}x",
compression_ratio
);
}
#[test]
fn test_enhanced_pq_384d() {
let dimensions = 384;
let num_vectors = 500;
let config = PQConfig {
num_subspaces: 8,
codebook_size: 256,
num_iterations: 10,
metric: DistanceMetric::Cosine,
};
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
// Generate training data
let training_data: Vec<Vec<f32>> = generate_vectors(num_vectors, dimensions)
.into_iter()
.map(|v| normalize_vector(&v))
.collect();
pq.train(&training_data).unwrap();
// Test reconstruction
let test_vector = &training_data[0];
let codes = pq.encode(test_vector).unwrap();
let reconstructed = pq.reconstruct(&codes).unwrap();
assert_eq!(reconstructed.len(), dimensions);
// Calculate reconstruction error
let error: f32 = test_vector
.iter()
.zip(&reconstructed)
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
println!("✓ Enhanced PQ 384D: reconstruction error = {:.4}", error);
assert!(error < 5.0); // Reasonable reconstruction error
}
#[test]
fn test_enhanced_pq_768d() {
let dimensions = 768;
let num_vectors = 300; // Increased to ensure we have enough vectors for search
let config = PQConfig {
num_subspaces: 16,
codebook_size: 256,
num_iterations: 10,
metric: DistanceMetric::Euclidean,
};
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
let training_data = generate_vectors(num_vectors, dimensions);
pq.train(&training_data).unwrap();
// Test lookup table creation
let query = generate_vectors(1, dimensions)[0].clone();
let lookup_table = pq.create_lookup_table(&query).unwrap();
assert_eq!(lookup_table.tables.len(), 16);
assert_eq!(lookup_table.tables[0].len(), 256);
println!("✓ Enhanced PQ 768D: lookup table created successfully");
}
#[test]
fn test_filtered_search_pre_filter() {
use serde_json::json;
// Create metadata store
let mut metadata_store = HashMap::new();
for i in 0..100 {
let mut metadata = HashMap::new();
metadata.insert(
"category".to_string(),
json!(if i % 3 == 0 { "A" } else { "B" }),
);
metadata.insert("price".to_string(), json!(i as f32 * 10.0));
metadata_store.insert(format!("vec_{}", i), metadata);
}
// Create filter: category == "A" AND price < 500
let filter = FilterExpression::And(vec![
FilterExpression::Eq("category".to_string(), json!("A")),
FilterExpression::Lt("price".to_string(), json!(500.0)),
]);
let search = FilteredSearch::new(filter, FilterStrategy::PreFilter, metadata_store);
// Test pre-filtering
let filtered_ids = search.get_filtered_ids();
assert!(!filtered_ids.is_empty());
assert!(filtered_ids.len() < 50); // Should be selective
println!(
"✓ Filtered Search (Pre-filter): {} matching documents",
filtered_ids.len()
);
}
#[test]
fn test_filtered_search_auto_strategy() {
use serde_json::json;
let mut metadata_store = HashMap::new();
for i in 0..1000 {
let mut metadata = HashMap::new();
metadata.insert("id".to_string(), json!(i));
metadata_store.insert(format!("vec_{}", i), metadata);
}
// Highly selective filter (should choose pre-filter)
let selective_filter = FilterExpression::Eq("id".to_string(), json!(42));
let search1 = FilteredSearch::new(
selective_filter,
FilterStrategy::Auto,
metadata_store.clone(),
);
assert_eq!(search1.auto_select_strategy(), FilterStrategy::PreFilter);
// Less selective filter (should choose post-filter)
let broad_filter = FilterExpression::Gte("id".to_string(), json!(0));
let search2 = FilteredSearch::new(broad_filter, FilterStrategy::Auto, metadata_store);
assert_eq!(search2.auto_select_strategy(), FilterStrategy::PostFilter);
println!("✓ Filtered Search: automatic strategy selection working");
}
#[test]
fn test_mmr_diversity_128d() {
let dimensions = 128;
let config = MMRConfig {
lambda: 0.5, // Balance relevance and diversity
metric: DistanceMetric::Cosine,
fetch_multiplier: 2.0,
};
let mmr = MMRSearch::new(config).unwrap();
// Create query and candidates
let query = normalize_vector(&generate_vectors(1, dimensions)[0]);
let candidates: Vec<SearchResult> = (0..20)
.map(|i| SearchResult {
id: format!("doc_{}", i),
score: i as f32 * 0.05,
vector: Some(normalize_vector(&generate_vectors(1, dimensions)[0])),
metadata: None,
})
.collect();
// Rerank using MMR
let results = mmr.rerank(&query, candidates, 10).unwrap();
assert_eq!(results.len(), 10);
// Results should be diverse (not just top-10 by relevance)
println!("✓ MMR 128D: diversified {} results", results.len());
}
#[test]
fn test_mmr_lambda_variations() {
let dimensions = 64;
// Test with pure relevance (lambda = 1.0)
let config_relevance = MMRConfig {
lambda: 1.0,
metric: DistanceMetric::Euclidean,
fetch_multiplier: 2.0,
};
let mmr_relevance = MMRSearch::new(config_relevance).unwrap();
// Test with pure diversity (lambda = 0.0)
let config_diversity = MMRConfig {
lambda: 0.0,
metric: DistanceMetric::Euclidean,
fetch_multiplier: 2.0,
};
let mmr_diversity = MMRSearch::new(config_diversity).unwrap();
let query = generate_vectors(1, dimensions)[0].clone();
let candidates: Vec<SearchResult> = (0..10)
.map(|i| SearchResult {
id: format!("doc_{}", i),
score: i as f32 * 0.1,
vector: Some(generate_vectors(1, dimensions)[0].clone()),
metadata: None,
})
.collect();
let results_relevance = mmr_relevance.rerank(&query, candidates.clone(), 5).unwrap();
let results_diversity = mmr_diversity.rerank(&query, candidates, 5).unwrap();
assert_eq!(results_relevance.len(), 5);
assert_eq!(results_diversity.len(), 5);
println!("✓ MMR: lambda variations tested successfully");
}
#[test]
fn test_hybrid_search_basic() {
let config = HybridConfig {
vector_weight: 0.7,
keyword_weight: 0.3,
normalization: NormalizationStrategy::MinMax,
};
let mut hybrid = HybridSearch::new(config);
// Index documents
hybrid.index_document("doc1".to_string(), "rust programming language".to_string());
hybrid.index_document("doc2".to_string(), "python machine learning".to_string());
hybrid.index_document("doc3".to_string(), "rust systems programming".to_string());
hybrid.finalize_indexing();
// Test BM25 scoring
let score = hybrid.bm25.score(
"rust programming",
&"doc1".to_string(),
"rust programming language",
);
assert!(score > 0.0);
println!(
"✓ Hybrid Search: indexed {} documents",
hybrid.doc_texts.len()
);
}
#[test]
fn test_hybrid_search_keyword_matching() {
let mut bm25 = BM25::new(1.5, 0.75);
bm25.index_document("doc1".to_string(), "vector database with HNSW indexing");
bm25.index_document("doc2".to_string(), "relational database management system");
bm25.index_document("doc3".to_string(), "vector search and similarity matching");
bm25.build_idf();
// Test candidate retrieval
let candidates = bm25.get_candidate_docs("vector database");
assert!(candidates.contains(&"doc1".to_string()));
assert!(candidates.contains(&"doc3".to_string()));
// Test scoring
let score1 = bm25.score(
"vector database",
&"doc1".to_string(),
"vector database with HNSW indexing",
);
let score2 = bm25.score(
"vector database",
&"doc2".to_string(),
"relational database management system",
);
assert!(score1 > score2); // doc1 matches better
println!(
"✓ Hybrid Search (BM25): {} candidate documents",
candidates.len()
);
}
#[test]
fn test_conformal_prediction_128d() {
let dimensions = 128;
let config = ConformalConfig {
alpha: 0.1, // 90% coverage
calibration_fraction: 0.2,
nonconformity_measure: NonconformityMeasure::Distance,
};
let mut predictor = ConformalPredictor::new(config).unwrap();
// Create calibration data
let calibration_queries = generate_vectors(10, dimensions);
let true_neighbors: Vec<Vec<String>> = (0..10)
.map(|i| vec![format!("vec_{}", i), format!("vec_{}", i + 1)])
.collect();
// Mock search function
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
Ok((0..k)
.map(|i| SearchResult {
id: format!("vec_{}", i),
score: i as f32 * 0.1,
vector: Some(vec![0.0; dimensions]),
metadata: None,
})
.collect())
};
// Calibrate
predictor
.calibrate(&calibration_queries, &true_neighbors, search_fn)
.unwrap();
assert!(predictor.threshold.is_some());
assert!(!predictor.calibration_scores.is_empty());
// Make prediction
let query = generate_vectors(1, dimensions)[0].clone();
let prediction_set = predictor.predict(&query, search_fn).unwrap();
assert_eq!(prediction_set.confidence, 0.9);
assert!(!prediction_set.results.is_empty());
println!(
"✓ Conformal Prediction 128D: prediction set size = {}",
prediction_set.results.len()
);
}
#[test]
fn test_conformal_prediction_384d() {
let dimensions = 384;
let config = ConformalConfig {
alpha: 0.05, // 95% coverage
calibration_fraction: 0.2,
nonconformity_measure: NonconformityMeasure::NormalizedDistance,
};
let mut predictor = ConformalPredictor::new(config).unwrap();
let calibration_queries = generate_vectors(5, dimensions);
let true_neighbors: Vec<Vec<String>> = (0..5).map(|i| vec![format!("vec_{}", i)]).collect();
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
Ok((0..k)
.map(|i| SearchResult {
id: format!("vec_{}", i),
score: 0.1 + (i as f32 * 0.05),
vector: Some(vec![0.0; dimensions]),
metadata: None,
})
.collect())
};
predictor
.calibrate(&calibration_queries, &true_neighbors, search_fn)
.unwrap();
// Test calibration statistics
let stats = predictor.get_statistics().unwrap();
assert_eq!(stats.num_samples, 5);
assert!(stats.mean > 0.0);
println!("✓ Conformal Prediction 384D: calibration stats computed");
}
#[test]
fn test_conformal_prediction_adaptive_k() {
let dimensions = 256;
let config = ConformalConfig {
alpha: 0.1,
calibration_fraction: 0.2,
nonconformity_measure: NonconformityMeasure::InverseRank,
};
let mut predictor = ConformalPredictor::new(config).unwrap();
let calibration_queries = generate_vectors(8, dimensions);
let true_neighbors: Vec<Vec<String>> = (0..8).map(|i| vec![format!("vec_{}", i)]).collect();
let search_fn = |_query: &[f32], k: usize| -> Result<Vec<SearchResult>> {
Ok((0..k)
.map(|i| SearchResult {
id: format!("vec_{}", i),
score: i as f32 * 0.08,
vector: Some(vec![0.0; dimensions]),
metadata: None,
})
.collect())
};
predictor
.calibrate(&calibration_queries, &true_neighbors, search_fn)
.unwrap();
// Test adaptive top-k
let query = generate_vectors(1, dimensions)[0].clone();
let adaptive_k = predictor.adaptive_top_k(&query, search_fn).unwrap();
assert!(adaptive_k > 0);
println!("✓ Conformal Prediction: adaptive k = {}", adaptive_k);
}
#[test]
fn test_all_features_integration() {
// Test that all features can work together
let dimensions = 128;
// 1. Enhanced PQ
let pq_config = PQConfig {
num_subspaces: 4,
codebook_size: 16,
num_iterations: 5,
metric: DistanceMetric::Euclidean,
};
let mut pq = EnhancedPQ::new(dimensions, pq_config).unwrap();
let training_data = generate_vectors(50, dimensions);
pq.train(&training_data).unwrap();
// 2. MMR
let mmr_config = MMRConfig::default();
let mmr = MMRSearch::new(mmr_config).unwrap();
// 3. Hybrid Search
let hybrid_config = HybridConfig::default();
let mut hybrid = HybridSearch::new(hybrid_config);
hybrid.index_document("doc1".to_string(), "test document".to_string());
hybrid.finalize_indexing();
// 4. Conformal Prediction
let cp_config = ConformalConfig::default();
let predictor = ConformalPredictor::new(cp_config).unwrap();
println!("✓ All features integrated successfully");
}
#[test]
fn test_pq_recall_384d() {
let dimensions = 384;
let num_vectors = 500;
let k = 10;
let config = PQConfig {
num_subspaces: 8,
codebook_size: 256,
num_iterations: 15,
metric: DistanceMetric::Euclidean,
};
let mut pq = EnhancedPQ::new(dimensions, config).unwrap();
// Generate and train
let vectors = generate_vectors(num_vectors, dimensions);
pq.train(&vectors).unwrap();
// Add vectors
for (i, vector) in vectors.iter().enumerate() {
pq.add_quantized(format!("vec_{}", i), vector).unwrap();
}
// Test search
let query = &vectors[0]; // Use first vector as query
let results = pq.search(query, k).unwrap();
// Verify we got results
assert!(!results.is_empty(), "Search should return results");
assert_eq!(results.len(), k, "Should return k results");
// First result should be among the top candidates (PQ is approximate)
// Due to quantization, the exact match might not be at position 0
// but the distance should be reasonably small relative to random vectors
let min_distance = results
.iter()
.map(|(_, d)| *d)
.fold(f32::INFINITY, f32::min);
// In high dimensions, PQ distances vary based on quantization quality
// Check that we get reasonable results (top result should be closer than random)
assert!(
min_distance < 50.0,
"Minimum distance {} should be reasonable for quantized search",
min_distance
);
println!(
"✓ PQ 384D Recall Test: top-{} results retrieved, min distance = {:.4}",
results.len(),
min_distance
);
}

View File

@@ -0,0 +1,440 @@
//! Concurrent access tests with multiple threads
//!
//! These tests verify thread-safety and correct behavior under concurrent access.
use ruvector_core::types::{DbOptions, HnswConfig, SearchQuery};
use ruvector_core::{VectorDB, VectorEntry};
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use std::thread;
use tempfile::tempdir;
#[test]
fn test_concurrent_reads() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent_reads.db")
.to_string_lossy()
.to_string();
options.dimensions = 32;
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial data
for i in 0..100 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
metadata: None,
})
.unwrap();
}
// Spawn multiple reader threads
let num_threads = 10;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..50 {
let id = format!("vec_{}", (thread_id * 10 + i) % 100);
let result = db_clone.get(&id).unwrap();
assert!(result.is_some(), "Failed to get {}", id);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_concurrent_writes_no_collision() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent_writes.db")
.to_string_lossy()
.to_string();
options.dimensions = 32;
let db = Arc::new(VectorDB::new(options).unwrap());
// Spawn multiple writer threads with non-overlapping IDs
let num_threads = 10;
let vectors_per_thread = 20;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..vectors_per_thread {
let id = format!("thread_{}_{}", thread_id, i);
db_clone
.insert(VectorEntry {
id: Some(id.clone()),
vector: vec![thread_id as f32; 32],
metadata: None,
})
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify all vectors were inserted
assert_eq!(db.len().unwrap(), num_threads * vectors_per_thread);
}
#[test]
fn test_concurrent_delete_and_insert() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent_delete_insert.db")
.to_string_lossy()
.to_string();
options.dimensions = 16;
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial data
for i in 0..100 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 16],
metadata: None,
})
.unwrap();
}
let num_threads = 5;
let mut handles = vec![];
// Deleter threads
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..10 {
let id = format!("vec_{}", thread_id * 10 + i);
db_clone.delete(&id).unwrap();
}
});
handles.push(handle);
}
// Inserter threads
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..10 {
let id = format!("new_{}_{}", thread_id, i);
db_clone
.insert(VectorEntry {
id: Some(id),
vector: vec![(thread_id * 100 + i) as f32; 16],
metadata: None,
})
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify database is in consistent state
let final_len = db.len().unwrap();
assert_eq!(final_len, 100); // 100 original - 50 deleted + 50 inserted
}
#[test]
fn test_concurrent_search_and_insert() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent_search_insert.db")
.to_string_lossy()
.to_string();
options.dimensions = 64;
options.hnsw_config = Some(HnswConfig::default());
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial data
for i in 0..100 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
metadata: None,
})
.unwrap();
}
let num_search_threads = 5;
let num_insert_threads = 2;
let mut handles = vec![];
// Search threads
for search_id in 0..num_search_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..20 {
let query: Vec<f32> = (0..64)
.map(|j| ((search_id * 10 + i + j) as f32) * 0.01)
.collect();
let results = db_clone
.search(SearchQuery {
vector: query,
k: 5,
filter: None,
ef_search: None,
})
.unwrap();
// Should always return some results (at least from initial data)
assert!(results.len() > 0);
}
});
handles.push(handle);
}
// Insert threads
for insert_id in 0..num_insert_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..50 {
db_clone
.insert(VectorEntry {
id: Some(format!("new_{}_{}", insert_id, i)),
vector: (0..64)
.map(|j| ((insert_id * 1000 + i + j) as f32) * 0.01)
.collect(),
metadata: None,
})
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify final state
assert_eq!(db.len().unwrap(), 200); // 100 initial + 100 new
}
#[test]
fn test_atomicity_of_batch_insert() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("atomic_batch.db")
.to_string_lossy()
.to_string();
options.dimensions = 16;
let db = Arc::new(VectorDB::new(options).unwrap());
// Track successful insertions
let inserted_ids = Arc::new(Mutex::new(HashSet::new()));
let num_threads = 5;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let ids_clone = Arc::clone(&inserted_ids);
let handle = thread::spawn(move || {
for batch_idx in 0..10 {
let vectors: Vec<VectorEntry> = (0..10)
.map(|i| {
let id = format!("t{}_b{}_v{}", thread_id, batch_idx, i);
VectorEntry {
id: Some(id.clone()),
vector: vec![(thread_id * 100 + batch_idx * 10 + i) as f32; 16],
metadata: None,
}
})
.collect();
let ids = db_clone.insert_batch(vectors).unwrap();
let mut lock = ids_clone.lock().unwrap();
for id in ids {
lock.insert(id);
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify all insertions were recorded
let total_inserted = inserted_ids.lock().unwrap().len();
assert_eq!(total_inserted, num_threads * 10 * 10); // threads * batches * vectors_per_batch
assert_eq!(db.len().unwrap(), total_inserted);
}
#[test]
fn test_read_write_consistency() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("consistency.db")
.to_string_lossy()
.to_string();
options.dimensions = 32;
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial vector
db.insert(VectorEntry {
id: Some("test".to_string()),
vector: vec![1.0; 32],
metadata: None,
})
.unwrap();
let num_threads = 10;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for _ in 0..100 {
// Read
let entry = db_clone.get("test").unwrap();
assert!(entry.is_some());
// Verify vector is consistent
let vector = entry.unwrap().vector;
assert_eq!(vector.len(), 32);
// All values should be the same (not corrupted)
let first_val = vector[0];
assert!(vector
.iter()
.all(|&v| v == first_val || (first_val == 1.0 || v == (thread_id as f32))));
// Write (update) - this creates a race condition intentionally
if thread_id % 2 == 0 {
let _ = db_clone.insert(VectorEntry {
id: Some("test".to_string()),
vector: vec![thread_id as f32; 32],
metadata: None,
});
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify database is still consistent
let final_entry = db.get("test").unwrap();
assert!(final_entry.is_some());
let vector = final_entry.unwrap().vector;
assert_eq!(vector.len(), 32);
// Check no corruption (all values should be the same)
let first_val = vector[0];
assert!(vector.iter().all(|&v| v == first_val));
}
#[test]
fn test_concurrent_metadata_updates() {
use std::collections::HashMap;
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("metadata.db").to_string_lossy().to_string();
options.dimensions = 16;
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert initial vectors
for i in 0..50 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 16],
metadata: None,
})
.unwrap();
}
let num_threads = 5;
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let handle = thread::spawn(move || {
for i in 0..10 {
let mut metadata = HashMap::new();
metadata.insert(format!("thread_{}", thread_id), serde_json::json!(i));
// Update vector with metadata
let id = format!("vec_{}", i * 5 + thread_id);
db_clone
.insert(VectorEntry {
id: Some(id.clone()),
vector: vec![thread_id as f32; 16],
metadata: Some(metadata),
})
.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
// Verify some vectors have metadata
let entry = db.get("vec_0").unwrap();
assert!(entry.is_some());
}

View File

@@ -0,0 +1,307 @@
//! Integration tests for embedding providers
use ruvector_core::embeddings::{ApiEmbedding, EmbeddingProvider, HashEmbedding};
use ruvector_core::{types::DbOptions, AgenticDB};
use std::sync::Arc;
use tempfile::tempdir;
#[test]
fn test_hash_embedding_provider() {
let provider = HashEmbedding::new(128);
// Test basic embedding
let emb1 = provider.embed("hello world").unwrap();
assert_eq!(emb1.len(), 128);
// Test consistency
let emb2 = provider.embed("hello world").unwrap();
assert_eq!(emb1, emb2, "Same text should produce same embedding");
// Test different text produces different embeddings
let emb3 = provider.embed("goodbye world").unwrap();
assert_ne!(
emb1, emb3,
"Different text should produce different embeddings"
);
// Test normalization
let norm: f32 = emb1.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"Embedding should be normalized to unit length"
);
// Test provider info
assert_eq!(provider.dimensions(), 128);
assert!(provider.name().contains("Hash"));
}
#[test]
fn test_agenticdb_with_hash_embeddings() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 128;
// Create AgenticDB with default hash embeddings
let db = AgenticDB::new(options).unwrap();
assert_eq!(db.embedding_provider_name(), "HashEmbedding (placeholder)");
// Test storing a reflexion episode
let episode_id = db
.store_episode(
"Solve a math problem".to_string(),
vec!["read problem".to_string(), "calculate".to_string()],
vec!["got answer 42".to_string()],
"Should have shown intermediate steps".to_string(),
)
.unwrap();
// Test retrieving similar episodes
let episodes = db
.retrieve_similar_episodes("math problem solving", 5)
.unwrap();
assert!(!episodes.is_empty());
assert_eq!(episodes[0].id, episode_id);
}
#[test]
fn test_agenticdb_with_custom_hash_provider() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 256;
// Create custom hash provider
let provider = Arc::new(HashEmbedding::new(256));
// Create AgenticDB with custom provider
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
assert_eq!(db.embedding_provider_name(), "HashEmbedding (placeholder)");
// Test creating a skill
let mut params = std::collections::HashMap::new();
params.insert("input".to_string(), "string".to_string());
let skill_id = db
.create_skill(
"Parse JSON".to_string(),
"Parse JSON from string".to_string(),
params,
vec!["json.parse()".to_string()],
)
.unwrap();
// Search for skills
let skills = db.search_skills("parse json data", 5).unwrap();
assert!(!skills.is_empty());
assert_eq!(skills[0].id, skill_id);
}
#[test]
fn test_dimension_mismatch_validation() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 128;
// Try to create with mismatched dimensions
let provider = Arc::new(HashEmbedding::new(256)); // Different from options
let result = AgenticDB::with_embedding_provider(options, provider);
assert!(result.is_err(), "Should fail when dimensions don't match");
if let Err(err) = result {
assert!(
err.to_string().contains("do not match"),
"Error should mention dimension mismatch"
);
}
}
#[test]
fn test_api_embedding_provider_construction() {
// Test OpenAI provider construction
let openai_small = ApiEmbedding::openai("sk-test", "text-embedding-3-small");
assert_eq!(openai_small.dimensions(), 1536);
assert_eq!(openai_small.name(), "ApiEmbedding");
let openai_large = ApiEmbedding::openai("sk-test", "text-embedding-3-large");
assert_eq!(openai_large.dimensions(), 3072);
// Test Cohere provider construction
let cohere = ApiEmbedding::cohere("co-test", "embed-english-v3.0");
assert_eq!(cohere.dimensions(), 1024);
// Test Voyage provider construction
let voyage = ApiEmbedding::voyage("vo-test", "voyage-2");
assert_eq!(voyage.dimensions(), 1024);
let voyage_large = ApiEmbedding::voyage("vo-test", "voyage-large-2");
assert_eq!(voyage_large.dimensions(), 1536);
}
#[test]
#[ignore] // Requires API key and network access
fn test_api_embedding_openai() {
let api_key = std::env::var("OPENAI_API_KEY")
.expect("OPENAI_API_KEY environment variable required for this test");
let provider = ApiEmbedding::openai(&api_key, "text-embedding-3-small");
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 1536);
// Check that embeddings are different for different texts
let embedding2 = provider.embed("goodbye world").unwrap();
assert_ne!(embedding, embedding2);
}
#[test]
#[ignore] // Requires API key and network access
fn test_agenticdb_with_openai_embeddings() {
let api_key = std::env::var("OPENAI_API_KEY")
.expect("OPENAI_API_KEY environment variable required for this test");
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 1536; // OpenAI text-embedding-3-small dimensions
let provider = Arc::new(ApiEmbedding::openai(&api_key, "text-embedding-3-small"));
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
assert_eq!(db.embedding_provider_name(), "ApiEmbedding");
// Test with real semantic embeddings
let _episode1_id = db
.store_episode(
"Solve calculus problem".to_string(),
vec![
"identify function".to_string(),
"take derivative".to_string(),
],
vec!["computed derivative".to_string()],
"Should explain chain rule application".to_string(),
)
.unwrap();
let _episode2_id = db
.store_episode(
"Solve algebra problem".to_string(),
vec!["simplify equation".to_string(), "solve for x".to_string()],
vec!["found x = 5".to_string()],
"Should show all steps".to_string(),
)
.unwrap();
// Search with semantic query - should find calculus episode first
let episodes = db
.retrieve_similar_episodes("derivative calculation", 2)
.unwrap();
assert!(!episodes.is_empty());
// With real embeddings, "derivative" should match calculus better than algebra
println!(
"Found episodes: {:?}",
episodes.iter().map(|e| &e.task).collect::<Vec<_>>()
);
}
#[cfg(feature = "real-embeddings")]
#[test]
#[ignore] // Requires model download
fn test_candle_embedding_provider() {
use ruvector_core::CandleEmbedding;
let provider =
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false).unwrap();
assert_eq!(provider.dimensions(), 384);
assert_eq!(provider.name(), "CandleEmbedding (transformer)");
let embedding = provider.embed("hello world").unwrap();
assert_eq!(embedding.len(), 384);
// Check normalization
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-3, "Embedding should be normalized");
// Test semantic similarity
let emb_dog = provider.embed("dog").unwrap();
let emb_cat = provider.embed("cat").unwrap();
let emb_car = provider.embed("car").unwrap();
// Cosine similarity
let similarity_dog_cat: f32 = emb_dog.iter().zip(emb_cat.iter()).map(|(a, b)| a * b).sum();
let similarity_dog_car: f32 = emb_dog.iter().zip(emb_car.iter()).map(|(a, b)| a * b).sum();
// "dog" and "cat" should be more similar than "dog" and "car"
assert!(
similarity_dog_cat > similarity_dog_car,
"Semantic embeddings should show dog-cat more similar than dog-car"
);
}
#[cfg(feature = "real-embeddings")]
#[test]
#[ignore] // Requires model download
fn test_agenticdb_with_candle_embeddings() {
use ruvector_core::CandleEmbedding;
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 384;
let provider = Arc::new(
CandleEmbedding::from_pretrained("sentence-transformers/all-MiniLM-L6-v2", false).unwrap(),
);
let db = AgenticDB::with_embedding_provider(options, provider).unwrap();
assert_eq!(
db.embedding_provider_name(),
"CandleEmbedding (transformer)"
);
// Test with real semantic embeddings
let skill1_id = db
.create_skill(
"File I/O".to_string(),
"Read and write files to disk".to_string(),
std::collections::HashMap::new(),
vec![
"open()".to_string(),
"read()".to_string(),
"write()".to_string(),
],
)
.unwrap();
let skill2_id = db
.create_skill(
"Network I/O".to_string(),
"Send and receive data over network".to_string(),
std::collections::HashMap::new(),
vec![
"connect()".to_string(),
"send()".to_string(),
"recv()".to_string(),
],
)
.unwrap();
// Search with semantic query
let skills = db.search_skills("reading files from storage", 2).unwrap();
assert!(!skills.is_empty());
// With real embeddings, file I/O should match better
println!(
"Found skills: {:?}",
skills.iter().map(|s| &s.name).collect::<Vec<_>>()
);
}

View File

@@ -0,0 +1,495 @@
//! Comprehensive HNSW integration tests with different index sizes
use ruvector_core::index::hnsw::HnswIndex;
use ruvector_core::index::VectorIndex;
use ruvector_core::types::{DistanceMetric, HnswConfig};
use ruvector_core::Result;
fn generate_random_vectors(count: usize, dimensions: usize, seed: u64) -> Vec<Vec<f32>> {
use rand::{Rng, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
(0..count)
.map(|_| {
(0..dimensions)
.map(|_| rng.gen::<f32>() * 2.0 - 1.0)
.collect()
})
.collect()
}
fn normalize_vector(v: &[f32]) -> Vec<f32> {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
fn calculate_recall(ground_truth: &[String], results: &[String]) -> f32 {
let gt_set: std::collections::HashSet<_> = ground_truth.iter().collect();
let found = results.iter().filter(|id| gt_set.contains(id)).count();
found as f32 / ground_truth.len() as f32
}
fn brute_force_search(
query: &[f32],
vectors: &[(String, Vec<f32>)],
k: usize,
metric: DistanceMetric,
) -> Vec<String> {
use ruvector_core::distance::distance;
let mut distances: Vec<_> = vectors
.iter()
.map(|(id, v)| {
let dist = distance(query, v, metric).unwrap();
(id.clone(), dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
distances.into_iter().take(k).map(|(id, _)| id).collect()
}
#[test]
fn test_hnsw_100_vectors() -> Result<()> {
let dimensions = 128;
let num_vectors = 100;
let k = 10;
let config = HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 200,
max_elements: 1000,
};
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
// Generate and insert vectors
let vectors = generate_random_vectors(num_vectors, dimensions, 42);
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
for (i, vector) in normalized_vectors.iter().enumerate() {
index.add(format!("vec_{}", i), vector.clone())?;
}
assert_eq!(index.len(), num_vectors);
// Test search accuracy with multiple queries
let num_queries = 10;
let mut total_recall = 0.0;
for i in 0..num_queries {
let query_idx = i * (num_vectors / num_queries);
let query = &normalized_vectors[query_idx];
// Get HNSW results
let results = index.search(query, k)?;
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
// Get ground truth with brute force
let vectors_with_ids: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
.collect();
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
let recall = calculate_recall(&ground_truth, &result_ids);
total_recall += recall;
}
let avg_recall = total_recall / num_queries as f32;
println!(
"100 vectors - Average recall@{}: {:.2}%",
k,
avg_recall * 100.0
);
// For small datasets, we expect very high recall
assert!(
avg_recall >= 0.90,
"Recall should be at least 90% for 100 vectors"
);
Ok(())
}
#[test]
fn test_hnsw_1k_vectors() -> Result<()> {
let dimensions = 128;
let num_vectors = 1000;
let k = 10;
let config = HnswConfig {
m: 32,
ef_construction: 200,
ef_search: 200,
max_elements: 10000,
};
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
// Generate and insert vectors
let vectors = generate_random_vectors(num_vectors, dimensions, 12345);
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
// Use batch insert for better performance
let entries: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
index.add_batch(entries)?;
assert_eq!(index.len(), num_vectors);
// Test search accuracy
let num_queries = 20;
let mut total_recall = 0.0;
for i in 0..num_queries {
let query_idx = i * (num_vectors / num_queries);
let query = &normalized_vectors[query_idx];
let results = index.search(query, k)?;
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
let vectors_with_ids: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
.collect();
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
let recall = calculate_recall(&ground_truth, &result_ids);
total_recall += recall;
}
let avg_recall = total_recall / num_queries as f32;
println!(
"1K vectors - Average recall@{}: {:.2}%",
k,
avg_recall * 100.0
);
// Should achieve at least 95% recall with ef_search=200
assert!(
avg_recall >= 0.95,
"Recall should be at least 95% for 1K vectors with ef_search=200"
);
Ok(())
}
#[test]
fn test_hnsw_10k_vectors() -> Result<()> {
let dimensions = 128;
let num_vectors = 10000;
let k = 10;
let config = HnswConfig {
m: 32,
ef_construction: 200,
ef_search: 200,
max_elements: 100000,
};
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
println!("Generating {} vectors...", num_vectors);
let vectors = generate_random_vectors(num_vectors, dimensions, 98765);
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
println!("Inserting vectors in batches...");
// Insert in batches for better performance
let batch_size = 1000;
for batch_start in (0..num_vectors).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(num_vectors);
let entries: Vec<_> = normalized_vectors[batch_start..batch_end]
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", batch_start + i), v.clone()))
.collect();
index.add_batch(entries)?;
}
assert_eq!(index.len(), num_vectors);
println!("Index built with {} vectors", index.len());
// Prepare all vectors for ground truth computation
let all_vectors: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
// Test search accuracy with a sample of queries
let num_queries = 20; // Reduced for faster testing
let mut total_recall = 0.0;
println!("Running {} queries...", num_queries);
for i in 0..num_queries {
let query_idx = i * (num_vectors / num_queries);
let query = &normalized_vectors[query_idx];
let results = index.search(query, k)?;
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
// Compare against all vectors for accurate ground truth
let ground_truth = brute_force_search(query, &all_vectors, k, DistanceMetric::Cosine);
let recall = calculate_recall(&ground_truth, &result_ids);
total_recall += recall;
}
let avg_recall = total_recall / num_queries as f32;
println!(
"10K vectors - Average recall@{}: {:.2}%",
k,
avg_recall * 100.0
);
// With ef_search=200 and m=32, we should achieve good recall
assert!(
avg_recall >= 0.70,
"Recall should be at least 70% for 10K vectors, got {:.2}%",
avg_recall * 100.0
);
Ok(())
}
#[test]
fn test_hnsw_ef_search_tuning() -> Result<()> {
let dimensions = 128;
let num_vectors = 500;
let k = 10;
let config = HnswConfig {
m: 32,
ef_construction: 200,
ef_search: 50, // Start with lower ef_search
max_elements: 10000,
};
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
let vectors = generate_random_vectors(num_vectors, dimensions, 54321);
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
let entries: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
index.add_batch(entries)?;
// Test different ef_search values
let ef_values = vec![50, 100, 200, 500];
for ef in ef_values {
let mut total_recall = 0.0;
let num_queries = 10;
for i in 0..num_queries {
let query_idx = i * 50;
let query = &normalized_vectors[query_idx];
let results = index.search_with_ef(query, k, ef)?;
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
let vectors_with_ids: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
.collect();
let ground_truth =
brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
let recall = calculate_recall(&ground_truth, &result_ids);
total_recall += recall;
}
let avg_recall = total_recall / num_queries as f32;
println!(
"ef_search={} - Average recall@{}: {:.2}%",
ef,
k,
avg_recall * 100.0
);
}
// Verify that ef_search=200 achieves at least 95% recall
let mut total_recall = 0.0;
let num_queries = 10;
for i in 0..num_queries {
let query_idx = i * 50;
let query = &normalized_vectors[query_idx];
let results = index.search_with_ef(query, k, 200)?;
let result_ids: Vec<_> = results.iter().map(|r| r.id.clone()).collect();
let vectors_with_ids: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(idx, v)| (format!("vec_{}", idx), v.clone()))
.collect();
let ground_truth = brute_force_search(query, &vectors_with_ids, k, DistanceMetric::Cosine);
let recall = calculate_recall(&ground_truth, &result_ids);
total_recall += recall;
}
let avg_recall = total_recall / num_queries as f32;
assert!(
avg_recall >= 0.95,
"ef_search=200 should achieve at least 95% recall"
);
Ok(())
}
#[test]
fn test_hnsw_serialization_large() -> Result<()> {
let dimensions = 128;
let num_vectors = 500;
let config = HnswConfig {
m: 32,
ef_construction: 200,
ef_search: 100,
max_elements: 10000,
};
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
let vectors = generate_random_vectors(num_vectors, dimensions, 11111);
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
let entries: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
index.add_batch(entries)?;
// Serialize
println!("Serializing index with {} vectors...", num_vectors);
let bytes = index.serialize()?;
println!(
"Serialized size: {} bytes ({:.2} KB)",
bytes.len(),
bytes.len() as f32 / 1024.0
);
// Deserialize
println!("Deserializing index...");
let restored_index = HnswIndex::deserialize(&bytes)?;
assert_eq!(restored_index.len(), num_vectors);
// Test that search works on restored index
let query = &normalized_vectors[0];
let original_results = index.search(query, 10)?;
let restored_results = restored_index.search(query, 10)?;
// Results should be identical
assert_eq!(original_results.len(), restored_results.len());
println!("Serialization test passed!");
Ok(())
}
#[test]
fn test_hnsw_different_metrics() -> Result<()> {
let dimensions = 128;
let num_vectors = 200;
let k = 5;
// Note: DotProduct can produce negative distances on normalized vectors,
// which causes issues with the underlying hnsw_rs library.
// We test Cosine and Euclidean which are the most commonly used metrics.
let metrics = vec![DistanceMetric::Cosine, DistanceMetric::Euclidean];
for metric in metrics {
println!("Testing metric: {:?}", metric);
let config = HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 100,
max_elements: 1000,
};
let mut index = HnswIndex::new(dimensions, metric, config)?;
let vectors = generate_random_vectors(num_vectors, dimensions, 99999);
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
for (i, vector) in normalized_vectors.iter().enumerate() {
index.add(format!("vec_{}", i), vector.clone())?;
}
// Test search
let query = &normalized_vectors[0];
let results = index.search(query, k)?;
assert!(!results.is_empty());
println!(" Found {} results for metric {:?}", results.len(), metric);
}
Ok(())
}
#[test]
fn test_hnsw_parallel_batch_insert() -> Result<()> {
let dimensions = 128;
let num_vectors = 2000;
let config = HnswConfig {
m: 32,
ef_construction: 200,
ef_search: 100,
max_elements: 10000,
};
let mut index = HnswIndex::new(dimensions, DistanceMetric::Cosine, config)?;
let vectors = generate_random_vectors(num_vectors, dimensions, 77777);
let normalized_vectors: Vec<_> = vectors.iter().map(|v| normalize_vector(v)).collect();
let entries: Vec<_> = normalized_vectors
.iter()
.enumerate()
.map(|(i, v)| (format!("vec_{}", i), v.clone()))
.collect();
// Time the batch insert
let start = std::time::Instant::now();
index.add_batch(entries)?;
let duration = start.elapsed();
println!("Batch inserted {} vectors in {:?}", num_vectors, duration);
println!(
"Throughput: {:.0} vectors/sec",
num_vectors as f64 / duration.as_secs_f64()
);
assert_eq!(index.len(), num_vectors);
// Verify search still works
let query = &normalized_vectors[0];
let results = index.search(query, 10)?;
assert!(!results.is_empty());
Ok(())
}

View File

@@ -0,0 +1,453 @@
//! Integration tests for end-to-end workflows
//!
//! These tests verify that all components work together correctly.
use ruvector_core::types::{DbOptions, DistanceMetric, HnswConfig, SearchQuery};
use ruvector_core::{VectorDB, VectorEntry};
use std::collections::HashMap;
use tempfile::tempdir;
// ============================================================================
// End-to-End Workflow Tests
// ============================================================================
#[test]
fn test_complete_insert_search_workflow() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 128;
options.distance_metric = DistanceMetric::Cosine;
options.hnsw_config = Some(HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 50,
max_elements: 100_000,
});
let db = VectorDB::new(options).unwrap();
// Insert training data
let vectors: Vec<VectorEntry> = (0..100)
.map(|i| {
let mut metadata = HashMap::new();
metadata.insert("index".to_string(), serde_json::json!(i));
VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..128).map(|j| ((i + j) as f32) * 0.01).collect(),
metadata: Some(metadata),
}
})
.collect();
let ids = db.insert_batch(vectors).unwrap();
assert_eq!(ids.len(), 100);
// Search for similar vectors
let query: Vec<f32> = (0..128).map(|j| (j as f32) * 0.01).collect();
let results = db
.search(SearchQuery {
vector: query,
k: 10,
filter: None,
ef_search: Some(100),
})
.unwrap();
assert_eq!(results.len(), 10);
assert!(results[0].vector.is_some());
assert!(results[0].metadata.is_some());
}
#[test]
fn test_batch_operations_10k_vectors() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 384;
options.distance_metric = DistanceMetric::Euclidean;
options.hnsw_config = Some(HnswConfig::default());
let db = VectorDB::new(options).unwrap();
// Generate 10K vectors
println!("Generating 10K vectors...");
let vectors: Vec<VectorEntry> = (0..10_000)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..384).map(|j| ((i + j) as f32) * 0.001).collect(),
metadata: None,
})
.collect();
// Batch insert
println!("Batch inserting 10K vectors...");
let start = std::time::Instant::now();
let ids = db.insert_batch(vectors).unwrap();
let duration = start.elapsed();
println!("Batch insert took: {:?}", duration);
assert_eq!(ids.len(), 10_000);
assert_eq!(db.len().unwrap(), 10_000);
// Perform multiple searches
println!("Performing searches...");
for i in 0..10 {
let query: Vec<f32> = (0..384).map(|j| ((i * 100 + j) as f32) * 0.001).collect();
let results = db
.search(SearchQuery {
vector: query,
k: 10,
filter: None,
ef_search: None,
})
.unwrap();
assert_eq!(results.len(), 10);
}
}
#[test]
fn test_persistence_and_reload() {
let dir = tempdir().unwrap();
let db_path = dir
.path()
.join("persistent.db")
.to_string_lossy()
.to_string();
// Create and populate database
{
let mut options = DbOptions::default();
options.storage_path = db_path.clone();
options.dimensions = 3;
options.hnsw_config = None; // Use flat index for simpler persistence test
let db = VectorDB::new(options).unwrap();
for i in 0..10 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32, (i * 2) as f32, (i * 3) as f32],
metadata: None,
})
.unwrap();
}
assert_eq!(db.len().unwrap(), 10);
}
// Reload database
{
let mut options = DbOptions::default();
options.storage_path = db_path.clone();
options.dimensions = 3;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
// Verify data persisted
assert_eq!(db.len().unwrap(), 10);
let entry = db.get("vec_5").unwrap().unwrap();
assert_eq!(entry.vector, vec![5.0, 10.0, 15.0]);
}
}
#[test]
fn test_mixed_operations_workflow() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 64;
let db = VectorDB::new(options).unwrap();
// Insert initial batch
let initial: Vec<VectorEntry> = (0..50)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..64).map(|j| ((i + j) as f32) * 0.1).collect(),
metadata: None,
})
.collect();
db.insert_batch(initial).unwrap();
assert_eq!(db.len().unwrap(), 50);
// Delete some vectors
for i in 0..10 {
db.delete(&format!("vec_{}", i)).unwrap();
}
assert_eq!(db.len().unwrap(), 40);
// Insert more individual vectors
for i in 50..60 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..64).map(|j| ((i + j) as f32) * 0.1).collect(),
metadata: None,
})
.unwrap();
}
assert_eq!(db.len().unwrap(), 50);
// Search
let query: Vec<f32> = (0..64).map(|j| (j as f32) * 0.1).collect();
let results = db
.search(SearchQuery {
vector: query,
k: 20,
filter: None,
ef_search: None,
})
.unwrap();
assert!(results.len() > 0);
}
// ============================================================================
// Different Distance Metrics
// ============================================================================
#[test]
fn test_all_distance_metrics() {
let metrics = vec![
DistanceMetric::Euclidean,
DistanceMetric::Cosine,
DistanceMetric::DotProduct,
DistanceMetric::Manhattan,
];
for metric in metrics {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 32;
options.distance_metric = metric;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
// Insert test vectors
for i in 0..20 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
metadata: None,
})
.unwrap();
}
// Search
let query: Vec<f32> = (0..32).map(|j| (j as f32) * 0.1).collect();
let results = db
.search(SearchQuery {
vector: query,
k: 5,
filter: None,
ef_search: None,
})
.unwrap();
assert_eq!(results.len(), 5, "Failed for metric {:?}", metric);
// Verify scores are in ascending order (lower is better for distance)
for i in 0..results.len() - 1 {
assert!(
results[i].score <= results[i + 1].score,
"Results not sorted for metric {:?}: {} > {}",
metric,
results[i].score,
results[i + 1].score
);
}
}
}
// ============================================================================
// HNSW Configuration Tests
// ============================================================================
#[test]
fn test_hnsw_different_configurations() {
let configs = vec![
HnswConfig {
m: 8,
ef_construction: 50,
ef_search: 50,
max_elements: 1000,
},
HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 100,
max_elements: 1000,
},
HnswConfig {
m: 32,
ef_construction: 200,
ef_search: 200,
max_elements: 1000,
},
];
for config in configs {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 64;
options.hnsw_config = Some(config.clone());
let db = VectorDB::new(options).unwrap();
// Insert vectors
let vectors: Vec<VectorEntry> = (0..100)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
metadata: None,
})
.collect();
db.insert_batch(vectors).unwrap();
// Search with different ef_search values
let query: Vec<f32> = (0..64).map(|j| (j as f32) * 0.01).collect();
let results = db
.search(SearchQuery {
vector: query,
k: 10,
filter: None,
ef_search: Some(config.ef_search),
})
.unwrap();
assert_eq!(results.len(), 10, "Failed for config M={}", config.m);
}
}
// ============================================================================
// Metadata Filtering Tests
// ============================================================================
#[test]
fn test_complex_metadata_filtering() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 16;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
// Insert vectors with different categories and values
for i in 0..50 {
let mut metadata = HashMap::new();
metadata.insert("category".to_string(), serde_json::json!(i % 3));
metadata.insert("value".to_string(), serde_json::json!(i / 10));
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..16).map(|j| ((i + j) as f32) * 0.1).collect(),
metadata: Some(metadata),
})
.unwrap();
}
// Search with single filter
let mut filter1 = HashMap::new();
filter1.insert("category".to_string(), serde_json::json!(0));
let query: Vec<f32> = (0..16).map(|j| (j as f32) * 0.1).collect();
let results1 = db
.search(SearchQuery {
vector: query.clone(),
k: 100,
filter: Some(filter1),
ef_search: None,
})
.unwrap();
// Should only get vectors where i % 3 == 0
for result in &results1 {
let meta = result.metadata.as_ref().unwrap();
assert_eq!(meta.get("category").unwrap(), &serde_json::json!(0));
}
// Search with different filter
let mut filter2 = HashMap::new();
filter2.insert("value".to_string(), serde_json::json!(2));
let results2 = db
.search(SearchQuery {
vector: query,
k: 100,
filter: Some(filter2),
ef_search: None,
})
.unwrap();
// Should only get vectors where i / 10 == 2 (i.e., i in 20..30)
for result in &results2 {
let meta = result.metadata.as_ref().unwrap();
assert_eq!(meta.get("value").unwrap(), &serde_json::json!(2));
}
}
// ============================================================================
// Error Handling Tests
// ============================================================================
#[test]
fn test_dimension_validation() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 64;
let db = VectorDB::new(options).unwrap();
// Try to insert vector with wrong dimensions
let result = db.insert(VectorEntry {
id: None,
vector: vec![1.0, 2.0, 3.0], // Only 3 dimensions, should be 64
metadata: None,
});
assert!(result.is_err());
}
#[test]
fn test_search_with_wrong_dimension() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 64;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
// Insert some vectors
db.insert(VectorEntry {
id: Some("v1".to_string()),
vector: (0..64).map(|i| i as f32).collect(),
metadata: None,
})
.unwrap();
// Try to search with wrong dimension query
// Note: This might not error in the current implementation, but should be validated
let query = vec![1.0, 2.0, 3.0]; // Wrong dimension
let result = db.search(SearchQuery {
vector: query,
k: 10,
filter: None,
ef_search: None,
});
// Depending on implementation, this might error or return empty results
// The important thing is it doesn't panic
let _ = result;
}

View File

@@ -0,0 +1,345 @@
//! Property-based tests using proptest
//!
//! These tests verify mathematical properties and invariants that should hold
//! for all inputs within a given domain.
use proptest::prelude::*;
use ruvector_core::distance::*;
use ruvector_core::quantization::*;
use ruvector_core::types::DistanceMetric;
// ============================================================================
// Distance Metric Properties
// ============================================================================
// Strategy to generate valid vectors with bounded values to prevent overflow
// Using range that won't overflow when squared: sqrt(f32::MAX) ≈ 1.84e19
// We use a more conservative range for numerical stability in distance calculations
fn vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
prop::collection::vec(-1000.0f32..1000.0f32, dim)
}
// Strategy for normalized vectors (for cosine similarity)
fn normalized_vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
vector_strategy(dim).prop_map(move |v| {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
vec![1.0 / (dim as f32).sqrt(); dim]
}
})
}
proptest! {
// Property: Distance to self is zero
#[test]
fn test_euclidean_self_distance_zero(v in vector_strategy(128)) {
let dist = euclidean_distance(&v, &v);
prop_assert!(dist < 0.001, "Distance to self should be ~0, got {}", dist);
}
// Property: Euclidean distance is symmetric
#[test]
fn test_euclidean_symmetry(
a in vector_strategy(64),
b in vector_strategy(64)
) {
let dist_ab = euclidean_distance(&a, &b);
let dist_ba = euclidean_distance(&b, &a);
prop_assert!((dist_ab - dist_ba).abs() < 0.001, "Distance should be symmetric");
}
// Property: Triangle inequality for Euclidean distance
#[test]
fn test_euclidean_triangle_inequality(
a in vector_strategy(32),
b in vector_strategy(32),
c in vector_strategy(32)
) {
let dist_ab = euclidean_distance(&a, &b);
let dist_bc = euclidean_distance(&b, &c);
let dist_ac = euclidean_distance(&a, &c);
// d(a,c) <= d(a,b) + d(b,c)
prop_assert!(
dist_ac <= dist_ab + dist_bc + 0.01, // Small epsilon for floating point
"Triangle inequality violated: {} > {} + {}",
dist_ac, dist_ab, dist_bc
);
}
// Property: Non-negativity of Euclidean distance
#[test]
fn test_euclidean_non_negative(
a in vector_strategy(64),
b in vector_strategy(64)
) {
let dist = euclidean_distance(&a, &b);
prop_assert!(dist >= 0.0, "Distance must be non-negative, got {}", dist);
}
// Property: Cosine distance symmetry
#[test]
fn test_cosine_symmetry(
a in normalized_vector_strategy(64),
b in normalized_vector_strategy(64)
) {
let dist_ab = cosine_distance(&a, &b);
let dist_ba = cosine_distance(&b, &a);
prop_assert!((dist_ab - dist_ba).abs() < 0.01, "Cosine distance should be symmetric");
}
// Property: Cosine distance to self is zero
#[test]
fn test_cosine_self_distance(v in normalized_vector_strategy(64)) {
let dist = cosine_distance(&v, &v);
prop_assert!(dist < 0.01, "Cosine distance to self should be ~0, got {}", dist);
}
// Property: Manhattan distance symmetry
#[test]
fn test_manhattan_symmetry(
a in vector_strategy(64),
b in vector_strategy(64)
) {
let dist_ab = manhattan_distance(&a, &b);
let dist_ba = manhattan_distance(&b, &a);
prop_assert!((dist_ab - dist_ba).abs() < 0.001);
}
// Property: Manhattan distance non-negativity
#[test]
fn test_manhattan_non_negative(
a in vector_strategy(64),
b in vector_strategy(64)
) {
let dist = manhattan_distance(&a, &b);
prop_assert!(dist >= 0.0, "Manhattan distance must be non-negative");
}
// Property: Dot product symmetry
#[test]
fn test_dot_product_symmetry(
a in vector_strategy(64),
b in vector_strategy(64)
) {
let dist_ab = dot_product_distance(&a, &b);
let dist_ba = dot_product_distance(&b, &a);
prop_assert!((dist_ab - dist_ba).abs() < 0.01);
}
}
// ============================================================================
// Quantization Round-Trip Properties
// ============================================================================
proptest! {
// Property: Scalar quantization round-trip preserves approximate values
#[test]
fn test_scalar_quantization_roundtrip(
v in prop::collection::vec(0.0f32..100.0f32, 64)
) {
let quantized = ScalarQuantized::quantize(&v);
let reconstructed = quantized.reconstruct();
prop_assert_eq!(v.len(), reconstructed.len());
// Check reconstruction error is bounded
for (orig, recon) in v.iter().zip(reconstructed.iter()) {
let error = (orig - recon).abs();
let relative_error = if *orig != 0.0 {
error / orig.abs()
} else {
error
};
prop_assert!(
relative_error < 0.5 || error < 1.0,
"Reconstruction error too large: {} vs {}",
orig, recon
);
}
}
// Property: Binary quantization preserves signs
#[test]
fn test_binary_quantization_sign_preservation(
v in prop::collection::vec(-10.0f32..10.0f32, 64)
.prop_filter("No zeros", |v| v.iter().all(|x| *x != 0.0))
) {
let quantized = BinaryQuantized::quantize(&v);
let reconstructed = quantized.reconstruct();
for (orig, recon) in v.iter().zip(reconstructed.iter()) {
prop_assert_eq!(
orig.signum(),
*recon,
"Sign not preserved for {} -> {}",
orig, recon
);
}
}
// Property: Binary quantization distance to self is zero
#[test]
fn test_binary_quantization_self_distance(
v in prop::collection::vec(-10.0f32..10.0f32, 64)
) {
let quantized = BinaryQuantized::quantize(&v);
let dist = quantized.distance(&quantized);
prop_assert_eq!(dist, 0.0, "Distance to self should be 0");
}
// Property: Binary quantization distance is symmetric
#[test]
fn test_binary_quantization_symmetry(
a in prop::collection::vec(-10.0f32..10.0f32, 64),
b in prop::collection::vec(-10.0f32..10.0f32, 64)
) {
let qa = BinaryQuantized::quantize(&a);
let qb = BinaryQuantized::quantize(&b);
let dist_ab = qa.distance(&qb);
let dist_ba = qb.distance(&qa);
prop_assert_eq!(dist_ab, dist_ba, "Distance should be symmetric");
}
// Property: Binary quantization distance is bounded
#[test]
fn test_binary_quantization_distance_bounded(
a in prop::collection::vec(-10.0f32..10.0f32, 64),
b in prop::collection::vec(-10.0f32..10.0f32, 64)
) {
let qa = BinaryQuantized::quantize(&a);
let qb = BinaryQuantized::quantize(&b);
let dist = qa.distance(&qb);
// Hamming distance for 64 bits should be in [0, 64]
prop_assert!(dist >= 0.0 && dist <= 64.0, "Distance {} out of bounds", dist);
}
// Property: Scalar quantization distance is non-negative
#[test]
fn test_scalar_quantization_distance_non_negative(
a in prop::collection::vec(0.0f32..100.0f32, 64),
b in prop::collection::vec(0.0f32..100.0f32, 64)
) {
let qa = ScalarQuantized::quantize(&a);
let qb = ScalarQuantized::quantize(&b);
let dist = qa.distance(&qb);
prop_assert!(dist >= 0.0, "Distance must be non-negative");
}
}
// ============================================================================
// Vector Operation Properties
// ============================================================================
proptest! {
// Property: Scaling preserves direction for cosine similarity
#[test]
fn test_cosine_scale_invariance(
v in normalized_vector_strategy(32),
scale in 0.1f32..10.0f32
) {
let scaled: Vec<f32> = v.iter().map(|x| x * scale).collect();
let dist_original = cosine_distance(&v, &v);
let dist_scaled = cosine_distance(&v, &scaled);
// Cosine distance should be approximately the same (scale invariant)
prop_assert!(
(dist_original - dist_scaled).abs() < 0.1,
"Cosine distance should be scale invariant: {} vs {}",
dist_original, dist_scaled
);
}
// Property: Adding the same vector to both preserves distance
#[test]
fn test_euclidean_translation_invariance(
a in vector_strategy(32),
b in vector_strategy(32),
offset in vector_strategy(32)
) {
let a_offset: Vec<f32> = a.iter().zip(&offset).map(|(x, o)| x + o).collect();
let b_offset: Vec<f32> = b.iter().zip(&offset).map(|(x, o)| x + o).collect();
let dist_original = euclidean_distance(&a, &b);
let dist_offset = euclidean_distance(&a_offset, &b_offset);
prop_assert!(
(dist_original - dist_offset).abs() < 0.01,
"Euclidean distance should be translation invariant"
);
}
}
// ============================================================================
// Batch Operations Properties
// ============================================================================
proptest! {
// Property: Batch distance calculation consistency
#[test]
fn test_batch_distances_consistency(
query in vector_strategy(32),
vectors in prop::collection::vec(vector_strategy(32), 10..20)
) {
// Calculate distances in batch
let batch_dists = batch_distances(&query, &vectors, DistanceMetric::Euclidean).unwrap();
// Calculate distances individually
let individual_dists: Vec<f32> = vectors.iter()
.map(|v| euclidean_distance(&query, v))
.collect();
prop_assert_eq!(batch_dists.len(), individual_dists.len());
for (batch, individual) in batch_dists.iter().zip(individual_dists.iter()) {
prop_assert!(
(batch - individual).abs() < 0.01,
"Batch and individual distances should match: {} vs {}",
batch, individual
);
}
}
}
// ============================================================================
// Dimension Handling Properties
// ============================================================================
proptest! {
// Property: Distance calculation fails on dimension mismatch
#[test]
fn test_dimension_mismatch_error(
dim1 in 1usize..100,
dim2 in 1usize..100
) {
prop_assume!(dim1 != dim2); // Only test when dimensions differ
let a = vec![1.0f32; dim1];
let b = vec![1.0f32; dim2];
let result = distance(&a, &b, DistanceMetric::Euclidean);
prop_assert!(result.is_err(), "Should error on dimension mismatch");
}
// Property: Distance calculation succeeds on matching dimensions
#[test]
fn test_dimension_match_success(
dim in 1usize..200,
a in prop::collection::vec(any::<f32>().prop_filter("Must be finite", |x| x.is_finite()), 1..200),
b in prop::collection::vec(any::<f32>().prop_filter("Must be finite", |x| x.is_finite()), 1..200)
) {
// Ensure same dimensions
let a_resized = vec![1.0f32; dim];
let b_resized = vec![1.0f32; dim];
let result = distance(&a_resized, &b_resized, DistanceMetric::Euclidean);
prop_assert!(result.is_ok(), "Should succeed on matching dimensions");
}
}

View File

@@ -0,0 +1,486 @@
//! Stress tests for scalability, concurrency, and resilience
//!
//! These tests push the system to its limits to verify robustness.
use ruvector_core::types::{DbOptions, HnswConfig, SearchQuery};
use ruvector_core::{VectorDB, VectorEntry};
use std::sync::{Arc, Barrier};
use std::thread;
use tempfile::tempdir;
// ============================================================================
// Large-Scale Insertion Tests
// ============================================================================
#[test]
#[ignore] // Run with: cargo test --test stress_tests -- --ignored --test-threads=1
fn test_million_vector_insertion() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("million.db").to_string_lossy().to_string();
options.dimensions = 128;
options.hnsw_config = Some(HnswConfig {
m: 16,
ef_construction: 100,
ef_search: 50,
max_elements: 2_000_000,
});
let db = VectorDB::new(options).unwrap();
println!("Starting million-vector insertion test...");
let batch_size = 10_000;
let num_batches = 100; // Total: 1M vectors
for batch_idx in 0..num_batches {
println!("Inserting batch {}/{}...", batch_idx + 1, num_batches);
let vectors: Vec<VectorEntry> = (0..batch_size)
.map(|i| {
let global_idx = batch_idx * batch_size + i;
VectorEntry {
id: Some(format!("vec_{}", global_idx)),
vector: (0..128)
.map(|j| ((global_idx + j) as f32) * 0.0001)
.collect(),
metadata: None,
}
})
.collect();
let start = std::time::Instant::now();
db.insert_batch(vectors).unwrap();
let duration = start.elapsed();
println!("Batch {} took: {:?}", batch_idx + 1, duration);
}
println!("Final database size: {}", db.len().unwrap());
assert_eq!(db.len().unwrap(), 1_000_000);
// Perform some searches to verify functionality
println!("Testing search on 1M vectors...");
for i in 0..10 {
let query: Vec<f32> = (0..128)
.map(|j| ((i * 10000 + j) as f32) * 0.0001)
.collect();
let start = std::time::Instant::now();
let results = db
.search(SearchQuery {
vector: query,
k: 10,
filter: None,
ef_search: Some(50),
})
.unwrap();
let duration = start.elapsed();
println!(
"Search {} took: {:?}, found {} results",
i + 1,
duration,
results.len()
);
assert_eq!(results.len(), 10);
}
}
// ============================================================================
// Concurrent Query Tests
// ============================================================================
#[test]
fn test_concurrent_queries() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("concurrent.db")
.to_string_lossy()
.to_string();
options.dimensions = 64;
options.hnsw_config = Some(HnswConfig::default());
let db = Arc::new(VectorDB::new(options).unwrap());
// Insert test data
println!("Inserting test data...");
let vectors: Vec<VectorEntry> = (0..1000)
.map(|i| VectorEntry {
id: Some(format!("vec_{}", i)),
vector: (0..64).map(|j| ((i + j) as f32) * 0.01).collect(),
metadata: None,
})
.collect();
db.insert_batch(vectors).unwrap();
// Spawn multiple threads doing concurrent searches
println!("Starting 10 concurrent query threads...");
let num_threads = 10;
let queries_per_thread = 100;
let barrier = Arc::new(Barrier::new(num_threads));
let mut handles = vec![];
for thread_id in 0..num_threads {
let db_clone = Arc::clone(&db);
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || {
// Wait for all threads to be ready
barrier_clone.wait();
let start = std::time::Instant::now();
for i in 0..queries_per_thread {
let query: Vec<f32> = (0..64)
.map(|j| ((thread_id * 1000 + i + j) as f32) * 0.01)
.collect();
let results = db_clone
.search(SearchQuery {
vector: query,
k: 10,
filter: None,
ef_search: None,
})
.unwrap();
assert_eq!(results.len(), 10);
}
let duration = start.elapsed();
println!(
"Thread {} completed {} queries in {:?}",
thread_id, queries_per_thread, duration
);
duration
});
handles.push(handle);
}
// Wait for all threads and collect results
let mut total_duration = std::time::Duration::ZERO;
for handle in handles {
let duration = handle.join().unwrap();
total_duration += duration;
}
let total_queries = num_threads * queries_per_thread;
println!("Total queries: {}", total_queries);
println!(
"Average duration per thread: {:?}",
total_duration / num_threads as u32
);
}
#[test]
fn test_concurrent_inserts_and_queries() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("mixed_concurrent.db")
.to_string_lossy()
.to_string();
options.dimensions = 32;
options.hnsw_config = Some(HnswConfig::default());
let db = Arc::new(VectorDB::new(options).unwrap());
// Initial data
let initial: Vec<VectorEntry> = (0..100)
.map(|i| VectorEntry {
id: Some(format!("initial_{}", i)),
vector: (0..32).map(|j| ((i + j) as f32) * 0.1).collect(),
metadata: None,
})
.collect();
db.insert_batch(initial).unwrap();
// Spawn reader threads
let num_readers = 5;
let num_writers = 2;
let barrier = Arc::new(Barrier::new(num_readers + num_writers));
let mut handles = vec![];
// Reader threads
for reader_id in 0..num_readers {
let db_clone = Arc::clone(&db);
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || {
barrier_clone.wait();
for i in 0..50 {
let query: Vec<f32> = (0..32)
.map(|j| ((reader_id * 100 + i + j) as f32) * 0.1)
.collect();
let results = db_clone
.search(SearchQuery {
vector: query,
k: 5,
filter: None,
ef_search: None,
})
.unwrap();
assert!(results.len() > 0 && results.len() <= 5);
}
println!("Reader {} completed", reader_id);
});
handles.push(handle);
}
// Writer threads
for writer_id in 0..num_writers {
let db_clone = Arc::clone(&db);
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || {
barrier_clone.wait();
for i in 0..20 {
let entry = VectorEntry {
id: Some(format!("writer_{}_{}", writer_id, i)),
vector: (0..32)
.map(|j| ((writer_id * 1000 + i + j) as f32) * 0.1)
.collect(),
metadata: None,
};
db_clone.insert(entry).unwrap();
}
println!("Writer {} completed", writer_id);
});
handles.push(handle);
}
// Wait for all threads
for handle in handles {
handle.join().unwrap();
}
// Verify final state
let final_len = db.len().unwrap();
println!("Final database size: {}", final_len);
assert!(final_len >= 100); // At least initial data should remain
}
// ============================================================================
// Memory Pressure Tests
// ============================================================================
#[test]
#[ignore] // Run with: cargo test --test stress_tests -- --ignored
fn test_memory_pressure_large_vectors() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir
.path()
.join("large_vectors.db")
.to_string_lossy()
.to_string();
options.dimensions = 2048; // Very large vectors
options.hnsw_config = Some(HnswConfig {
m: 8,
ef_construction: 50,
ef_search: 50,
max_elements: 100_000,
});
let db = VectorDB::new(options).unwrap();
println!("Testing with large 2048-dimensional vectors...");
let num_vectors = 10_000;
let batch_size = 1000;
for batch_idx in 0..(num_vectors / batch_size) {
let vectors: Vec<VectorEntry> = (0..batch_size)
.map(|i| {
let global_idx = batch_idx * batch_size + i;
VectorEntry {
id: Some(format!("vec_{}", global_idx)),
vector: (0..2048)
.map(|j| ((global_idx + j) as f32) * 0.0001)
.collect(),
metadata: None,
}
})
.collect();
db.insert_batch(vectors).unwrap();
println!(
"Inserted batch {}/{}",
batch_idx + 1,
num_vectors / batch_size
);
}
println!("Database size: {}", db.len().unwrap());
assert_eq!(db.len().unwrap(), num_vectors);
// Perform searches
for i in 0..5 {
let query: Vec<f32> = (0..2048)
.map(|j| ((i * 1000 + j) as f32) * 0.0001)
.collect();
let results = db
.search(SearchQuery {
vector: query,
k: 10,
filter: None,
ef_search: None,
})
.unwrap();
assert_eq!(results.len(), 10);
}
}
// ============================================================================
// Error Recovery Tests
// ============================================================================
#[test]
fn test_invalid_operations_dont_crash() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 32;
let db = VectorDB::new(options).unwrap();
// Try various invalid operations
// 1. Delete non-existent vector
let _ = db.delete("nonexistent");
// 2. Get non-existent vector
let _ = db.get("nonexistent");
// 3. Search with k=0
let result = db.search(SearchQuery {
vector: vec![0.0; 32],
k: 0,
filter: None,
ef_search: None,
});
// Should either return empty or error gracefully
let _ = result;
// 4. Insert and immediately delete in rapid succession
for i in 0..100 {
let id = db
.insert(VectorEntry {
id: Some(format!("temp_{}", i)),
vector: vec![1.0; 32],
metadata: None,
})
.unwrap();
db.delete(&id).unwrap();
}
// Database should still be functional
db.insert(VectorEntry {
id: Some("final".to_string()),
vector: vec![1.0; 32],
metadata: None,
})
.unwrap();
assert!(db.get("final").unwrap().is_some());
}
#[test]
fn test_repeated_operations() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 16;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
// Insert the same ID multiple times (should replace or error)
for _ in 0..10 {
let _ = db.insert(VectorEntry {
id: Some("same_id".to_string()),
vector: vec![1.0; 16],
metadata: None,
});
}
// Delete the same ID multiple times
for _ in 0..5 {
let _ = db.delete("same_id");
}
// Search repeatedly with the same query
let query = vec![1.0; 16];
for _ in 0..100 {
let _ = db.search(SearchQuery {
vector: query.clone(),
k: 10,
filter: None,
ef_search: None,
});
}
}
// ============================================================================
// Extreme Parameter Tests
// ============================================================================
#[test]
fn test_extreme_k_values() {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 16;
options.hnsw_config = None;
let db = VectorDB::new(options).unwrap();
// Insert some vectors
for i in 0..10 {
db.insert(VectorEntry {
id: Some(format!("vec_{}", i)),
vector: vec![i as f32; 16],
metadata: None,
})
.unwrap();
}
// Search with k larger than database size
let results = db
.search(SearchQuery {
vector: vec![1.0; 16],
k: 1000,
filter: None,
ef_search: None,
})
.unwrap();
// Should return at most 10 results
assert!(results.len() <= 10);
// Search with k=1
let results = db
.search(SearchQuery {
vector: vec![1.0; 16],
k: 1,
filter: None,
ef_search: None,
})
.unwrap();
assert_eq!(results.len(), 1);
}

View File

@@ -0,0 +1,772 @@
//! Memory Pool and Allocation Tests
//!
//! This module tests the arena allocator and cache-optimized storage
//! for correct memory management, eviction, and performance characteristics.
use ruvector_core::arena::{Arena, ArenaVec};
use ruvector_core::cache_optimized::SoAVectorStorage;
use std::sync::{Arc, Barrier};
use std::thread;
// ============================================================================
// Arena Allocator Tests
// ============================================================================
mod arena_tests {
use super::*;
#[test]
fn test_arena_basic_allocation() {
let arena = Arena::new(1024);
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
assert_eq!(vec.capacity(), 10);
assert_eq!(vec.len(), 0);
assert!(vec.is_empty());
vec.push(1.0);
vec.push(2.0);
vec.push(3.0);
assert_eq!(vec.len(), 3);
assert!(!vec.is_empty());
assert_eq!(vec[0], 1.0);
assert_eq!(vec[1], 2.0);
assert_eq!(vec[2], 3.0);
}
#[test]
fn test_arena_multiple_allocations() {
let arena = Arena::new(4096);
let vec1: ArenaVec<f32> = arena.alloc_vec(100);
let vec2: ArenaVec<f64> = arena.alloc_vec(50);
let vec3: ArenaVec<u32> = arena.alloc_vec(200);
let vec4: ArenaVec<i64> = arena.alloc_vec(75);
assert_eq!(vec1.capacity(), 100);
assert_eq!(vec2.capacity(), 50);
assert_eq!(vec3.capacity(), 200);
assert_eq!(vec4.capacity(), 75);
}
#[test]
fn test_arena_different_types() {
let arena = Arena::new(2048);
// Allocate different types
let mut floats: ArenaVec<f32> = arena.alloc_vec(10);
let mut doubles: ArenaVec<f64> = arena.alloc_vec(10);
let mut ints: ArenaVec<i32> = arena.alloc_vec(10);
let mut bytes: ArenaVec<u8> = arena.alloc_vec(10);
// Push values
for i in 0..10 {
floats.push(i as f32);
doubles.push(i as f64);
ints.push(i);
bytes.push(i as u8);
}
// Verify
for i in 0..10 {
assert_eq!(floats[i], i as f32);
assert_eq!(doubles[i], i as f64);
assert_eq!(ints[i], i as i32);
assert_eq!(bytes[i], i as u8);
}
}
#[test]
fn test_arena_reset() {
let arena = Arena::new(4096);
// First allocation cycle
{
let mut vec1: ArenaVec<f32> = arena.alloc_vec(100);
let mut vec2: ArenaVec<f32> = arena.alloc_vec(100);
for i in 0..50 {
vec1.push(i as f32);
vec2.push(i as f32 * 2.0);
}
}
let used_before = arena.used_bytes();
assert!(used_before > 0, "Should have used some bytes");
arena.reset();
let used_after = arena.used_bytes();
assert_eq!(used_after, 0, "Reset should set used bytes to 0");
// Allocated bytes should remain (memory is reused, not freed)
let allocated = arena.allocated_bytes();
assert!(allocated > 0, "Allocated bytes should remain after reset");
// Second allocation cycle - should reuse memory
let mut vec3: ArenaVec<f32> = arena.alloc_vec(50);
for i in 0..50 {
vec3.push(i as f32);
}
// Memory was reused
assert!(
arena.allocated_bytes() == allocated,
"Should reuse existing allocation"
);
}
#[test]
fn test_arena_chunk_growth() {
// Small initial chunk size to force growth
let arena = Arena::new(64);
// Allocate more than fits in one chunk
let vec1: ArenaVec<f32> = arena.alloc_vec(100);
let vec2: ArenaVec<f32> = arena.alloc_vec(100);
let vec3: ArenaVec<f32> = arena.alloc_vec(100);
assert_eq!(vec1.capacity(), 100);
assert_eq!(vec2.capacity(), 100);
assert_eq!(vec3.capacity(), 100);
// Should have allocated multiple chunks
let allocated = arena.allocated_bytes();
assert!(allocated > 64 * 3, "Should have grown beyond initial chunk");
}
#[test]
fn test_arena_as_slice() {
let arena = Arena::new(1024);
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
for i in 0..5 {
vec.push((i * 10) as f32);
}
let slice = vec.as_slice();
assert_eq!(slice, &[0.0, 10.0, 20.0, 30.0, 40.0]);
}
#[test]
fn test_arena_as_mut_slice() {
let arena = Arena::new(1024);
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
for i in 0..5 {
vec.push((i * 10) as f32);
}
{
let slice = vec.as_mut_slice();
slice[0] = 100.0;
slice[4] = 500.0;
}
assert_eq!(vec[0], 100.0);
assert_eq!(vec[4], 500.0);
}
#[test]
fn test_arena_deref() {
let arena = Arena::new(1024);
let mut vec: ArenaVec<f32> = arena.alloc_vec(10);
vec.push(1.0);
vec.push(2.0);
vec.push(3.0);
// Test Deref trait (can use slice methods)
assert_eq!(vec.len(), 3);
assert_eq!(vec.iter().sum::<f32>(), 6.0);
}
#[test]
fn test_arena_large_allocation() {
let arena = Arena::new(1024);
// Allocate something larger than the chunk size
let large_vec: ArenaVec<f32> = arena.alloc_vec(10000);
assert_eq!(large_vec.capacity(), 10000);
// Should have grown to accommodate
assert!(arena.allocated_bytes() >= 10000 * std::mem::size_of::<f32>());
}
#[test]
fn test_arena_statistics() {
let arena = Arena::new(1024);
let initial_allocated = arena.allocated_bytes();
let initial_used = arena.used_bytes();
assert_eq!(initial_allocated, 0);
assert_eq!(initial_used, 0);
let _vec: ArenaVec<f32> = arena.alloc_vec(100);
assert!(arena.allocated_bytes() > 0);
assert!(arena.used_bytes() > 0);
}
#[test]
#[should_panic(expected = "ArenaVec capacity exceeded")]
fn test_arena_capacity_exceeded() {
let arena = Arena::new(1024);
let mut vec: ArenaVec<f32> = arena.alloc_vec(5);
// Push more than capacity
for i in 0..10 {
vec.push(i as f32);
}
}
#[test]
fn test_arena_with_default_chunk_size() {
let arena = Arena::with_default_chunk_size();
// Default is 1MB
let _vec: ArenaVec<f32> = arena.alloc_vec(1000);
assert!(arena.allocated_bytes() >= 1024 * 1024);
}
}
// ============================================================================
// Cache-Optimized Storage (SoA) Tests
// ============================================================================
mod soa_tests {
use super::*;
#[test]
fn test_soa_basic_operations() {
let mut storage = SoAVectorStorage::new(3, 4);
assert_eq!(storage.len(), 0);
assert!(storage.is_empty());
assert_eq!(storage.dimensions(), 3);
storage.push(&[1.0, 2.0, 3.0]);
storage.push(&[4.0, 5.0, 6.0]);
assert_eq!(storage.len(), 2);
assert!(!storage.is_empty());
}
#[test]
fn test_soa_get_vector() {
let mut storage = SoAVectorStorage::new(4, 8);
storage.push(&[1.0, 2.0, 3.0, 4.0]);
storage.push(&[5.0, 6.0, 7.0, 8.0]);
storage.push(&[9.0, 10.0, 11.0, 12.0]);
let mut output = vec![0.0; 4];
storage.get(0, &mut output);
assert_eq!(output, vec![1.0, 2.0, 3.0, 4.0]);
storage.get(1, &mut output);
assert_eq!(output, vec![5.0, 6.0, 7.0, 8.0]);
storage.get(2, &mut output);
assert_eq!(output, vec![9.0, 10.0, 11.0, 12.0]);
}
#[test]
fn test_soa_dimension_slice() {
let mut storage = SoAVectorStorage::new(3, 8);
storage.push(&[1.0, 10.0, 100.0]);
storage.push(&[2.0, 20.0, 200.0]);
storage.push(&[3.0, 30.0, 300.0]);
storage.push(&[4.0, 40.0, 400.0]);
// Dimension 0: all first elements
let dim0 = storage.dimension_slice(0);
assert_eq!(dim0, &[1.0, 2.0, 3.0, 4.0]);
// Dimension 1: all second elements
let dim1 = storage.dimension_slice(1);
assert_eq!(dim1, &[10.0, 20.0, 30.0, 40.0]);
// Dimension 2: all third elements
let dim2 = storage.dimension_slice(2);
assert_eq!(dim2, &[100.0, 200.0, 300.0, 400.0]);
}
#[test]
fn test_soa_dimension_slice_mut() {
let mut storage = SoAVectorStorage::new(3, 8);
storage.push(&[1.0, 2.0, 3.0]);
storage.push(&[4.0, 5.0, 6.0]);
// Modify dimension 0
{
let dim0 = storage.dimension_slice_mut(0);
dim0[0] = 100.0;
dim0[1] = 400.0;
}
let mut output = vec![0.0; 3];
storage.get(0, &mut output);
assert_eq!(output, vec![100.0, 2.0, 3.0]);
storage.get(1, &mut output);
assert_eq!(output, vec![400.0, 5.0, 6.0]);
}
#[test]
fn test_soa_auto_growth() {
// Start with small capacity
let mut storage = SoAVectorStorage::new(4, 2);
// Push more vectors than initial capacity
for i in 0..100 {
storage.push(&[i as f32, (i * 2) as f32, (i * 3) as f32, (i * 4) as f32]);
}
assert_eq!(storage.len(), 100);
// Verify all values are correct
let mut output = vec![0.0; 4];
for i in 0..100 {
storage.get(i, &mut output);
assert_eq!(
output,
vec![i as f32, (i * 2) as f32, (i * 3) as f32, (i * 4) as f32]
);
}
}
#[test]
fn test_soa_batch_euclidean_distances() {
let mut storage = SoAVectorStorage::new(3, 4);
// Add orthogonal unit vectors
storage.push(&[1.0, 0.0, 0.0]);
storage.push(&[0.0, 1.0, 0.0]);
storage.push(&[0.0, 0.0, 1.0]);
let query = vec![1.0, 0.0, 0.0];
let mut distances = vec![0.0; 3];
storage.batch_euclidean_distances(&query, &mut distances);
// Distance to itself should be 0
assert!(distances[0] < 0.001, "Distance to self should be ~0");
// Distance to orthogonal vectors should be sqrt(2)
let sqrt2 = (2.0_f32).sqrt();
assert!(
(distances[1] - sqrt2).abs() < 0.01,
"Expected sqrt(2), got {}",
distances[1]
);
assert!(
(distances[2] - sqrt2).abs() < 0.01,
"Expected sqrt(2), got {}",
distances[2]
);
}
#[test]
fn test_soa_batch_distances_large() {
let dim = 128;
let num_vectors = 1000;
let mut storage = SoAVectorStorage::new(dim, 16);
// Add random-ish vectors
for i in 0..num_vectors {
let vec: Vec<f32> = (0..dim)
.map(|j| ((i * dim + j) % 100) as f32 * 0.01)
.collect();
storage.push(&vec);
}
let query: Vec<f32> = (0..dim).map(|j| (j % 50) as f32 * 0.02).collect();
let mut distances = vec![0.0; num_vectors];
storage.batch_euclidean_distances(&query, &mut distances);
// Verify all distances are non-negative and finite
for (i, &dist) in distances.iter().enumerate() {
assert!(
dist >= 0.0 && dist.is_finite(),
"Distance {} is invalid: {}",
i,
dist
);
}
}
#[test]
fn test_soa_common_embedding_dimensions() {
// Test common embedding dimensions
for dim in [128, 256, 384, 512, 768, 1024, 1536] {
let mut storage = SoAVectorStorage::new(dim, 4);
let vec: Vec<f32> = (0..dim).map(|i| i as f32 * 0.001).collect();
storage.push(&vec);
let mut output = vec![0.0; dim];
storage.get(0, &mut output);
assert_eq!(output, vec);
}
}
#[test]
#[should_panic(expected = "dimensions must be between")]
fn test_soa_zero_dimensions() {
let _ = SoAVectorStorage::new(0, 4);
}
#[test]
#[should_panic]
fn test_soa_wrong_vector_length() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[1.0, 2.0]); // Wrong dimension
}
#[test]
#[should_panic]
fn test_soa_get_out_of_bounds() {
let storage = SoAVectorStorage::new(3, 4);
let mut output = vec![0.0; 3];
storage.get(0, &mut output); // No vectors added
}
#[test]
#[should_panic]
fn test_soa_dimension_slice_out_of_bounds() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[1.0, 2.0, 3.0]);
let _ = storage.dimension_slice(5); // Invalid dimension
}
}
// ============================================================================
// Memory Pressure Tests
// ============================================================================
mod memory_pressure_tests {
use super::*;
#[test]
fn test_arena_many_small_allocations() {
let arena = Arena::new(1024 * 1024); // 1MB
// Many small allocations
for _ in 0..10000 {
let _vec: ArenaVec<f32> = arena.alloc_vec(10);
}
// Should handle without issues
assert!(arena.allocated_bytes() > 0);
}
#[test]
fn test_arena_alternating_sizes() {
let arena = Arena::new(4096);
for i in 0..100 {
let size = if i % 2 == 0 { 10 } else { 1000 };
let _vec: ArenaVec<f32> = arena.alloc_vec(size);
}
}
#[test]
fn test_soa_large_capacity() {
let mut storage = SoAVectorStorage::new(128, 10000);
for i in 0..10000 {
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.0001).collect();
storage.push(&vec);
}
assert_eq!(storage.len(), 10000);
// Verify random access
let mut output = vec![0.0; 128];
storage.get(5000, &mut output);
assert!((output[0] - (5000 * 128) as f32 * 0.0001).abs() < 0.0001);
}
#[test]
fn test_soa_batch_operations_under_pressure() {
let dim = 512;
let num_vectors = 5000;
let mut storage = SoAVectorStorage::new(dim, 128);
for i in 0..num_vectors {
let vec: Vec<f32> = (0..dim).map(|j| ((i + j) % 1000) as f32 * 0.001).collect();
storage.push(&vec);
}
// Perform batch distance calculations
let query: Vec<f32> = (0..dim).map(|j| (j % 500) as f32 * 0.002).collect();
let mut distances = vec![0.0; num_vectors];
storage.batch_euclidean_distances(&query, &mut distances);
// All distances should be valid
for dist in &distances {
assert!(dist.is_finite() && *dist >= 0.0);
}
}
}
// ============================================================================
// Concurrent Access Tests
// ============================================================================
mod concurrent_tests {
use super::*;
#[test]
fn test_soa_concurrent_reads() {
// Create and populate storage
let mut storage = SoAVectorStorage::new(64, 16);
for i in 0..1000 {
let vec: Vec<f32> = (0..64).map(|j| (i * 64 + j) as f32 * 0.01).collect();
storage.push(&vec);
}
let storage = Arc::new(storage);
let num_threads = 8;
let barrier = Arc::new(Barrier::new(num_threads));
let mut handles = vec![];
for thread_id in 0..num_threads {
let storage_clone = Arc::clone(&storage);
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || {
barrier_clone.wait();
// Each thread performs many reads
for i in 0..100 {
let idx = (thread_id * 100 + i) % 1000;
// Read dimension slices
let dim_slice = storage_clone.dimension_slice(idx % 64);
assert!(!dim_slice.is_empty());
}
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread panicked");
}
}
#[test]
fn test_soa_concurrent_batch_distances() {
let mut storage = SoAVectorStorage::new(32, 16);
for i in 0..500 {
let vec: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect();
storage.push(&vec);
}
let storage = Arc::new(storage);
let num_threads = 4;
let barrier = Arc::new(Barrier::new(num_threads));
let mut handles = vec![];
for thread_id in 0..num_threads {
let storage_clone = Arc::clone(&storage);
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || {
barrier_clone.wait();
for i in 0..50 {
let query: Vec<f32> = (0..32)
.map(|j| ((thread_id * 50 + i) * 32 + j) as f32 * 0.01)
.collect();
let mut distances = vec![0.0; 500];
storage_clone.batch_euclidean_distances(&query, &mut distances);
// Verify results
for dist in &distances {
assert!(dist.is_finite());
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread panicked");
}
}
}
// ============================================================================
// Edge Cases
// ============================================================================
mod edge_cases {
use super::*;
#[test]
fn test_soa_single_vector() {
let mut storage = SoAVectorStorage::new(3, 1);
storage.push(&[1.0, 2.0, 3.0]);
assert_eq!(storage.len(), 1);
let mut output = vec![0.0; 3];
storage.get(0, &mut output);
assert_eq!(output, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_soa_single_dimension() {
let mut storage = SoAVectorStorage::new(1, 4);
storage.push(&[1.0]);
storage.push(&[2.0]);
storage.push(&[3.0]);
let dim0 = storage.dimension_slice(0);
assert_eq!(dim0, &[1.0, 2.0, 3.0]);
}
#[test]
fn test_arena_exact_capacity() {
let arena = Arena::new(1024);
let mut vec: ArenaVec<f32> = arena.alloc_vec(5);
// Fill to exactly capacity
for i in 0..5 {
vec.push(i as f32);
}
assert_eq!(vec.len(), 5);
assert_eq!(vec.capacity(), 5);
}
#[test]
fn test_soa_zeros() {
let mut storage = SoAVectorStorage::new(4, 4);
storage.push(&[0.0, 0.0, 0.0, 0.0]);
storage.push(&[0.0, 0.0, 0.0, 0.0]);
let query = vec![0.0; 4];
let mut distances = vec![0.0; 2];
storage.batch_euclidean_distances(&query, &mut distances);
assert!(distances[0] < 1e-6);
assert!(distances[1] < 1e-6);
}
#[test]
fn test_soa_negative_values() {
let mut storage = SoAVectorStorage::new(3, 4);
storage.push(&[-1.0, -2.0, -3.0]);
storage.push(&[-4.0, -5.0, -6.0]);
let mut output = vec![0.0; 3];
storage.get(0, &mut output);
assert_eq!(output, vec![-1.0, -2.0, -3.0]);
}
}
// ============================================================================
// Performance Characteristics Tests
// ============================================================================
mod performance_tests {
use super::*;
#[test]
fn test_arena_allocation_performance() {
// This test verifies that arena allocation is efficient
let arena = Arena::new(1024 * 1024); // 1MB
let start = std::time::Instant::now();
for _ in 0..100000 {
let _vec: ArenaVec<f32> = arena.alloc_vec(10);
}
let duration = start.elapsed();
// Should complete quickly (< 1 second for 100k allocations)
assert!(
duration.as_millis() < 1000,
"Arena allocation took too long: {:?}",
duration
);
}
#[test]
fn test_soa_dimension_access_pattern() {
let mut storage = SoAVectorStorage::new(128, 16);
for i in 0..1000 {
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32).collect();
storage.push(&vec);
}
// Test dimension-wise access (this should be cache-efficient)
let start = std::time::Instant::now();
for dim in 0..128 {
let slice = storage.dimension_slice(dim);
let _sum: f32 = slice.iter().sum();
}
let duration = start.elapsed();
// Dimension-wise access should be fast due to cache locality
assert!(
duration.as_millis() < 100,
"Dimension access took too long: {:?}",
duration
);
}
#[test]
fn test_soa_batch_distance_performance() {
let mut storage = SoAVectorStorage::new(128, 128);
for i in 0..1000 {
let vec: Vec<f32> = (0..128).map(|j| (i * 128 + j) as f32 * 0.001).collect();
storage.push(&vec);
}
let query: Vec<f32> = (0..128).map(|j| j as f32 * 0.001).collect();
let mut distances = vec![0.0; 1000];
let start = std::time::Instant::now();
for _ in 0..100 {
storage.batch_euclidean_distances(&query, &mut distances);
}
let duration = start.elapsed();
// 100 batch operations on 1000 vectors should be fast
assert!(
duration.as_millis() < 500,
"Batch distance took too long: {:?}",
duration
);
}
}

View File

@@ -0,0 +1,767 @@
//! Quantization Accuracy Tests
//!
//! This module provides comprehensive tests for quantization techniques,
//! verifying accuracy, compression ratios, and distance calculations.
use ruvector_core::quantization::*;
// ============================================================================
// Scalar Quantization Tests
// ============================================================================
mod scalar_quantization_tests {
use super::*;
#[test]
fn test_scalar_quantization_basic() {
let vector = vec![0.0, 0.5, 1.0, 1.5, 2.0];
let quantized = ScalarQuantized::quantize(&vector);
assert_eq!(quantized.data.len(), 5);
assert!(quantized.scale > 0.0, "Scale should be positive");
}
#[test]
fn test_scalar_quantization_min_max() {
let vector = vec![-10.0, -5.0, 0.0, 5.0, 10.0];
let quantized = ScalarQuantized::quantize(&vector);
// Min should be -10.0
assert!((quantized.min - (-10.0)).abs() < 0.001);
// Scale should map range 20 to 255
let expected_scale = 20.0 / 255.0;
assert!(
(quantized.scale - expected_scale).abs() < 0.001,
"Scale mismatch: expected {}, got {}",
expected_scale,
quantized.scale
);
}
#[test]
fn test_scalar_quantization_reconstruction_accuracy() {
let test_vectors = vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![0.0, 0.25, 0.5, 0.75, 1.0],
vec![-100.0, 0.0, 100.0],
vec![0.001, 0.002, 0.003, 0.004, 0.005],
];
for vector in test_vectors {
let quantized = ScalarQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
assert_eq!(vector.len(), reconstructed.len());
// Calculate max error based on range
let min = vector.iter().copied().fold(f32::INFINITY, f32::min);
let max = vector.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let max_allowed_error = (max - min) / 128.0; // Allow 2 quantization steps error
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
let error = (orig - recon).abs();
assert!(
error <= max_allowed_error,
"Reconstruction error {} exceeds max {} for value {}",
error,
max_allowed_error,
orig
);
}
}
}
#[test]
fn test_scalar_quantization_constant_values() {
let constant = vec![5.0, 5.0, 5.0, 5.0, 5.0];
let quantized = ScalarQuantized::quantize(&constant);
let reconstructed = quantized.reconstruct();
for (orig, recon) in constant.iter().zip(reconstructed.iter()) {
assert!(
(orig - recon).abs() < 0.1,
"Constant value reconstruction failed"
);
}
}
#[test]
fn test_scalar_quantization_distance_self() {
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let quantized = ScalarQuantized::quantize(&vector);
let distance = quantized.distance(&quantized);
assert!(
distance < 0.001,
"Distance to self should be ~0, got {}",
distance
);
}
#[test]
fn test_scalar_quantization_distance_symmetry() {
let v1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let q1 = ScalarQuantized::quantize(&v1);
let q2 = ScalarQuantized::quantize(&v2);
let dist_ab = q1.distance(&q2);
let dist_ba = q2.distance(&q1);
assert!(
(dist_ab - dist_ba).abs() < 0.1,
"Distance not symmetric: {} vs {}",
dist_ab,
dist_ba
);
}
#[test]
fn test_scalar_quantization_distance_triangle_inequality() {
let v1 = vec![1.0, 0.0, 0.0, 0.0];
let v2 = vec![0.0, 1.0, 0.0, 0.0];
let v3 = vec![0.0, 0.0, 1.0, 0.0];
let q1 = ScalarQuantized::quantize(&v1);
let q2 = ScalarQuantized::quantize(&v2);
let q3 = ScalarQuantized::quantize(&v3);
let d12 = q1.distance(&q2);
let d23 = q2.distance(&q3);
let d13 = q1.distance(&q3);
// Triangle inequality: d(1,3) <= d(1,2) + d(2,3)
// Allow some slack for quantization errors
assert!(
d13 <= d12 + d23 + 0.5,
"Triangle inequality violated: {} > {} + {}",
d13,
d12,
d23
);
}
#[test]
fn test_scalar_quantization_common_embedding_sizes() {
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
let vector: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
let quantized = ScalarQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
assert_eq!(quantized.data.len(), dim);
assert_eq!(reconstructed.len(), dim);
// Verify compression ratio (4x for f32 -> u8)
let original_size = dim * std::mem::size_of::<f32>();
let quantized_size = quantized.data.len() + std::mem::size_of::<f32>() * 2; // data + min + scale
assert!(
quantized_size < original_size,
"No compression achieved for dim {}",
dim
);
}
}
#[test]
fn test_scalar_quantization_extreme_values() {
// Test with large values
let large = vec![1e10, 2e10, 3e10];
let quantized = ScalarQuantized::quantize(&large);
let reconstructed = quantized.reconstruct();
for (orig, recon) in large.iter().zip(reconstructed.iter()) {
let relative_error = (orig - recon).abs() / orig.abs();
assert!(
relative_error < 0.02,
"Large value reconstruction error too high: {}",
relative_error
);
}
// Test with small values
let small = vec![1e-5, 2e-5, 3e-5, 4e-5, 5e-5];
let quantized = ScalarQuantized::quantize(&small);
let reconstructed = quantized.reconstruct();
for (orig, recon) in small.iter().zip(reconstructed.iter()) {
let error = (orig - recon).abs();
let range = 4e-5;
assert!(
error < range / 100.0,
"Small value reconstruction error too high: {}",
error
);
}
}
#[test]
fn test_scalar_quantization_negative_values() {
let negative = vec![-5.0, -4.0, -3.0, -2.0, -1.0];
let quantized = ScalarQuantized::quantize(&negative);
let reconstructed = quantized.reconstruct();
for (orig, recon) in negative.iter().zip(reconstructed.iter()) {
assert!(
(orig - recon).abs() < 0.1,
"Negative value reconstruction failed: {} vs {}",
orig,
recon
);
}
}
}
// ============================================================================
// Binary Quantization Tests
// ============================================================================
mod binary_quantization_tests {
use super::*;
#[test]
fn test_binary_quantization_basic() {
let vector = vec![1.0, -1.0, 0.5, -0.5, 0.1];
let quantized = BinaryQuantized::quantize(&vector);
assert_eq!(quantized.dimensions, 5);
assert_eq!(quantized.bits.len(), 1); // 5 bits fit in 1 byte
}
#[test]
fn test_binary_quantization_packing() {
// Test byte packing
for dim in 1..=32 {
let vector: Vec<f32> = (0..dim)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let quantized = BinaryQuantized::quantize(&vector);
let expected_bytes = (dim + 7) / 8;
assert_eq!(
quantized.bits.len(),
expected_bytes,
"Wrong byte count for dim {}",
dim
);
assert_eq!(quantized.dimensions, dim);
}
}
#[test]
fn test_binary_quantization_sign_preservation() {
let test_vectors = vec![
vec![1.0, -1.0, 2.0, -2.0],
vec![0.001, -0.001, 100.0, -100.0],
vec![f32::MAX / 2.0, f32::MIN / 2.0],
];
for vector in test_vectors {
let quantized = BinaryQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
if *orig > 0.0 {
assert_eq!(*recon, 1.0, "Positive value should reconstruct to 1.0");
} else if *orig < 0.0 {
assert_eq!(*recon, -1.0, "Negative value should reconstruct to -1.0");
}
}
}
}
#[test]
fn test_binary_quantization_zero_handling() {
let vector = vec![0.0, 0.0, 0.0, 0.0];
let quantized = BinaryQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
// Zero maps to negative bit (0), which reconstructs to -1.0
for val in reconstructed {
assert_eq!(val, -1.0);
}
}
#[test]
fn test_binary_quantization_hamming_distance() {
// Test specific Hamming distance cases
let cases = vec![
// (v1, v2, expected_distance)
(vec![1.0, 1.0, 1.0, 1.0], vec![1.0, 1.0, 1.0, 1.0], 0.0), // identical
(vec![1.0, 1.0, 1.0, 1.0], vec![-1.0, -1.0, -1.0, -1.0], 4.0), // opposite
(vec![1.0, 1.0, -1.0, -1.0], vec![1.0, -1.0, -1.0, 1.0], 2.0), // 2 bits differ
(vec![1.0, -1.0, 1.0, -1.0], vec![-1.0, 1.0, -1.0, 1.0], 4.0), // all differ
];
for (v1, v2, expected) in cases {
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let distance = q1.distance(&q2);
assert!(
(distance - expected).abs() < 0.001,
"Hamming distance mismatch: expected {}, got {}",
expected,
distance
);
}
}
#[test]
fn test_binary_quantization_distance_symmetry() {
let v1 = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
let v2 = vec![-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0];
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let d12 = q1.distance(&q2);
let d21 = q2.distance(&q1);
assert_eq!(d12, d21, "Binary distance should be symmetric");
}
#[test]
fn test_binary_quantization_distance_bounds() {
for dim in [8, 16, 32, 64, 128, 256] {
let v1: Vec<f32> = (0..dim)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let v2: Vec<f32> = (0..dim)
.map(|i| if i % 3 == 0 { 1.0 } else { -1.0 })
.collect();
let q1 = BinaryQuantized::quantize(&v1);
let q2 = BinaryQuantized::quantize(&v2);
let distance = q1.distance(&q2);
// Distance should be in [0, dim]
assert!(
distance >= 0.0 && distance <= dim as f32,
"Distance {} out of bounds [0, {}]",
distance,
dim
);
}
}
#[test]
fn test_binary_quantization_compression_ratio() {
for dim in [128, 256, 512, 1024] {
let vector: Vec<f32> = (0..dim)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let quantized = BinaryQuantized::quantize(&vector);
// f32 to 1 bit = theoretical 32x compression for data only
// Actual ratio depends on overhead but should be significant
let original_data_size = dim * std::mem::size_of::<f32>();
let quantized_data_size = quantized.bits.len();
let data_compression_ratio = original_data_size as f32 / quantized_data_size as f32;
assert!(
data_compression_ratio >= 31.0,
"Data compression ratio {} less than expected ~32x for dim {}",
data_compression_ratio,
dim
);
// Verify bits.len() is correct: ceil(dim / 8)
assert_eq!(quantized.bits.len(), (dim + 7) / 8);
}
}
#[test]
fn test_binary_quantization_common_embedding_sizes() {
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
let vector: Vec<f32> = (0..dim).map(|i| (i as f32 - dim as f32 / 2.0)).collect();
let quantized = BinaryQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
assert_eq!(reconstructed.len(), dim);
// Check all values are +1 or -1
for val in &reconstructed {
assert!(*val == 1.0 || *val == -1.0);
}
}
}
}
// ============================================================================
// Product Quantization Tests
// ============================================================================
mod product_quantization_tests {
use super::*;
#[test]
fn test_product_quantization_training() {
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect())
.collect();
let num_subspaces = 4;
let codebook_size = 16;
let pq = ProductQuantized::train(&vectors, num_subspaces, codebook_size, 10).unwrap();
assert_eq!(pq.codebooks.len(), num_subspaces);
for codebook in &pq.codebooks {
assert_eq!(codebook.len(), codebook_size);
}
}
#[test]
fn test_product_quantization_encode() {
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..32).map(|j| (i * 32 + j) as f32 * 0.01).collect())
.collect();
let num_subspaces = 4;
let codebook_size = 16;
let pq = ProductQuantized::train(&vectors, num_subspaces, codebook_size, 10).unwrap();
let test_vector: Vec<f32> = (0..32).map(|i| i as f32 * 0.02).collect();
let codes = pq.encode(&test_vector);
assert_eq!(codes.len(), num_subspaces);
for code in &codes {
assert!(*code < codebook_size as u8);
}
}
#[test]
fn test_product_quantization_empty_input_error() {
let result = ProductQuantized::train(&[], 4, 16, 10);
assert!(result.is_err());
}
#[test]
fn test_product_quantization_codebook_size_limit() {
let vectors: Vec<Vec<f32>> = (0..10)
.map(|i| (0..16).map(|j| (i * 16 + j) as f32).collect())
.collect();
// Codebook size > 256 should error
let result = ProductQuantized::train(&vectors, 4, 300, 10);
assert!(result.is_err());
}
#[test]
fn test_product_quantization_various_subspaces() {
let dim = 64;
let vectors: Vec<Vec<f32>> = (0..200)
.map(|i| (0..dim).map(|j| (i * dim + j) as f32 * 0.001).collect())
.collect();
for num_subspaces in [1, 2, 4, 8, 16] {
let pq = ProductQuantized::train(&vectors, num_subspaces, 16, 5).unwrap();
assert_eq!(pq.codebooks.len(), num_subspaces);
let subspace_dim = dim / num_subspaces;
for codebook in &pq.codebooks {
for centroid in codebook {
assert_eq!(centroid.len(), subspace_dim);
}
}
}
}
}
// ============================================================================
// Comparative Tests
// ============================================================================
mod comparative_tests {
use super::*;
#[test]
fn test_scalar_vs_binary_reconstruction() {
let vector = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0];
let scalar = ScalarQuantized::quantize(&vector);
let binary = BinaryQuantized::quantize(&vector);
let scalar_recon = scalar.reconstruct();
let binary_recon = binary.reconstruct();
// Scalar should have better accuracy
let scalar_error: f32 = vector
.iter()
.zip(scalar_recon.iter())
.map(|(o, r)| (o - r).abs())
.sum::<f32>()
/ vector.len() as f32;
// Binary only preserves sign
for (orig, recon) in vector.iter().zip(binary_recon.iter()) {
assert_eq!(orig.signum(), recon.signum());
}
// Scalar error should be small
assert!(
scalar_error < 0.5,
"Scalar reconstruction error {} too high",
scalar_error
);
}
#[test]
fn test_quantization_preserves_relative_ordering() {
// Test that vectors closest in original space are also closest in quantized space
let v1 = vec![1.0, 0.0, 0.0, 0.0];
let v2 = vec![0.9, 0.1, 0.0, 0.0]; // close to v1
let v3 = vec![0.0, 0.0, 0.0, 1.0]; // far from v1
// For scalar quantization
let q1_s = ScalarQuantized::quantize(&v1);
let q2_s = ScalarQuantized::quantize(&v2);
let q3_s = ScalarQuantized::quantize(&v3);
let d12_s = q1_s.distance(&q2_s);
let d13_s = q1_s.distance(&q3_s);
// v2 should be closer to v1 than v3
assert!(
d12_s < d13_s,
"Scalar: v2 should be closer to v1 than v3: {} vs {}",
d12_s,
d13_s
);
// For binary quantization
let q1_b = BinaryQuantized::quantize(&v1);
let q2_b = BinaryQuantized::quantize(&v2);
let q3_b = BinaryQuantized::quantize(&v3);
let d12_b = q1_b.distance(&q2_b);
let d13_b = q1_b.distance(&q3_b);
// Same relative ordering should hold
assert!(
d12_b <= d13_b,
"Binary: v2 should be at most as far as v3: {} vs {}",
d12_b,
d13_b
);
}
#[test]
fn test_compression_ratios() {
let dim = 512;
let vector: Vec<f32> = (0..dim).map(|i| i as f32 * 0.01).collect();
// Original size
let original_size = dim * std::mem::size_of::<f32>(); // 2048 bytes
// Scalar quantization: u8 per element + 2 floats for min/scale
let scalar = ScalarQuantized::quantize(&vector);
let scalar_size = scalar.data.len() + 2 * std::mem::size_of::<f32>(); // ~520 bytes
let scalar_ratio = original_size as f32 / scalar_size as f32;
// Binary quantization: 1 bit per element + usize for dimensions
let binary = BinaryQuantized::quantize(&vector);
let binary_size = binary.bits.len() + std::mem::size_of::<usize>(); // ~72 bytes
let binary_ratio = original_size as f32 / binary_size as f32;
println!("Original: {} bytes", original_size);
println!(
"Scalar: {} bytes ({:.1}x compression)",
scalar_size, scalar_ratio
);
println!(
"Binary: {} bytes ({:.1}x compression)",
binary_size, binary_ratio
);
// Verify expected ratios
assert!(scalar_ratio > 3.5, "Scalar should achieve ~4x compression");
assert!(
binary_ratio > 25.0,
"Binary should achieve ~32x compression"
);
}
}
// ============================================================================
// Edge Cases and Error Handling
// ============================================================================
mod edge_cases {
use super::*;
#[test]
fn test_single_element_vector() {
let vector = vec![42.0];
let scalar = ScalarQuantized::quantize(&vector);
let binary = BinaryQuantized::quantize(&vector);
assert_eq!(scalar.data.len(), 1);
assert_eq!(binary.bits.len(), 1);
assert_eq!(binary.dimensions, 1);
}
#[test]
fn test_large_vector() {
let dim = 8192;
let vector: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
let scalar = ScalarQuantized::quantize(&vector);
let binary = BinaryQuantized::quantize(&vector);
assert_eq!(scalar.data.len(), dim);
assert_eq!(binary.dimensions, dim);
}
#[test]
fn test_all_positive() {
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let binary = BinaryQuantized::quantize(&vector);
let reconstructed = binary.reconstruct();
// All values should reconstruct to 1.0
for val in reconstructed {
assert_eq!(val, 1.0);
}
}
#[test]
fn test_all_negative() {
let vector = vec![-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0];
let binary = BinaryQuantized::quantize(&vector);
let reconstructed = binary.reconstruct();
// All values should reconstruct to -1.0
for val in reconstructed {
assert_eq!(val, -1.0);
}
}
#[test]
fn test_alternating_pattern() {
let vector: Vec<f32> = (0..100)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let binary = BinaryQuantized::quantize(&vector);
let reconstructed = binary.reconstruct();
for (i, val) in reconstructed.iter().enumerate() {
let expected = if i % 2 == 0 { 1.0 } else { -1.0 };
assert_eq!(*val, expected);
}
}
#[test]
fn test_quantization_deterministic() {
let vector = vec![1.0, 2.0, 3.0, 4.0, 5.0];
// Quantize multiple times - should get same result
let q1 = ScalarQuantized::quantize(&vector);
let q2 = ScalarQuantized::quantize(&vector);
assert_eq!(q1.data, q2.data);
assert_eq!(q1.min, q2.min);
assert_eq!(q1.scale, q2.scale);
}
}
// ============================================================================
// Performance Characteristic Tests
// ============================================================================
mod performance_tests {
use super::*;
#[test]
fn test_scalar_quantization_speed() {
let vector: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
let start = std::time::Instant::now();
for _ in 0..10000 {
let _ = ScalarQuantized::quantize(&vector);
}
let duration = start.elapsed();
let ops_per_sec = 10000.0 / duration.as_secs_f64();
println!(
"Scalar quantization: {:.0} ops/sec for 1024-dim vectors",
ops_per_sec
);
// Should be fast
assert!(
duration.as_millis() < 5000,
"Scalar quantization too slow: {:?}",
duration
);
}
#[test]
fn test_binary_quantization_speed() {
let vector: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
let start = std::time::Instant::now();
for _ in 0..10000 {
let _ = BinaryQuantized::quantize(&vector);
}
let duration = start.elapsed();
let ops_per_sec = 10000.0 / duration.as_secs_f64();
println!(
"Binary quantization: {:.0} ops/sec for 1024-dim vectors",
ops_per_sec
);
// Should be fast
assert!(
duration.as_millis() < 5000,
"Binary quantization too slow: {:?}",
duration
);
}
#[test]
fn test_distance_calculation_speed() {
let v1: Vec<f32> = (0..512).map(|i| i as f32 * 0.01).collect();
let v2: Vec<f32> = (0..512).map(|i| (i as f32 * 0.01) + 0.5).collect();
let q1_s = ScalarQuantized::quantize(&v1);
let q2_s = ScalarQuantized::quantize(&v2);
let q1_b = BinaryQuantized::quantize(&v1);
let q2_b = BinaryQuantized::quantize(&v2);
// Scalar distance
let start = std::time::Instant::now();
for _ in 0..100000 {
let _ = q1_s.distance(&q2_s);
}
let scalar_duration = start.elapsed();
// Binary distance (Hamming)
let start = std::time::Instant::now();
for _ in 0..100000 {
let _ = q1_b.distance(&q2_b);
}
let binary_duration = start.elapsed();
println!("Scalar distance: {:?} for 100k ops", scalar_duration);
println!("Binary distance: {:?} for 100k ops", binary_duration);
// Binary should be faster (just XOR and popcount)
// But both should be fast
assert!(scalar_duration.as_millis() < 1000);
assert!(binary_duration.as_millis() < 1000);
}
}

View File

@@ -0,0 +1,552 @@
//! SIMD Correctness Tests
//!
//! This module verifies that SIMD implementations produce identical results
//! to scalar fallback implementations across various input sizes and edge cases.
use ruvector_core::simd_intrinsics::*;
// ============================================================================
// Helper Functions for Scalar Computations (Ground Truth)
// ============================================================================
fn scalar_euclidean(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum::<f32>()
.sqrt()
}
fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn scalar_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > f32::EPSILON && norm_b > f32::EPSILON {
dot / (norm_a * norm_b)
} else {
0.0
}
}
fn scalar_manhattan(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
// ============================================================================
// Euclidean Distance Tests
// ============================================================================
#[test]
fn test_euclidean_simd_vs_scalar_small() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-5,
"Euclidean mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
#[test]
fn test_euclidean_simd_vs_scalar_exact_simd_width() {
// Test with exact AVX2 width (8 floats)
let a: Vec<f32> = (0..8).map(|i| i as f32).collect();
let b: Vec<f32> = (0..8).map(|i| (i + 1) as f32).collect();
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-5,
"8-element Euclidean mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
#[test]
fn test_euclidean_simd_vs_scalar_non_aligned() {
// Test with non-SIMD-aligned sizes
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65, 127, 129] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 0.01,
"Size {} Euclidean mismatch: SIMD={}, scalar={}",
size,
simd_result,
scalar_result
);
}
}
#[test]
fn test_euclidean_simd_vs_scalar_common_embedding_sizes() {
// Test common embedding dimensions
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
let a: Vec<f32> = (0..dim).map(|i| ((i % 100) as f32) * 0.01).collect();
let b: Vec<f32> = (0..dim).map(|i| (((i + 50) % 100) as f32) * 0.01).collect();
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 0.1,
"Dim {} Euclidean mismatch: SIMD={}, scalar={}",
dim,
simd_result,
scalar_result
);
}
}
#[test]
fn test_euclidean_simd_identical_vectors() {
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = euclidean_distance_simd(&v, &v);
assert!(
result < 1e-6,
"Distance to self should be ~0, got {}",
result
);
}
#[test]
fn test_euclidean_simd_zero_vectors() {
let zeros = vec![0.0; 16];
let result = euclidean_distance_simd(&zeros, &zeros);
assert!(result < 1e-6, "Distance between zeros should be 0");
}
#[test]
fn test_euclidean_simd_negative_values() {
let a = vec![-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0];
let b = vec![-5.0, -6.0, -7.0, -8.0, -9.0, -10.0, -11.0, -12.0];
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-5,
"Negative values Euclidean mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
#[test]
fn test_euclidean_simd_mixed_signs() {
let a = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0];
let b = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0];
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-4,
"Mixed signs Euclidean mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
// ============================================================================
// Dot Product Tests
// ============================================================================
#[test]
fn test_dot_product_simd_vs_scalar_small() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let simd_result = dot_product_simd(&a, &b);
let scalar_result = scalar_dot_product(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-4,
"Dot product mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
#[test]
fn test_dot_product_simd_vs_scalar_exact_simd_width() {
let a: Vec<f32> = (1..=8).map(|i| i as f32).collect();
let b: Vec<f32> = (1..=8).map(|i| i as f32).collect();
let simd_result = dot_product_simd(&a, &b);
let scalar_result = scalar_dot_product(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-4,
"8-element dot product mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
#[test]
fn test_dot_product_simd_vs_scalar_non_aligned() {
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65, 127, 129] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let simd_result = dot_product_simd(&a, &b);
let scalar_result = scalar_dot_product(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 0.1,
"Size {} dot product mismatch: SIMD={}, scalar={}",
size,
simd_result,
scalar_result
);
}
}
#[test]
fn test_dot_product_simd_common_embedding_sizes() {
for dim in [128, 256, 384, 512, 768, 1024, 1536, 2048] {
let a: Vec<f32> = (0..dim).map(|i| ((i % 10) as f32) * 0.1).collect();
let b: Vec<f32> = (0..dim).map(|i| (((i + 5) % 10) as f32) * 0.1).collect();
let simd_result = dot_product_simd(&a, &b);
let scalar_result = scalar_dot_product(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 0.5,
"Dim {} dot product mismatch: SIMD={}, scalar={}",
dim,
simd_result,
scalar_result
);
}
}
#[test]
fn test_dot_product_simd_orthogonal_vectors() {
let a = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let result = dot_product_simd(&a, &b);
assert!(result.abs() < 1e-6, "Orthogonal dot product should be 0");
}
// ============================================================================
// Cosine Similarity Tests
// ============================================================================
#[test]
fn test_cosine_simd_vs_scalar_small() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let simd_result = cosine_similarity_simd(&a, &b);
let scalar_result = scalar_cosine_similarity(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-4,
"Cosine mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
#[test]
fn test_cosine_simd_vs_scalar_non_aligned() {
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65] {
let a: Vec<f32> = (1..=size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (1..=size).map(|i| (i as f32) * 0.2).collect();
let simd_result = cosine_similarity_simd(&a, &b);
let scalar_result = scalar_cosine_similarity(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 0.01,
"Size {} cosine mismatch: SIMD={}, scalar={}",
size,
simd_result,
scalar_result
);
}
}
#[test]
fn test_cosine_simd_identical_vectors() {
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = cosine_similarity_simd(&v, &v);
assert!(
(result - 1.0).abs() < 1e-5,
"Identical vectors should have similarity 1.0, got {}",
result
);
}
#[test]
fn test_cosine_simd_opposite_vectors() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![-1.0, -2.0, -3.0, -4.0];
let result = cosine_similarity_simd(&a, &b);
assert!(
(result + 1.0).abs() < 1e-5,
"Opposite vectors should have similarity -1.0, got {}",
result
);
}
#[test]
fn test_cosine_simd_orthogonal_vectors() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0, 0.0];
let result = cosine_similarity_simd(&a, &b);
assert!(
result.abs() < 1e-5,
"Orthogonal vectors should have similarity 0, got {}",
result
);
}
// ============================================================================
// Manhattan Distance Tests
// ============================================================================
#[test]
fn test_manhattan_simd_vs_scalar_small() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let simd_result = manhattan_distance_simd(&a, &b);
let scalar_result = scalar_manhattan(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-4,
"Manhattan mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
#[test]
fn test_manhattan_simd_vs_scalar_non_aligned() {
for size in [3, 5, 7, 9, 11, 13, 15, 17, 31, 33, 63, 65] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let simd_result = manhattan_distance_simd(&a, &b);
let scalar_result = scalar_manhattan(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 0.01,
"Size {} Manhattan mismatch: SIMD={}, scalar={}",
size,
simd_result,
scalar_result
);
}
}
#[test]
fn test_manhattan_simd_identical_vectors() {
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = manhattan_distance_simd(&v, &v);
assert!(
result < 1e-6,
"Manhattan to self should be 0, got {}",
result
);
}
// ============================================================================
// Numerical Stability Tests
// ============================================================================
#[test]
fn test_simd_large_values() {
// Test with large but finite values
let large_val = 1e10;
let a: Vec<f32> = (0..16).map(|i| large_val + (i as f32)).collect();
let b: Vec<f32> = (0..16).map(|i| large_val + (i as f32) + 1.0).collect();
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
simd_result.is_finite() && scalar_result.is_finite(),
"Results should be finite for large values"
);
assert!(
(simd_result - scalar_result).abs() < 0.1,
"Large values mismatch: SIMD={}, scalar={}",
simd_result,
scalar_result
);
}
#[test]
fn test_simd_small_values() {
// Test with small values
let small_val = 1e-10;
let a: Vec<f32> = (0..16).map(|i| small_val * (i as f32 + 1.0)).collect();
let b: Vec<f32> = (0..16).map(|i| small_val * (i as f32 + 2.0)).collect();
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
simd_result.is_finite() && scalar_result.is_finite(),
"Results should be finite for small values"
);
}
#[test]
fn test_simd_denormalized_values() {
// Test with denormalized floats
let a = vec![f32::MIN_POSITIVE; 8];
let b = vec![f32::MIN_POSITIVE * 2.0; 8];
let simd_result = euclidean_distance_simd(&a, &b);
let scalar_result = scalar_euclidean(&a, &b);
assert!(
simd_result.is_finite() && scalar_result.is_finite(),
"Results should be finite for denormalized values"
);
}
// ============================================================================
// Legacy Alias Tests
// ============================================================================
#[test]
fn test_legacy_avx2_aliases_match_simd() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0];
// Legacy AVX2 functions should produce same results as SIMD functions
assert_eq!(
euclidean_distance_avx2(&a, &b),
euclidean_distance_simd(&a, &b)
);
assert_eq!(dot_product_avx2(&a, &b), dot_product_simd(&a, &b));
assert_eq!(
cosine_similarity_avx2(&a, &b),
cosine_similarity_simd(&a, &b)
);
}
// ============================================================================
// Batch Operation Tests
// ============================================================================
#[test]
fn test_simd_batch_consistency() {
let query: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
let vectors: Vec<Vec<f32>> = (0..100)
.map(|j| (0..64).map(|i| ((i + j) as f32) * 0.1).collect())
.collect();
// Compute distances using SIMD
let simd_distances: Vec<f32> = vectors
.iter()
.map(|v| euclidean_distance_simd(&query, v))
.collect();
// Compute distances using scalar
let scalar_distances: Vec<f32> = vectors
.iter()
.map(|v| scalar_euclidean(&query, v))
.collect();
// Compare
for (i, (simd, scalar)) in simd_distances
.iter()
.zip(scalar_distances.iter())
.enumerate()
{
assert!(
(simd - scalar).abs() < 0.01,
"Vector {} mismatch: SIMD={}, scalar={}",
i,
simd,
scalar
);
}
}
// ============================================================================
// Edge Case Tests
// ============================================================================
#[test]
fn test_simd_single_element() {
let a = vec![1.0];
let b = vec![2.0];
let euclidean = euclidean_distance_simd(&a, &b);
let dot = dot_product_simd(&a, &b);
let manhattan = manhattan_distance_simd(&a, &b);
assert!((euclidean - 1.0).abs() < 1e-6);
assert!((dot - 2.0).abs() < 1e-6);
assert!((manhattan - 1.0).abs() < 1e-6);
}
#[test]
fn test_simd_two_elements() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let euclidean = euclidean_distance_simd(&a, &b);
let expected = (2.0_f32).sqrt(); // sqrt(1 + 1)
assert!(
(euclidean - expected).abs() < 1e-5,
"Two element test: got {}, expected {}",
euclidean,
expected
);
}
// ============================================================================
// Stress Tests for SIMD
// ============================================================================
#[test]
fn test_simd_many_operations() {
let a: Vec<f32> = (0..512).map(|i| (i as f32) * 0.001).collect();
let b: Vec<f32> = (0..512).map(|i| ((i + 256) as f32) * 0.001).collect();
// Perform many operations to stress test
for _ in 0..1000 {
let _ = euclidean_distance_simd(&a, &b);
let _ = dot_product_simd(&a, &b);
let _ = cosine_similarity_simd(&a, &b);
let _ = manhattan_distance_simd(&a, &b);
}
// Final verification
let result = euclidean_distance_simd(&a, &b);
assert!(
result.is_finite(),
"Result should be finite after stress test"
);
}

View File

@@ -0,0 +1,555 @@
//! Unit tests with mocking using mockall (London School TDD)
//!
//! These tests use mocks to isolate components and test behavior in isolation.
use mockall::mock;
use mockall::predicate::*;
use ruvector_core::error::{Result, RuvectorError};
use ruvector_core::types::*;
use std::collections::HashMap;
// ============================================================================
// Mock Definitions
// ============================================================================
// Mock for storage operations
mock! {
pub Storage {
fn insert(&self, entry: &VectorEntry) -> Result<VectorId>;
fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>>;
fn get(&self, id: &str) -> Result<Option<VectorEntry>>;
fn delete(&self, id: &str) -> Result<bool>;
fn len(&self) -> Result<usize>;
fn is_empty(&self) -> Result<bool>;
fn all_ids(&self) -> Result<Vec<VectorId>>;
}
}
// Mock for index operations
mock! {
pub Index {
fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()>;
fn add_batch(&mut self, entries: Vec<(VectorId, Vec<f32>)>) -> Result<()>;
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
fn remove(&mut self, id: &VectorId) -> Result<bool>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
}
}
// ============================================================================
// Distance Metric Tests
// ============================================================================
#[cfg(test)]
mod distance_tests {
use super::*;
use ruvector_core::distance::*;
#[test]
fn test_euclidean_same_vector() {
let v = vec![1.0, 2.0, 3.0];
let dist = euclidean_distance(&v, &v);
assert!(dist < 0.001, "Distance to self should be ~0");
}
#[test]
fn test_euclidean_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let dist = euclidean_distance(&a, &b);
assert!(
(dist - 1.414).abs() < 0.01,
"Expected sqrt(2), got {}",
dist
);
}
#[test]
fn test_cosine_parallel_vectors() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![2.0, 4.0, 6.0]; // Parallel to a
let dist = cosine_distance(&a, &b);
assert!(
dist < 0.01,
"Parallel vectors should have ~0 cosine distance, got {}",
dist
);
}
#[test]
fn test_cosine_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let dist = cosine_distance(&a, &b);
assert!(
dist > 0.9 && dist < 1.1,
"Orthogonal vectors should have distance ~1, got {}",
dist
);
}
#[test]
fn test_dot_product_positive() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let dist = dot_product_distance(&a, &b);
assert!(
dist < 0.0,
"Dot product distance should be negative for similar vectors"
);
}
#[test]
fn test_manhattan_distance() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 1.0, 1.0];
let dist = manhattan_distance(&a, &b);
assert!(
(dist - 3.0).abs() < 0.001,
"Manhattan distance should be 3.0, got {}",
dist
);
}
#[test]
fn test_dimension_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let result = distance(&a, &b, DistanceMetric::Euclidean);
assert!(result.is_err(), "Should error on dimension mismatch");
match result {
Err(RuvectorError::DimensionMismatch { expected, actual }) => {
assert_eq!(expected, 2);
assert_eq!(actual, 3);
}
_ => panic!("Expected DimensionMismatch error"),
}
}
}
// ============================================================================
// Quantization Tests
// ============================================================================
#[cfg(test)]
mod quantization_tests {
use super::*;
use ruvector_core::quantization::*;
#[test]
fn test_scalar_quantization_reconstruction() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let quantized = ScalarQuantized::quantize(&original);
let reconstructed = quantized.reconstruct();
assert_eq!(original.len(), reconstructed.len());
for (orig, recon) in original.iter().zip(reconstructed.iter()) {
let error = (orig - recon).abs();
assert!(error < 0.1, "Reconstruction error {} too large", error);
}
}
#[test]
fn test_scalar_quantization_distance() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.1, 2.1, 3.1];
let qa = ScalarQuantized::quantize(&a);
let qb = ScalarQuantized::quantize(&b);
let quantized_dist = qa.distance(&qb);
assert!(quantized_dist >= 0.0, "Distance should be non-negative");
}
#[test]
fn test_binary_quantization_sign_preservation() {
let vector = vec![1.0, -1.0, 2.0, -2.0, 0.5, -0.5];
let quantized = BinaryQuantized::quantize(&vector);
let reconstructed = quantized.reconstruct();
for (orig, recon) in vector.iter().zip(reconstructed.iter()) {
assert_eq!(orig.signum(), *recon, "Sign should be preserved");
}
}
#[test]
fn test_binary_quantization_hamming() {
let a = vec![1.0, 1.0, 1.0, 1.0];
let b = vec![1.0, 1.0, -1.0, -1.0];
let qa = BinaryQuantized::quantize(&a);
let qb = BinaryQuantized::quantize(&b);
let dist = qa.distance(&qb);
assert_eq!(dist, 2.0, "Hamming distance should be 2.0");
}
#[test]
fn test_binary_quantization_zero_distance() {
let vector = vec![1.0, 2.0, 3.0, 4.0];
let quantized = BinaryQuantized::quantize(&vector);
let dist = quantized.distance(&quantized);
assert_eq!(dist, 0.0, "Distance to self should be 0");
}
}
// ============================================================================
// Storage Layer Tests
// ============================================================================
#[cfg(test)]
mod storage_tests {
use super::*;
use ruvector_core::storage::VectorStorage;
use tempfile::tempdir;
#[test]
fn test_insert_with_explicit_id() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let entry = VectorEntry {
id: Some("explicit_id".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
};
let id = storage.insert(&entry)?;
assert_eq!(id, "explicit_id");
Ok(())
}
#[test]
fn test_insert_auto_generates_id() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let entry = VectorEntry {
id: None,
vector: vec![1.0, 2.0, 3.0],
metadata: None,
};
let id = storage.insert(&entry)?;
assert!(!id.is_empty(), "Should generate a UUID");
assert!(id.contains('-'), "Should be a valid UUID format");
Ok(())
}
#[test]
fn test_insert_with_metadata() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let mut metadata = HashMap::new();
metadata.insert("key".to_string(), serde_json::json!("value"));
let entry = VectorEntry {
id: Some("meta_test".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: Some(metadata.clone()),
};
storage.insert(&entry)?;
let retrieved = storage.get("meta_test")?.unwrap();
assert_eq!(retrieved.metadata, Some(metadata));
Ok(())
}
#[test]
fn test_dimension_mismatch_error() {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3).unwrap();
let entry = VectorEntry {
id: None,
vector: vec![1.0, 2.0], // Wrong dimension
metadata: None,
};
let result = storage.insert(&entry);
assert!(result.is_err());
match result {
Err(RuvectorError::DimensionMismatch { expected, actual }) => {
assert_eq!(expected, 3);
assert_eq!(actual, 2);
}
_ => panic!("Expected DimensionMismatch error"),
}
}
#[test]
fn test_get_nonexistent() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let result = storage.get("nonexistent")?;
assert!(result.is_none());
Ok(())
}
#[test]
fn test_delete_nonexistent() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let deleted = storage.delete("nonexistent")?;
assert!(!deleted);
Ok(())
}
#[test]
fn test_batch_insert_empty() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let ids = storage.insert_batch(&[])?;
assert_eq!(ids.len(), 0);
Ok(())
}
#[test]
fn test_batch_insert_dimension_mismatch() {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3).unwrap();
let entries = vec![
VectorEntry {
id: None,
vector: vec![1.0, 2.0, 3.0],
metadata: None,
},
VectorEntry {
id: None,
vector: vec![1.0, 2.0], // Wrong dimension
metadata: None,
},
];
let result = storage.insert_batch(&entries);
assert!(
result.is_err(),
"Should error on dimension mismatch in batch"
);
}
#[test]
fn test_all_ids_empty() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
let ids = storage.all_ids()?;
assert_eq!(ids.len(), 0);
Ok(())
}
#[test]
fn test_all_ids_after_insert() -> Result<()> {
let dir = tempdir().unwrap();
let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
storage.insert(&VectorEntry {
id: Some("id1".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
})?;
storage.insert(&VectorEntry {
id: Some("id2".to_string()),
vector: vec![4.0, 5.0, 6.0],
metadata: None,
})?;
let ids = storage.all_ids()?;
assert_eq!(ids.len(), 2);
assert!(ids.contains(&"id1".to_string()));
assert!(ids.contains(&"id2".to_string()));
Ok(())
}
}
// ============================================================================
// VectorDB Tests (High-level API)
// ============================================================================
#[cfg(test)]
mod vector_db_tests {
use super::*;
use ruvector_core::VectorDB;
use tempfile::tempdir;
#[test]
fn test_empty_database() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
let db = VectorDB::new(options)?;
assert!(db.is_empty()?);
assert_eq!(db.len()?, 0);
Ok(())
}
#[test]
fn test_insert_updates_len() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
let db = VectorDB::new(options)?;
db.insert(VectorEntry {
id: None,
vector: vec![1.0, 2.0, 3.0],
metadata: None,
})?;
assert_eq!(db.len()?, 1);
assert!(!db.is_empty()?);
Ok(())
}
#[test]
fn test_delete_updates_len() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
let db = VectorDB::new(options)?;
let id = db.insert(VectorEntry {
id: Some("test_id".to_string()),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
})?;
assert_eq!(db.len()?, 1);
let deleted = db.delete(&id)?;
assert!(deleted);
assert_eq!(db.len()?, 0);
Ok(())
}
#[test]
fn test_search_empty_database() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
options.hnsw_config = None; // Use flat index
let db = VectorDB::new(options)?;
let results = db.search(SearchQuery {
vector: vec![1.0, 2.0, 3.0],
k: 10,
filter: None,
ef_search: None,
})?;
assert_eq!(results.len(), 0);
Ok(())
}
#[test]
fn test_search_with_filter() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
options.hnsw_config = None;
let db = VectorDB::new(options)?;
// Insert vectors with metadata
let mut meta1 = HashMap::new();
meta1.insert("category".to_string(), serde_json::json!("A"));
let mut meta2 = HashMap::new();
meta2.insert("category".to_string(), serde_json::json!("B"));
db.insert(VectorEntry {
id: Some("v1".to_string()),
vector: vec![1.0, 0.0, 0.0],
metadata: Some(meta1),
})?;
db.insert(VectorEntry {
id: Some("v2".to_string()),
vector: vec![0.9, 0.1, 0.0],
metadata: Some(meta2),
})?;
// Search with filter
let mut filter = HashMap::new();
filter.insert("category".to_string(), serde_json::json!("A"));
let results = db.search(SearchQuery {
vector: vec![1.0, 0.0, 0.0],
k: 10,
filter: Some(filter),
ef_search: None,
})?;
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "v1");
Ok(())
}
#[test]
fn test_batch_insert() -> Result<()> {
let dir = tempdir().unwrap();
let mut options = DbOptions::default();
options.storage_path = dir.path().join("test.db").to_string_lossy().to_string();
options.dimensions = 3;
let db = VectorDB::new(options)?;
let entries = vec![
VectorEntry {
id: None,
vector: vec![1.0, 0.0, 0.0],
metadata: None,
},
VectorEntry {
id: None,
vector: vec![0.0, 1.0, 0.0],
metadata: None,
},
VectorEntry {
id: None,
vector: vec![0.0, 0.0, 1.0],
metadata: None,
},
];
let ids = db.insert_batch(entries)?;
assert_eq!(ids.len(), 3);
assert_eq!(db.len()?, 3);
Ok(())
}
}